|
import gradio as gr |
|
import random |
|
import os |
|
from huggingface_hub import InferenceClient |
|
|
|
MODELS = { |
|
"Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta", |
|
"DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct", |
|
"Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
"Meta-Llama 3.1 70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", |
|
"Microsoft": "microsoft/Phi-3-mini-4k-instruct", |
|
"Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3", |
|
"Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", |
|
"Cohere Command R+": "CohereForAI/c4ai-command-r-plus", |
|
"Aya-23-35B": "CohereForAI/aya-23-35B" |
|
} |
|
|
|
def create_client(model_name): |
|
return InferenceClient(model_name, token=os.getenv("HF_TOKEN")) |
|
|
|
def call_api(model, content, system_message, max_tokens, temperature, top_p): |
|
client = create_client(MODELS[model]) |
|
messages = [{"role": "system", "content": system_message}, {"role": "user", "content": content}] |
|
random_seed = random.randint(0, 1000000) |
|
response = client.chat_completion(messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, seed=random_seed) |
|
return response.choices[0].message.content |
|
|
|
def generate_copywriting(model, topic, system_message, max_tokens, temperature, top_p): |
|
|
|
content = f"'{topic}'์ ๋ํ ์นดํผ๋ผ์ดํ
์ ์์ฑํ๋, ๋ฐ๋์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์. {system_message}๋ฅผ ์ฐธ๊ณ ํ์ฌ 10๊ฐ์ ์นดํผ๋ฅผ ์์ฑํ์ธ์." |
|
return call_api(model, content, system_message, max_tokens, temperature, top_p) |
|
|
|
title = "SNS ๋ง์ผํ
์นดํผ๋ผ์ดํ
์์ฑ๊ธฐ (ํ๊ตญ์ด ์ ์ฉ)" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# {title}") |
|
|
|
model = gr.Radio(choices=list(MODELS.keys()), label="์ธ์ด ๋ชจ๋ธ ์ ํ", value="Zephyr 7B Beta") |
|
topic = gr.Textbox(label="์นดํผ๋ผ์ดํ
์ฃผ์ ์
๋ ฅ", lines=1, placeholder="์: ์ฌ๋ฆ ์ธ์ผ, ์ ์ ํ ์ถ์") |
|
system_message = gr.Textbox(label="์นดํผ๋ผ์ดํ
์์ฑ ๊ท์น", lines=10, value=""" |
|
1. ๊ด์ฌ์ ๋ ์ ์๋ ๊ฐ๋ ฅํ ๋ฌธ๊ตฌ๋ก ์์ฑํ๋ผ. |
|
2. ํ์ํ ๊ฒฝ์ฐ ๊ฐ๋จํ ๋น์ ๋ ์์ ๋ฅผ ์ฌ์ฉํ์ฌ ๋ฌธ๊ตฌ๋ฅผ ๋ณด๊ฐํ๋ผ. |
|
3. ๋ช
ํํ๊ณ ์ค๋๋ ฅ ์๋ CTA(Call to Action)๋ฅผ ์์ฑํ๋ผ. |
|
4. ๊ธด๊ธ์ฑ์ด๋ ํํ์ ๊ฐ์กฐํ๋ ๋ฌธ๊ตฌ๋ฅผ ์ฌ์ฉํ๋ผ. |
|
""") |
|
|
|
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): |
|
max_tokens = gr.Slider(label="Max Tokens", minimum=0, maximum=4000, value=500, step=100) |
|
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.75, step=0.05) |
|
top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.95, step=0.05) |
|
|
|
generate_btn = gr.Button("์นดํผ๋ผ์ดํ
์์ฑํ๊ธฐ") |
|
output = gr.Textbox(label="์์ฑ๋ ์นดํผ๋ผ์ดํ
", lines=10) |
|
|
|
generate_btn.click(fn=generate_copywriting, |
|
inputs=[model, topic, system_message, max_tokens, temperature, top_p], |
|
outputs=[output]) |
|
|
|
demo.launch() |