File size: 4,205 Bytes
03ceb55
 
 
 
e5ec0df
042cdb0
03ceb55
e5ec0df
 
 
03ceb55
e5ec0df
 
 
 
03ceb55
e5ec0df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ceb55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5ec0df
 
 
 
 
 
03ceb55
e5ec0df
03ceb55
 
e5ec0df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ceb55
 
 
 
e5ec0df
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-

import gradio as gr
from huggingface_hub import InferenceClient
from gradio_client import Client
import os
import requests
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor

# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# API μ„€μ •
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus-08-2024", token=os.getenv("HF_TOKEN"))
IMAGE_API_URL = "http://211.233.58.201:7896"

def generate_image(prompt: str) -> tuple:
    """이미지 생성 ν•¨μˆ˜"""
    try:
        client = Client(IMAGE_API_URL)
        # ν”„λ‘¬ν”„νŠΈ μ•žμ— "fantasy style," μΆ”κ°€
        enhanced_prompt = f"fantasy style, {prompt}"
        result = client.predict(
            prompt=enhanced_prompt,
            width=768,
            height=768,
            guidance=7.5,
            inference_steps=30,
            seed=3,
            do_img2img=False,
            init_image=None,
            image2image_strength=0.8,
            resize_img=True,
            api_name="/generate_image"
        )
        return result[0], result[1]
    except Exception as e:
        logging.error(f"Image generation failed: {str(e)}")
        return None, f"Error: {str(e)}"

def respond(
    message,
    history: list[tuple[str, str]],
    system_message="",  
    max_tokens=7860,  
    temperature=0.8, 
    top_p=0.9,  
):
    system_prefix = """
    [μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ λ‚΄μš©...]
    """

    messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})

    response = ""
    try:
        for message in hf_client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message.choices[0].delta.content
            if token is not None:
                response += token.strip("")
            yield response, None  # 이미지λ₯Ό μœ„ν•œ None μΆ”κ°€

        # ν…μŠ€νŠΈ 생성이 μ™„λ£Œλœ ν›„ 이미지 생성
        image, seed = generate_image(response[:200])  # 처음 200자λ₯Ό 이미지 ν”„λ‘¬ν”„νŠΈλ‘œ μ‚¬μš©
        yield response, image

    except Exception as e:
        yield f"Error: {str(e)}", None

# Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
    gr.Markdown("# Fantasy Novel AI Generation")
    
    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Enter your message")
            system_msg = gr.Textbox(label="System Message", value="Write(output) in ν•œκ΅­μ–΄.")
            
            with gr.Row():
                max_tokens = gr.Slider(minimum=1, maximum=8000, value=7000, label="Max Tokens")
                temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
                top_p = gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P")
        
        with gr.Column(scale=1):
            image_output = gr.Image(label="Generated Image")

    examples = gr.Examples(
        examples=[
            ["νŒνƒ€μ§€ μ†Œμ„€μ˜ ν₯미둜운 μ†Œμž¬ 10가지λ₯Ό μ œμ‹œν•˜λΌ"],
            ["계속 μ΄μ–΄μ„œ μž‘μ„±ν•˜λΌ"],
            ["Translate into English"],
            ["λ§ˆλ²• μ‹œμŠ€ν…œμ— λŒ€ν•΄ 더 μžμ„Ένžˆ μ„€λͺ…ν•˜λΌ"],
            ["μ „νˆ¬ μž₯면을 더 극적으둜 λ¬˜μ‚¬ν•˜λΌ"],
            ["μƒˆλ‘œμš΄ νŒνƒ€μ§€ 쒅쑱을 μΆ”κ°€ν•˜λΌ"],
            ["κ³ λŒ€ μ˜ˆμ–Έμ— λŒ€ν•΄ 더 μžμ„Ένžˆ μ„€λͺ…ν•˜λΌ"],
            ["주인곡의 λ‚΄λ©΄ λ¬˜μ‚¬λ₯Ό μΆ”κ°€ν•˜λΌ"],
        ],
        inputs=msg
    )

    msg.submit(
        respond,
        [msg, chatbot, system_msg, max_tokens, temperature, top_p],
        [chatbot, image_output]
    )

# μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ‹€ν–‰
if __name__ == "__main__":
    interface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )