fantaxy's picture
Update app.py
e5ec0df verified
raw
history blame
4.21 kB
# -*- coding: utf-8 -*-
import gradio as gr
from huggingface_hub import InferenceClient
from gradio_client import Client
import os
import requests
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# API μ„€μ •
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus-08-2024", token=os.getenv("HF_TOKEN"))
IMAGE_API_URL = "http://211.233.58.201:7896"
def generate_image(prompt: str) -> tuple:
"""이미지 생성 ν•¨μˆ˜"""
try:
client = Client(IMAGE_API_URL)
# ν”„λ‘¬ν”„νŠΈ μ•žμ— "fantasy style," μΆ”κ°€
enhanced_prompt = f"fantasy style, {prompt}"
result = client.predict(
prompt=enhanced_prompt,
width=768,
height=768,
guidance=7.5,
inference_steps=30,
seed=3,
do_img2img=False,
init_image=None,
image2image_strength=0.8,
resize_img=True,
api_name="/generate_image"
)
return result[0], result[1]
except Exception as e:
logging.error(f"Image generation failed: {str(e)}")
return None, f"Error: {str(e)}"
def respond(
message,
history: list[tuple[str, str]],
system_message="",
max_tokens=7860,
temperature=0.8,
top_p=0.9,
):
system_prefix = """
[μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ λ‚΄μš©...]
"""
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
try:
for message in hf_client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
if token is not None:
response += token.strip("")
yield response, None # 이미지λ₯Ό μœ„ν•œ None μΆ”κ°€
# ν…μŠ€νŠΈ 생성이 μ™„λ£Œλœ ν›„ 이미지 생성
image, seed = generate_image(response[:200]) # 처음 200자λ₯Ό 이미지 ν”„λ‘¬ν”„νŠΈλ‘œ μ‚¬μš©
yield response, image
except Exception as e:
yield f"Error: {str(e)}", None
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
gr.Markdown("# Fantasy Novel AI Generation")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Enter your message")
system_msg = gr.Textbox(label="System Message", value="Write(output) in ν•œκ΅­μ–΄.")
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=8000, value=7000, label="Max Tokens")
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P")
with gr.Column(scale=1):
image_output = gr.Image(label="Generated Image")
examples = gr.Examples(
examples=[
["νŒνƒ€μ§€ μ†Œμ„€μ˜ ν₯미둜운 μ†Œμž¬ 10가지λ₯Ό μ œμ‹œν•˜λΌ"],
["계속 μ΄μ–΄μ„œ μž‘μ„±ν•˜λΌ"],
["Translate into English"],
["λ§ˆλ²• μ‹œμŠ€ν…œμ— λŒ€ν•΄ 더 μžμ„Ένžˆ μ„€λͺ…ν•˜λΌ"],
["μ „νˆ¬ μž₯면을 더 극적으둜 λ¬˜μ‚¬ν•˜λΌ"],
["μƒˆλ‘œμš΄ νŒνƒ€μ§€ 쒅쑱을 μΆ”κ°€ν•˜λΌ"],
["κ³ λŒ€ μ˜ˆμ–Έμ— λŒ€ν•΄ 더 μžμ„Ένžˆ μ„€λͺ…ν•˜λΌ"],
["주인곡의 λ‚΄λ©΄ λ¬˜μ‚¬λ₯Ό μΆ”κ°€ν•˜λΌ"],
],
inputs=msg
)
msg.submit(
respond,
[msg, chatbot, system_msg, max_tokens, temperature, top_p],
[chatbot, image_output]
)
# μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ‹€ν–‰
if __name__ == "__main__":
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)