import gradio as gr import os import time from openai import OpenAI from dotenv import load_dotenv from prompts import ( USER_PROMPT, WRAPPER_PROMPT, CALL_1_SYSTEM_PROMPT, CALL_2_SYSTEM_PROMPT, CALL_3_SYSTEM_PROMPT, ) import difflib import csv from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock import threading load_dotenv() BASE_URL = "https://api.upstage.ai/v1" API_KEY = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=API_KEY, base_url=BASE_URL) # Load vocabulary for rule-based correction def load_vocabulary(): vocabulary = {} with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f: reader = csv.DictReader(f) for row in reader: # Debug: print first row to check column names if len(vocabulary) == 0: print("CSV columns:", list(row.keys())) vocabulary[row["original"]] = row["corrected"] return vocabulary VOCABULARY = load_vocabulary() # 스레드 안전한 카운터 counter_lock = Lock() processed_count = 0 total_bulks = 0 def apply_vocabulary_correction(text): for original, corrected in VOCABULARY.items(): text = text.replace(original, corrected) return text def create_bulk_paragraphs(text, max_chars=500): """ 텍스트를 500자 기준으로 벌크 단위로 분할합니다. Args: text: 입력 텍스트 max_chars: 최대 문자 수 (기본값: 500) Returns: List[str]: 벌크 단위로 분할된 텍스트 리스트 """ paragraphs = [p.strip() for p in text.split("\n") if p.strip()] if not paragraphs: return [] bulks = [] current_bulk = [] current_length = 0 for para in paragraphs: para_length = len(para) # 현재 문단이 500자를 초과하는 경우 if para_length > max_chars: # 현재 벌크가 있다면 추가 if current_bulk: bulks.append("\n".join(current_bulk)) current_bulk = [] current_length = 0 # 긴 문단은 단독으로 처리 bulks.append(para) else: # 현재 벌크에 추가했을 때 500자를 초과하는 경우 if ( current_length + para_length + len(current_bulk) > max_chars and current_bulk ): # 현재 벌크를 완성하고 새 벌크 시작 bulks.append("\n".join(current_bulk)) current_bulk = [para] current_length = para_length else: # 현재 벌크에 추가 current_bulk.append(para) current_length += para_length # 마지막 벌크 추가 if current_bulk: bulks.append("\n".join(current_bulk)) return bulks def process_bulk(bulk_text, bulk_index, max_retries=3): """하나의 벌크를 파이프라인으로 처리합니다. API 에러시 재호출합니다.""" global processed_count thread_id = threading.get_ident() start = time.time() for attempt in range(max_retries): try: # Step 0: Apply vocabulary correction to input step0 = apply_vocabulary_correction(bulk_text) proofread_result = call_proofread(step0) system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) step1 = call_solar_pro2(system_step1, user_step1) step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) # Step 4: Apply vocabulary correction to final output step4 = apply_vocabulary_correction(step3) elapsed = time.time() - start with counter_lock: processed_count += 1 return { "bulk_index": bulk_index, "original": bulk_text, "final": step4, "processing_time": elapsed, "character_count": len(bulk_text), "attempts": attempt + 1, } except Exception as e: if attempt < max_retries - 1: print( f"[Thread-{thread_id}] 벌크 {bulk_index+1} 시도 {attempt+1} 실패, 재시도: {e}" ) time.sleep(1 * (attempt + 1)) # 점진적 대기 continue else: print(f"[Thread-{thread_id}] 벌크 {bulk_index+1} 최종 실패: {e}") return { "bulk_index": bulk_index, "original": bulk_text, "final": bulk_text, # 오류 시 원본 반환 "processing_time": 0, "character_count": len(bulk_text), "error": str(e), "attempts": max_retries, } def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"): response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], stream=False, temperature=temperature, ) return response.choices[0].message.content def call_proofread(paragraph): prompt = "입력된 문서에 대한 교열 결과를 생성해 주세요." response = client.chat.completions.create( model="ft:solar-news-correction-dev", messages=[ {"role": "system", "content": prompt}, {"role": "user", "content": paragraph}, ], stream=False, temperature=0.0, ) return response.choices[0].message.content def highlight_diff(original, corrected): matcher = difflib.SequenceMatcher(None, original, corrected) result_html = [] for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == "equal": result_html.append(f"{original[i1:i2]}") elif tag == "replace": result_html.append( f'{original[i1:i2]}' ) result_html.append( f'{corrected[j1:j2]}' ) elif tag == "delete": result_html.append( f'{original[i1:i2]}' ) elif tag == "insert": result_html.append( f'{corrected[j1:j2]}' ) return "".join(result_html) def process_text_parallel(input_text, max_workers=10): """텍스트를 벌크 단위로 병렬 처리합니다.""" global processed_count, total_bulks # 벌크 생성 bulks = create_bulk_paragraphs(input_text) total_bulks = len(bulks) processed_count = 0 if not bulks: return [] results = [] # 병렬 처리 with ThreadPoolExecutor(max_workers=max_workers) as executor: # 모든 벌크를 병렬로 제출 future_to_bulk = { executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks) } # 완료된 순서대로 결과 수집 for future in as_completed(future_to_bulk): try: result = future.result() results.append(result) except Exception as e: bulk_index = future_to_bulk[future] print(f"벌크 {bulk_index+1} 처리 중 예외 발생: {e}") results.append( { "bulk_index": bulk_index, "original": bulks[bulk_index], "final": bulks[bulk_index], "processing_time": 0, "character_count": len(bulks[bulk_index]), "error": str(e), } ) # 벌크 인덱스 순서대로 정렬 results.sort(key=lambda x: x["bulk_index"]) return results def process(paragraph): start = time.time() # Step 0: Apply vocabulary correction to input step0 = apply_vocabulary_correction(paragraph) proofread_result = call_proofread(step0) system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) step1 = call_solar_pro2(system_step1, user_step1) step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) # Step 4: Apply vocabulary correction to final output step4 = apply_vocabulary_correction(step3) elapsed = time.time() - start return step4, highlight_diff(paragraph, step4) def demo_fn(input_text): # 병렬 처리로 벌크 단위로 처리 bulk_results = process_text_parallel(input_text, max_workers=10) if not bulk_results: return input_text, input_text # 결과 합치기 final_texts = [r["final"] for r in bulk_results] final_result = "\n".join(final_texts) # 하이라이트 생성 highlighted = highlight_diff(input_text, final_result) return final_result, highlighted with gr.Blocks() as demo: gr.Markdown("# 교열 모델 데모") input_text = gr.Textbox( label="원문 입력", lines=10, placeholder="문단 단위로 입력해 주세요." ) btn = gr.Button("교열하기") output_corrected = gr.Textbox(label="교열 결과", lines=10) output_highlight = gr.HTML(label="수정된 부분 강조") btn.click( fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight] ) if __name__ == "__main__": demo.launch()