|
import gradio as gr |
|
import os |
|
import time |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
from prompts import USER_PROMPT, CALL_1_SYSTEM_PROMPT, CALL_2_SYSTEM_PROMPT, CALL_3_SYSTEM_PROMPT |
|
import difflib |
|
|
|
load_dotenv() |
|
|
|
BASE_URL = "https://api.upstage.ai/v1" |
|
API_KEY = os.getenv("UPSTAGE_API_KEY") |
|
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
|
|
|
|
|
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"): |
|
response = client.chat.completions.create( |
|
model=model_name, |
|
messages=[ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": user} |
|
], |
|
stream=False, |
|
temperature=temperature, |
|
reasoning_effort="high" |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def call_proofread(paragraph): |
|
prompt = "์
๋ ฅ๋ ๋ฌธ์์ ๋ํ ๊ต์ด ๊ฒฐ๊ณผ๋ฅผ ์์ฑํด ์ฃผ์ธ์." |
|
response = client.chat.completions.create( |
|
model="ft:solar-news-correction", |
|
messages=[ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": paragraph} |
|
], |
|
stream=False, |
|
temperature=0.0 |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def highlight_diff(original, corrected): |
|
matcher = difflib.SequenceMatcher(None, original, corrected) |
|
result_html = [] |
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
|
if tag == 'equal': |
|
result_html.append(f'<span>{original[i1:i2]}</span>') |
|
elif tag == 'replace': |
|
result_html.append(f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>') |
|
result_html.append(f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>') |
|
elif tag == 'delete': |
|
result_html.append(f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>') |
|
elif tag == 'insert': |
|
result_html.append(f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>') |
|
return ''.join(result_html) |
|
|
|
|
|
def process(paragraph): |
|
start = time.time() |
|
proofread_result = call_proofread(paragraph) |
|
user_step1 = USER_PROMPT.format(original=paragraph, proofread=proofread_result) |
|
step1 = call_solar_pro2(CALL_1_SYSTEM_PROMPT, user_step1) |
|
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) |
|
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) |
|
elapsed = time.time() - start |
|
return step3, highlight_diff(paragraph, step3) |
|
|
|
|
|
def demo_fn(input_text): |
|
paragraphs = [p.strip() for p in input_text.split('\n') if p.strip()] |
|
corrected_all = [] |
|
highlighted_all = [] |
|
for para in paragraphs: |
|
corrected, highlighted = process(para) |
|
corrected_all.append(corrected) |
|
highlighted_all.append(highlighted) |
|
return '\n\n'.join(corrected_all), '<br><br>'.join(highlighted_all) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# ๊ต์ด ๋ชจ๋ธ ๋ฐ๋ชจ") |
|
input_text = gr.Textbox(label="์๋ฌธ ์
๋ ฅ", lines=10, placeholder="๋ฌธ๋จ ๋จ์๋ก ์
๋ ฅํด ์ฃผ์ธ์.") |
|
btn = gr.Button("๊ต์ดํ๊ธฐ") |
|
output_corrected = gr.Textbox(label="๊ต์ด ๊ฒฐ๊ณผ", lines=10) |
|
output_highlight = gr.HTML(label="์์ ๋ ๋ถ๋ถ ๊ฐ์กฐ") |
|
|
|
btn.click(fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight]) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|
|
|