import gradio as gr
import os
import time
from openai import OpenAI
from dotenv import load_dotenv
from prompts import (
USER_PROMPT,
WRAPPER_PROMPT,
CALL_1_SYSTEM_PROMPT,
CALL_2_SYSTEM_PROMPT,
CALL_3_SYSTEM_PROMPT,
)
import difflib
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import threading
load_dotenv()
BASE_URL = "https://api.upstage.ai/v1"
API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
# Load vocabulary for rule-based correction
def load_vocabulary():
vocabulary = {}
with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
# Debug: print first row to check column names
if len(vocabulary) == 0:
print("CSV columns:", list(row.keys()))
vocabulary[row["original"]] = row["corrected"]
return vocabulary
VOCABULARY = load_vocabulary()
# 스레드 안전한 카운터
counter_lock = Lock()
processed_count = 0
total_bulks = 0
def apply_vocabulary_correction(text):
for original, corrected in VOCABULARY.items():
text = text.replace(original, corrected)
return text
def create_bulk_paragraphs(text, max_chars=500):
"""
텍스트를 500자 기준으로 벌크 단위로 분할합니다.
Args:
text: 입력 텍스트
max_chars: 최대 문자 수 (기본값: 500)
Returns:
List[str]: 벌크 단위로 분할된 텍스트 리스트
"""
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
if not paragraphs:
return []
bulks = []
current_bulk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 현재 문단이 500자를 초과하는 경우
if para_length > max_chars:
# 현재 벌크가 있다면 추가
if current_bulk:
bulks.append("\n".join(current_bulk))
current_bulk = []
current_length = 0
# 긴 문단은 단독으로 처리
bulks.append(para)
else:
# 현재 벌크에 추가했을 때 500자를 초과하는 경우
if (
current_length + para_length + len(current_bulk) > max_chars
and current_bulk
):
# 현재 벌크를 완성하고 새 벌크 시작
bulks.append("\n".join(current_bulk))
current_bulk = [para]
current_length = para_length
else:
# 현재 벌크에 추가
current_bulk.append(para)
current_length += para_length
# 마지막 벌크 추가
if current_bulk:
bulks.append("\n".join(current_bulk))
return bulks
def process_bulk(bulk_text, bulk_index, max_retries=3):
"""하나의 벌크를 파이프라인으로 처리합니다. API 에러시 재호출합니다."""
global processed_count
thread_id = threading.get_ident()
start = time.time()
for attempt in range(max_retries):
try:
# Step 0: Apply vocabulary correction to input
step0 = apply_vocabulary_correction(bulk_text)
proofread_result = call_proofread(step0)
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT)
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result)
step1 = call_solar_pro2(system_step1, user_step1)
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
# Step 4: Apply vocabulary correction to final output
step4 = apply_vocabulary_correction(step3)
elapsed = time.time() - start
with counter_lock:
processed_count += 1
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": step4,
"processing_time": elapsed,
"character_count": len(bulk_text),
"attempts": attempt + 1,
}
except Exception as e:
if attempt < max_retries - 1:
print(
f"[Thread-{thread_id}] 벌크 {bulk_index+1} 시도 {attempt+1} 실패, 재시도: {e}"
)
time.sleep(1 * (attempt + 1)) # 점진적 대기
continue
else:
print(f"[Thread-{thread_id}] 벌크 {bulk_index+1} 최종 실패: {e}")
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": bulk_text, # 오류 시 원본 반환
"processing_time": 0,
"character_count": len(bulk_text),
"error": str(e),
"attempts": max_retries,
}
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"):
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
stream=False,
temperature=temperature,
)
return response.choices[0].message.content
def call_proofread(paragraph):
prompt = "입력된 문서에 대한 교열 결과를 생성해 주세요."
response = client.chat.completions.create(
model="ft:solar-news-correction-dev",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": paragraph},
],
stream=False,
temperature=0.0,
)
return response.choices[0].message.content
def highlight_diff(original, corrected):
matcher = difflib.SequenceMatcher(None, original, corrected)
result_html = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
result_html.append(f"{original[i1:i2]}")
elif tag == "replace":
result_html.append(
f'{original[i1:i2]}'
)
result_html.append(
f'{corrected[j1:j2]}'
)
elif tag == "delete":
result_html.append(
f'{original[i1:i2]}'
)
elif tag == "insert":
result_html.append(
f'{corrected[j1:j2]}'
)
return "".join(result_html)
def process_text_parallel(input_text, max_workers=10):
"""텍스트를 벌크 단위로 병렬 처리합니다."""
global processed_count, total_bulks
# 벌크 생성
bulks = create_bulk_paragraphs(input_text)
total_bulks = len(bulks)
processed_count = 0
if not bulks:
return []
results = []
# 병렬 처리
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 모든 벌크를 병렬로 제출
future_to_bulk = {
executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks)
}
# 완료된 순서대로 결과 수집
for future in as_completed(future_to_bulk):
try:
result = future.result()
results.append(result)
except Exception as e:
bulk_index = future_to_bulk[future]
print(f"벌크 {bulk_index+1} 처리 중 예외 발생: {e}")
results.append(
{
"bulk_index": bulk_index,
"original": bulks[bulk_index],
"final": bulks[bulk_index],
"processing_time": 0,
"character_count": len(bulks[bulk_index]),
"error": str(e),
}
)
# 벌크 인덱스 순서대로 정렬
results.sort(key=lambda x: x["bulk_index"])
return results
def process(paragraph):
start = time.time()
# Step 0: Apply vocabulary correction to input
step0 = apply_vocabulary_correction(paragraph)
proofread_result = call_proofread(step0)
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT)
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result)
step1 = call_solar_pro2(system_step1, user_step1)
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
# Step 4: Apply vocabulary correction to final output
step4 = apply_vocabulary_correction(step3)
elapsed = time.time() - start
return step4, highlight_diff(paragraph, step4)
def demo_fn(input_text):
# 병렬 처리로 벌크 단위로 처리
bulk_results = process_text_parallel(input_text, max_workers=10)
if not bulk_results:
return input_text, input_text
# 결과 합치기
final_texts = [r["final"] for r in bulk_results]
final_result = "\n".join(final_texts)
# 하이라이트 생성
highlighted = highlight_diff(input_text, final_result)
return final_result, highlighted
with gr.Blocks() as demo:
gr.Markdown("# 교열 모델 데모")
input_text = gr.Textbox(
label="원문 입력", lines=10, placeholder="문단 단위로 입력해 주세요."
)
btn = gr.Button("교열하기")
output_corrected = gr.Textbox(label="교열 결과", lines=10)
output_highlight = gr.HTML(label="수정된 부분 강조")
btn.click(
fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight]
)
if __name__ == "__main__":
demo.launch()