File size: 3,074 Bytes
744eef2
 
d8a1f2b
 
744eef2
 
 
 
 
 
 
 
 
 
 
 
 
d8a1f2b
 
744eef2
d8a1f2b
 
 
744eef2
d8a1f2b
 
744eef2
d5b7da9
b5d1d70
 
d5b7da9
744eef2
b5d1d70
744eef2
d8a1f2b
 
 
 
d5b7da9
 
 
 
 
 
 
d8a1f2b
 
 
 
 
 
d5b7da9
 
d8a1f2b
d5b7da9
 
d8a1f2b
744eef2
d8a1f2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import random
import os
from huggingface_hub import InferenceClient

MODELS = {
    "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
    "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
    "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Meta-Llama 3.1 70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "Microsoft": "microsoft/Phi-3-mini-4k-instruct",
    "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
    "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
    "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
    "Aya-23-35B": "CohereForAI/aya-23-35B"
}

def create_client(model_name):
    return InferenceClient(model_name, token=os.getenv("HF_TOKEN"))

def call_api(model, content, system_message, max_tokens, temperature, top_p):
    client = create_client(MODELS[model])
    messages = [{"role": "system", "content": system_message}, {"role": "user", "content": content}]
    random_seed = random.randint(0, 1000000)
    response = client.chat_completion(messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, seed=random_seed)
    return response.choices[0].message.content

def generate_copywriting(model, topic, system_message, max_tokens, temperature, top_p):
    # ๋ฐ˜๋“œ์‹œ ํ•œ๊ธ€๋กœ ์ƒ์„ฑํ•˜๋„๋ก ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ช…์‹œ
    content = f"'{topic}'์— ๋Œ€ํ•œ ์นดํ”ผ๋ผ์ดํŒ…์„ ์ž‘์„ฑํ•˜๋˜, ๋ฐ˜๋“œ์‹œ ํ•œ๊ตญ์–ด๋กœ ์ž‘์„ฑํ•˜์„ธ์š”. {system_message}๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ 10๊ฐœ์˜ ์นดํ”ผ๋ฅผ ์ž‘์„ฑํ•˜์„ธ์š”."
    return call_api(model, content, system_message, max_tokens, temperature, top_p)

title = "SNS ๋งˆ์ผ€ํŒ… ์นดํ”ผ๋ผ์ดํŒ… ์ƒ์„ฑ๊ธฐ (ํ•œ๊ตญ์–ด ์ „์šฉ)"

with gr.Blocks() as demo:
    gr.Markdown(f"# {title}")
    
    model = gr.Radio(choices=list(MODELS.keys()), label="์–ธ์–ด ๋ชจ๋ธ ์„ ํƒ", value="Zephyr 7B Beta")
    topic = gr.Textbox(label="์นดํ”ผ๋ผ์ดํŒ… ์ฃผ์ œ ์ž…๋ ฅ", lines=1, placeholder="์˜ˆ: ์—ฌ๋ฆ„ ์„ธ์ผ, ์‹ ์ œํ’ˆ ์ถœ์‹œ")
    system_message = gr.Textbox(label="์นดํ”ผ๋ผ์ดํŒ… ์ž‘์„ฑ ๊ทœ์น™", lines=10, value="""
      1. ๊ด€์‹ฌ์„ ๋Œ ์ˆ˜ ์žˆ๋Š” ๊ฐ•๋ ฅํ•œ ๋ฌธ๊ตฌ๋กœ ์ž‘์„ฑํ•˜๋ผ.
      2. ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๊ฐ„๋‹จํ•œ ๋น„์œ ๋‚˜ ์€์œ ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฌธ๊ตฌ๋ฅผ ๋ณด๊ฐ•ํ•˜๋ผ.
      3. ๋ช…ํ™•ํ•˜๊ณ  ์„ค๋“๋ ฅ ์žˆ๋Š” CTA(Call to Action)๋ฅผ ์ž‘์„ฑํ•˜๋ผ.
      4. ๊ธด๊ธ‰์„ฑ์ด๋‚˜ ํ˜œํƒ์„ ๊ฐ•์กฐํ•˜๋Š” ๋ฌธ๊ตฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ผ.
    """)
    
    with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
        max_tokens = gr.Slider(label="Max Tokens", minimum=0, maximum=4000, value=500, step=100)
        temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.75, step=0.05)
        top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.95, step=0.05)
    
    generate_btn = gr.Button("์นดํ”ผ๋ผ์ดํŒ… ์ƒ์„ฑํ•˜๊ธฐ")
    output = gr.Textbox(label="์ƒ์„ฑ๋œ ์นดํ”ผ๋ผ์ดํŒ…", lines=10)
    
    generate_btn.click(fn=generate_copywriting, 
                       inputs=[model, topic, system_message, max_tokens, temperature, top_p], 
                       outputs=[output])

demo.launch()