|
import gradio as gr |
|
import os |
|
import time |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
from prompts import ( |
|
USER_PROMPT, |
|
WRAPPER_PROMPT, |
|
CALL_1_SYSTEM_PROMPT, |
|
CALL_2_SYSTEM_PROMPT, |
|
CALL_3_SYSTEM_PROMPT, |
|
) |
|
import difflib |
|
import csv |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from threading import Lock |
|
import threading |
|
|
|
load_dotenv() |
|
|
|
BASE_URL = "https://api.upstage.ai/v1" |
|
API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
|
|
|
|
|
|
|
def load_vocabulary(): |
|
vocabulary = {} |
|
with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f: |
|
reader = csv.DictReader(f) |
|
for row in reader: |
|
|
|
if len(vocabulary) == 0: |
|
print("CSV columns:", list(row.keys())) |
|
vocabulary[row["original"]] = row["corrected"] |
|
return vocabulary |
|
|
|
|
|
VOCABULARY = load_vocabulary() |
|
|
|
|
|
counter_lock = Lock() |
|
processed_count = 0 |
|
total_bulks = 0 |
|
|
|
|
|
def apply_vocabulary_correction(text): |
|
for original, corrected in VOCABULARY.items(): |
|
text = text.replace(original, corrected) |
|
return text |
|
|
|
|
|
def create_bulk_paragraphs(text, max_chars=500): |
|
""" |
|
ํ
์คํธ๋ฅผ 500์ ๊ธฐ์ค์ผ๋ก ๋ฒํฌ ๋จ์๋ก ๋ถํ ํฉ๋๋ค. |
|
|
|
Args: |
|
text: ์
๋ ฅ ํ
์คํธ |
|
max_chars: ์ต๋ ๋ฌธ์ ์ (๊ธฐ๋ณธ๊ฐ: 500) |
|
|
|
Returns: |
|
List[str]: ๋ฒํฌ ๋จ์๋ก ๋ถํ ๋ ํ
์คํธ ๋ฆฌ์คํธ |
|
""" |
|
paragraphs = [p.strip() for p in text.split("\n") if p.strip()] |
|
|
|
if not paragraphs: |
|
return [] |
|
|
|
bulks = [] |
|
current_bulk = [] |
|
current_length = 0 |
|
|
|
for para in paragraphs: |
|
para_length = len(para) |
|
|
|
|
|
if para_length > max_chars: |
|
|
|
if current_bulk: |
|
bulks.append("\n".join(current_bulk)) |
|
current_bulk = [] |
|
current_length = 0 |
|
|
|
|
|
bulks.append(para) |
|
else: |
|
|
|
if ( |
|
current_length + para_length + len(current_bulk) > max_chars |
|
and current_bulk |
|
): |
|
|
|
bulks.append("\n".join(current_bulk)) |
|
current_bulk = [para] |
|
current_length = para_length |
|
else: |
|
|
|
current_bulk.append(para) |
|
current_length += para_length |
|
|
|
|
|
if current_bulk: |
|
bulks.append("\n".join(current_bulk)) |
|
|
|
return bulks |
|
|
|
|
|
def process_bulk(bulk_text, bulk_index, max_retries=3): |
|
"""ํ๋์ ๋ฒํฌ๋ฅผ ํ์ดํ๋ผ์ธ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. API ์๋ฌ์ ์ฌํธ์ถํฉ๋๋ค.""" |
|
global processed_count |
|
|
|
thread_id = threading.get_ident() |
|
start = time.time() |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
|
|
step0 = apply_vocabulary_correction(bulk_text) |
|
proofread_result = call_proofread(step0) |
|
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) |
|
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) |
|
step1 = call_solar_pro2(system_step1, user_step1) |
|
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) |
|
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) |
|
|
|
step4 = apply_vocabulary_correction(step3) |
|
|
|
elapsed = time.time() - start |
|
|
|
with counter_lock: |
|
processed_count += 1 |
|
|
|
return { |
|
"bulk_index": bulk_index, |
|
"original": bulk_text, |
|
"final": step4, |
|
"processing_time": elapsed, |
|
"character_count": len(bulk_text), |
|
"attempts": attempt + 1, |
|
} |
|
|
|
except Exception as e: |
|
if attempt < max_retries - 1: |
|
print( |
|
f"[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์๋ {attempt+1} ์คํจ, ์ฌ์๋: {e}" |
|
) |
|
time.sleep(1 * (attempt + 1)) |
|
continue |
|
else: |
|
print(f"[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์ต์ข
์คํจ: {e}") |
|
return { |
|
"bulk_index": bulk_index, |
|
"original": bulk_text, |
|
"final": bulk_text, |
|
"processing_time": 0, |
|
"character_count": len(bulk_text), |
|
"error": str(e), |
|
"attempts": max_retries, |
|
} |
|
|
|
|
|
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"): |
|
response = client.chat.completions.create( |
|
model=model_name, |
|
messages=[ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": user}, |
|
], |
|
stream=False, |
|
temperature=temperature, |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def call_proofread(paragraph): |
|
prompt = "์
๋ ฅ๋ ๋ฌธ์์ ๋ํ ๊ต์ด ๊ฒฐ๊ณผ๋ฅผ ์์ฑํด ์ฃผ์ธ์." |
|
response = client.chat.completions.create( |
|
model="ft:solar-news-correction", |
|
messages=[ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": paragraph}, |
|
], |
|
stream=False, |
|
temperature=0.0, |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def highlight_diff(original, corrected): |
|
matcher = difflib.SequenceMatcher(None, original, corrected) |
|
result_html = [] |
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
|
if tag == "equal": |
|
result_html.append(f"<span>{original[i1:i2]}</span>") |
|
elif tag == "replace": |
|
result_html.append( |
|
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
|
) |
|
result_html.append( |
|
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
|
) |
|
elif tag == "delete": |
|
result_html.append( |
|
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
|
) |
|
elif tag == "insert": |
|
result_html.append( |
|
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
|
) |
|
return "".join(result_html) |
|
|
|
|
|
def process_text_parallel(input_text, max_workers=10): |
|
"""ํ
์คํธ๋ฅผ ๋ฒํฌ ๋จ์๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌํฉ๋๋ค.""" |
|
global processed_count, total_bulks |
|
|
|
|
|
bulks = create_bulk_paragraphs(input_text) |
|
total_bulks = len(bulks) |
|
processed_count = 0 |
|
|
|
if not bulks: |
|
return [] |
|
|
|
results = [] |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
future_to_bulk = { |
|
executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks) |
|
} |
|
|
|
|
|
for future in as_completed(future_to_bulk): |
|
try: |
|
result = future.result() |
|
results.append(result) |
|
except Exception as e: |
|
bulk_index = future_to_bulk[future] |
|
print(f"๋ฒํฌ {bulk_index+1} ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์: {e}") |
|
results.append( |
|
{ |
|
"bulk_index": bulk_index, |
|
"original": bulks[bulk_index], |
|
"final": bulks[bulk_index], |
|
"processing_time": 0, |
|
"character_count": len(bulks[bulk_index]), |
|
"error": str(e), |
|
} |
|
) |
|
|
|
|
|
results.sort(key=lambda x: x["bulk_index"]) |
|
|
|
return results |
|
|
|
|
|
def process(paragraph): |
|
start = time.time() |
|
|
|
step0 = apply_vocabulary_correction(paragraph) |
|
proofread_result = call_proofread(step0) |
|
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) |
|
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) |
|
step1 = call_solar_pro2(system_step1, user_step1) |
|
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) |
|
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) |
|
|
|
step4 = apply_vocabulary_correction(step3) |
|
elapsed = time.time() - start |
|
return step4, highlight_diff(paragraph, step4) |
|
|
|
|
|
def demo_fn(input_text): |
|
|
|
bulk_results = process_text_parallel(input_text, max_workers=10) |
|
|
|
if not bulk_results: |
|
return input_text, input_text |
|
|
|
|
|
final_texts = [r["final"] for r in bulk_results] |
|
final_result = "\n".join(final_texts) |
|
|
|
|
|
highlighted = highlight_diff(input_text, final_result) |
|
|
|
return final_result, highlighted |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# ๊ต์ด ๋ชจ๋ธ ๋ฐ๋ชจ") |
|
input_text = gr.Textbox( |
|
label="์๋ฌธ ์
๋ ฅ", lines=10, placeholder="๋ฌธ๋จ ๋จ์๋ก ์
๋ ฅํด ์ฃผ์ธ์." |
|
) |
|
btn = gr.Button("๊ต์ดํ๊ธฐ") |
|
output_corrected = gr.Textbox(label="๊ต์ด ๊ฒฐ๊ณผ", lines=10) |
|
output_highlight = gr.HTML(label="์์ ๋ ๋ถ๋ถ ๊ฐ์กฐ") |
|
|
|
btn.click( |
|
fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|