import gradio as gr
import os
import time
from openai import OpenAI
from dotenv import load_dotenv
from prompts import USER_PROMPT, CALL_1_SYSTEM_PROMPT, CALL_2_SYSTEM_PROMPT, CALL_3_SYSTEM_PROMPT
import difflib
load_dotenv()
BASE_URL = "https://api.upstage.ai/v1"
API_KEY = os.getenv("UPSTAGE_API_KEY")
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"):
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
stream=False,
temperature=temperature,
reasoning_effort="high"
)
return response.choices[0].message.content
def call_proofread(paragraph):
prompt = "입력된 문서에 대한 교열 결과를 생성해 주세요."
response = client.chat.completions.create(
model="ft:solar-news-correction",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": paragraph}
],
stream=False,
temperature=0.0
)
return response.choices[0].message.content
def highlight_diff(original, corrected):
matcher = difflib.SequenceMatcher(None, original, corrected)
result_html = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
result_html.append(f'{original[i1:i2]}')
elif tag == 'replace':
result_html.append(f'{original[i1:i2]}')
result_html.append(f'{corrected[j1:j2]}')
elif tag == 'delete':
result_html.append(f'{original[i1:i2]}')
elif tag == 'insert':
result_html.append(f'{corrected[j1:j2]}')
return ''.join(result_html)
def process(paragraph):
start = time.time()
proofread_result = call_proofread(paragraph)
user_step1 = USER_PROMPT.format(original=paragraph, proofread=proofread_result)
step1 = call_solar_pro2(CALL_1_SYSTEM_PROMPT, user_step1)
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
elapsed = time.time() - start
return step3, highlight_diff(paragraph, step3)
def demo_fn(input_text):
paragraphs = [p.strip() for p in input_text.split('\n') if p.strip()]
corrected_all = []
highlighted_all = []
for para in paragraphs:
corrected, highlighted = process(para)
corrected_all.append(corrected)
highlighted_all.append(highlighted)
return '\n\n'.join(corrected_all), '
'.join(highlighted_all)
with gr.Blocks() as demo:
gr.Markdown("# 교열 모델 데모")
input_text = gr.Textbox(label="원문 입력", lines=10, placeholder="문단 단위로 입력해 주세요.")
btn = gr.Button("교열하기")
output_corrected = gr.Textbox(label="교열 결과", lines=10)
output_highlight = gr.HTML(label="수정된 부분 강조")
btn.click(fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight])
if __name__ == '__main__':
demo.launch()