ChloeLee22's picture
dev ๋ชจ๋ธ๋กœ ์—…๋ฐ์ดํŠธ
4b70df2 verified
import gradio as gr
import os
import time
from openai import OpenAI
from dotenv import load_dotenv
from prompts import (
USER_PROMPT,
WRAPPER_PROMPT,
CALL_1_SYSTEM_PROMPT,
CALL_2_SYSTEM_PROMPT,
CALL_3_SYSTEM_PROMPT,
)
import difflib
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import threading
load_dotenv()
BASE_URL = "https://api.upstage.ai/v1"
API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
# Load vocabulary for rule-based correction
def load_vocabulary():
vocabulary = {}
with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
# Debug: print first row to check column names
if len(vocabulary) == 0:
print("CSV columns:", list(row.keys()))
vocabulary[row["original"]] = row["corrected"]
return vocabulary
VOCABULARY = load_vocabulary()
# ์Šค๋ ˆ๋“œ ์•ˆ์ „ํ•œ ์นด์šดํ„ฐ
counter_lock = Lock()
processed_count = 0
total_bulks = 0
def apply_vocabulary_correction(text):
for original, corrected in VOCABULARY.items():
text = text.replace(original, corrected)
return text
def create_bulk_paragraphs(text, max_chars=500):
"""
ํ…์ŠคํŠธ๋ฅผ 500์ž ๊ธฐ์ค€์œผ๋กœ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
Args:
text: ์ž…๋ ฅ ํ…์ŠคํŠธ
max_chars: ์ตœ๋Œ€ ๋ฌธ์ž ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 500)
Returns:
List[str]: ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ถ„ํ• ๋œ ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
"""
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
if not paragraphs:
return []
bulks = []
current_bulk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# ํ˜„์žฌ ๋ฌธ๋‹จ์ด 500์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ
if para_length > max_chars:
# ํ˜„์žฌ ๋ฒŒํฌ๊ฐ€ ์žˆ๋‹ค๋ฉด ์ถ”๊ฐ€
if current_bulk:
bulks.append("\n".join(current_bulk))
current_bulk = []
current_length = 0
# ๊ธด ๋ฌธ๋‹จ์€ ๋‹จ๋…์œผ๋กœ ์ฒ˜๋ฆฌ
bulks.append(para)
else:
# ํ˜„์žฌ ๋ฒŒํฌ์— ์ถ”๊ฐ€ํ–ˆ์„ ๋•Œ 500์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ
if (
current_length + para_length + len(current_bulk) > max_chars
and current_bulk
):
# ํ˜„์žฌ ๋ฒŒํฌ๋ฅผ ์™„์„ฑํ•˜๊ณ  ์ƒˆ ๋ฒŒํฌ ์‹œ์ž‘
bulks.append("\n".join(current_bulk))
current_bulk = [para]
current_length = para_length
else:
# ํ˜„์žฌ ๋ฒŒํฌ์— ์ถ”๊ฐ€
current_bulk.append(para)
current_length += para_length
# ๋งˆ์ง€๋ง‰ ๋ฒŒํฌ ์ถ”๊ฐ€
if current_bulk:
bulks.append("\n".join(current_bulk))
return bulks
def process_bulk(bulk_text, bulk_index, max_retries=3):
"""ํ•˜๋‚˜์˜ ๋ฒŒํฌ๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. API ์—๋Ÿฌ์‹œ ์žฌํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค."""
global processed_count
thread_id = threading.get_ident()
start = time.time()
for attempt in range(max_retries):
try:
# Step 0: Apply vocabulary correction to input
step0 = apply_vocabulary_correction(bulk_text)
proofread_result = call_proofread(step0)
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT)
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result)
step1 = call_solar_pro2(system_step1, user_step1)
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
# Step 4: Apply vocabulary correction to final output
step4 = apply_vocabulary_correction(step3)
elapsed = time.time() - start
with counter_lock:
processed_count += 1
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": step4,
"processing_time": elapsed,
"character_count": len(bulk_text),
"attempts": attempt + 1,
}
except Exception as e:
if attempt < max_retries - 1:
print(
f"[Thread-{thread_id}] ๋ฒŒํฌ {bulk_index+1} ์‹œ๋„ {attempt+1} ์‹คํŒจ, ์žฌ์‹œ๋„: {e}"
)
time.sleep(1 * (attempt + 1)) # ์ ์ง„์  ๋Œ€๊ธฐ
continue
else:
print(f"[Thread-{thread_id}] ๋ฒŒํฌ {bulk_index+1} ์ตœ์ข… ์‹คํŒจ: {e}")
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": bulk_text, # ์˜ค๋ฅ˜ ์‹œ ์›๋ณธ ๋ฐ˜ํ™˜
"processing_time": 0,
"character_count": len(bulk_text),
"error": str(e),
"attempts": max_retries,
}
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"):
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
stream=False,
temperature=temperature,
)
return response.choices[0].message.content
def call_proofread(paragraph):
prompt = "์ž…๋ ฅ๋œ ๋ฌธ์„œ์— ๋Œ€ํ•œ ๊ต์—ด ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•ด ์ฃผ์„ธ์š”."
response = client.chat.completions.create(
model="ft:solar-news-correction-dev",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": paragraph},
],
stream=False,
temperature=0.0,
)
return response.choices[0].message.content
def highlight_diff(original, corrected):
matcher = difflib.SequenceMatcher(None, original, corrected)
result_html = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
result_html.append(f"<span>{original[i1:i2]}</span>")
elif tag == "replace":
result_html.append(
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>'
)
result_html.append(
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>'
)
elif tag == "delete":
result_html.append(
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>'
)
elif tag == "insert":
result_html.append(
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>'
)
return "".join(result_html)
def process_text_parallel(input_text, max_workers=10):
"""ํ…์ŠคํŠธ๋ฅผ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค."""
global processed_count, total_bulks
# ๋ฒŒํฌ ์ƒ์„ฑ
bulks = create_bulk_paragraphs(input_text)
total_bulks = len(bulks)
processed_count = 0
if not bulks:
return []
results = []
# ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# ๋ชจ๋“  ๋ฒŒํฌ๋ฅผ ๋ณ‘๋ ฌ๋กœ ์ œ์ถœ
future_to_bulk = {
executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks)
}
# ์™„๋ฃŒ๋œ ์ˆœ์„œ๋Œ€๋กœ ๊ฒฐ๊ณผ ์ˆ˜์ง‘
for future in as_completed(future_to_bulk):
try:
result = future.result()
results.append(result)
except Exception as e:
bulk_index = future_to_bulk[future]
print(f"๋ฒŒํฌ {bulk_index+1} ์ฒ˜๋ฆฌ ์ค‘ ์˜ˆ์™ธ ๋ฐœ์ƒ: {e}")
results.append(
{
"bulk_index": bulk_index,
"original": bulks[bulk_index],
"final": bulks[bulk_index],
"processing_time": 0,
"character_count": len(bulks[bulk_index]),
"error": str(e),
}
)
# ๋ฒŒํฌ ์ธ๋ฑ์Šค ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ
results.sort(key=lambda x: x["bulk_index"])
return results
def process(paragraph):
start = time.time()
# Step 0: Apply vocabulary correction to input
step0 = apply_vocabulary_correction(paragraph)
proofread_result = call_proofread(step0)
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT)
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result)
step1 = call_solar_pro2(system_step1, user_step1)
step2 = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
step3 = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
# Step 4: Apply vocabulary correction to final output
step4 = apply_vocabulary_correction(step3)
elapsed = time.time() - start
return step4, highlight_diff(paragraph, step4)
def demo_fn(input_text):
# ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋กœ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ์ฒ˜๋ฆฌ
bulk_results = process_text_parallel(input_text, max_workers=10)
if not bulk_results:
return input_text, input_text
# ๊ฒฐ๊ณผ ํ•ฉ์น˜๊ธฐ
final_texts = [r["final"] for r in bulk_results]
final_result = "\n".join(final_texts)
# ํ•˜์ด๋ผ์ดํŠธ ์ƒ์„ฑ
highlighted = highlight_diff(input_text, final_result)
return final_result, highlighted
with gr.Blocks() as demo:
gr.Markdown("# ๊ต์—ด ๋ชจ๋ธ ๋ฐ๋ชจ")
input_text = gr.Textbox(
label="์›๋ฌธ ์ž…๋ ฅ", lines=10, placeholder="๋ฌธ๋‹จ ๋‹จ์œ„๋กœ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”."
)
btn = gr.Button("๊ต์—ดํ•˜๊ธฐ")
output_corrected = gr.Textbox(label="๊ต์—ด ๊ฒฐ๊ณผ", lines=10)
output_highlight = gr.HTML(label="์ˆ˜์ •๋œ ๋ถ€๋ถ„ ๊ฐ•์กฐ")
btn.click(
fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight]
)
if __name__ == "__main__":
demo.launch()