File size: 3,074 Bytes
744eef2 d8a1f2b 744eef2 d8a1f2b 744eef2 d8a1f2b 744eef2 d8a1f2b 744eef2 d5b7da9 b5d1d70 d5b7da9 744eef2 b5d1d70 744eef2 d8a1f2b d5b7da9 d8a1f2b d5b7da9 d8a1f2b d5b7da9 d8a1f2b 744eef2 d8a1f2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import random
import os
from huggingface_hub import InferenceClient
MODELS = {
"Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
"DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama 3.1 70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"Microsoft": "microsoft/Phi-3-mini-4k-instruct",
"Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
"Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
"Aya-23-35B": "CohereForAI/aya-23-35B"
}
def create_client(model_name):
return InferenceClient(model_name, token=os.getenv("HF_TOKEN"))
def call_api(model, content, system_message, max_tokens, temperature, top_p):
client = create_client(MODELS[model])
messages = [{"role": "system", "content": system_message}, {"role": "user", "content": content}]
random_seed = random.randint(0, 1000000)
response = client.chat_completion(messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, seed=random_seed)
return response.choices[0].message.content
def generate_copywriting(model, topic, system_message, max_tokens, temperature, top_p):
# ๋ฐ๋์ ํ๊ธ๋ก ์์ฑํ๋๋ก ํ๋กฌํํธ๋ฅผ ๋ช
์
content = f"'{topic}'์ ๋ํ ์นดํผ๋ผ์ดํ
์ ์์ฑํ๋, ๋ฐ๋์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์. {system_message}๋ฅผ ์ฐธ๊ณ ํ์ฌ 10๊ฐ์ ์นดํผ๋ฅผ ์์ฑํ์ธ์."
return call_api(model, content, system_message, max_tokens, temperature, top_p)
title = "SNS ๋ง์ผํ
์นดํผ๋ผ์ดํ
์์ฑ๊ธฐ (ํ๊ตญ์ด ์ ์ฉ)"
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
model = gr.Radio(choices=list(MODELS.keys()), label="์ธ์ด ๋ชจ๋ธ ์ ํ", value="Zephyr 7B Beta")
topic = gr.Textbox(label="์นดํผ๋ผ์ดํ
์ฃผ์ ์
๋ ฅ", lines=1, placeholder="์: ์ฌ๋ฆ ์ธ์ผ, ์ ์ ํ ์ถ์")
system_message = gr.Textbox(label="์นดํผ๋ผ์ดํ
์์ฑ ๊ท์น", lines=10, value="""
1. ๊ด์ฌ์ ๋ ์ ์๋ ๊ฐ๋ ฅํ ๋ฌธ๊ตฌ๋ก ์์ฑํ๋ผ.
2. ํ์ํ ๊ฒฝ์ฐ ๊ฐ๋จํ ๋น์ ๋ ์์ ๋ฅผ ์ฌ์ฉํ์ฌ ๋ฌธ๊ตฌ๋ฅผ ๋ณด๊ฐํ๋ผ.
3. ๋ช
ํํ๊ณ ์ค๋๋ ฅ ์๋ CTA(Call to Action)๋ฅผ ์์ฑํ๋ผ.
4. ๊ธด๊ธ์ฑ์ด๋ ํํ์ ๊ฐ์กฐํ๋ ๋ฌธ๊ตฌ๋ฅผ ์ฌ์ฉํ๋ผ.
""")
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False):
max_tokens = gr.Slider(label="Max Tokens", minimum=0, maximum=4000, value=500, step=100)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.75, step=0.05)
top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.95, step=0.05)
generate_btn = gr.Button("์นดํผ๋ผ์ดํ
์์ฑํ๊ธฐ")
output = gr.Textbox(label="์์ฑ๋ ์นดํผ๋ผ์ดํ
", lines=10)
generate_btn.click(fn=generate_copywriting,
inputs=[model, topic, system_message, max_tokens, temperature, top_p],
outputs=[output])
demo.launch() |