Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import gradio as gr | |
import pandas as pd | |
import io | |
from torch.utils.data import DataLoader, Dataset | |
from torch.optim import AdamW | |
from sklearn.model_selection import train_test_split | |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ | |
MODEL_NAME = "beomi/kcbert-base" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
# ๋ฐ์ดํฐ์ ํด๋์ค ์ ์ | |
class CustomDataset(Dataset): | |
def __init__(self, dataframe, tokenizer, max_len=128): | |
self.tokenizer = tokenizer | |
self.data = dataframe | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
item = self.data.iloc[index] | |
description = str(item['description']) | |
label = item['label'] | |
encoding = self.tokenizer.encode_plus( | |
description, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
return { | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten(), | |
'labels': torch.tensor(label, dtype=torch.long) | |
} | |
# ํ๋ จ ๋ฐ์ดํฐ ์ค๋น ๋ฐ ๋ชจ๋ธ ํ๋ จ | |
def train_model(): | |
csv_data = """description,gender | |
"๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค.",๋จ์ | |
"๊ทธ๋ ๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์ ์๋ค.",์ฌ์ | |
"์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์ ์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค.",๋จ์ | |
"์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋ ธ๋ํ๋ ๊ทธ๋ ๋ ๊ฐ์๋ค.",์ฌ์ | |
"๊ทธ์ ์ทจ๋ฏธ๋ ์๋์ฐจ ์ ๋น์ ์ปดํจํฐ ๊ฒ์์ด๋ค.",๋จ์ | |
"๊ทธ๋ ๋ ์ฌ์ธํ ์๊ธธ๋ก ์๊ธฐ ์ธํ์ ๋ง๋ค์๋ค.",์ฌ์ | |
"๊ตฐ๋์์ ๋ง ์ ๋ํ ๊ทธ๋ ์ฉ์ฉํด ๋ณด์๋ค.",๋จ์ | |
"๊ทธ๋ ๋ ์น๊ตฌ๋ค๊ณผ ์๋ค ๋ ๋ ๊ฒ์ ์ข์ํ๋ค.",์ฌ์ | |
"๊ฐ๋ ฅํ ๋ฆฌ๋์ญ์ผ๋ก ํ์ ์ด๋๋ ๋ชจ์ต์ด ์ธ์์ ์ด์๋ค.",๋จ์ | |
"์์ ์ด ์ง์ ๋ง๋ ์ฟ ํค๋ฅผ ์ฃผ๋ณ์ ๋๋์ด์ฃผ๊ณค ํ๋ค.",์ฌ์ | |
"์ ๋ฏผ์ง",์ฒ์ฌ์ฌ | |
""" | |
data = pd.read_csv(io.StringIO(csv_data)) | |
data['label'] = data['gender'].apply(lambda x: 0 if x == '๋จ์' else 1) | |
train_data, _ = train_test_split(data, test_size=0.2, random_state=42) | |
train_dataset = CustomDataset(train_data, tokenizer) | |
train_loader = DataLoader(train_dataset, batch_size=2) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
optimizer = AdamW(model.parameters(), lr=5e-5) | |
print("๋ชจ๋ธ ํ๋ จ ์์...") | |
model.train() | |
for epoch in range(3): | |
for batch in train_loader: | |
optimizer.zero_grad() | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
labels = batch['labels'].to(device) | |
outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch + 1} ์๋ฃ") | |
print("๋ชจ๋ธ ํ๋ จ ์๋ฃ!") | |
# ์์ธก ํจ์ | |
def predict_gender(text): | |
if not text.strip(): | |
return "ํ ์คํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์." | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.eval() | |
encoding = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=128, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
prediction = torch.argmax(outputs.logits, dim=1).flatten().item() | |
confidence = probabilities[0][prediction].item() | |
gender = "๋จ์" if prediction == 0 else "์ฌ์" | |
return f"์์ธก ์ฑ๋ณ: {gender} (์ ๋ขฐ๋: {confidence:.2%})" | |
# ์ฑ ์์ ์ ๋ชจ๋ธ ํ๋ จ | |
print("์ฑ ์ด๊ธฐํ ์ค...") | |
train_model() | |
# Gradio ์ธํฐํ์ด์ค ์์ฑ | |
iface = gr.Interface( | |
fn=predict_gender, | |
inputs=gr.Textbox( | |
lines=3, | |
placeholder="์ฑ๋ณ์ ์์ธกํ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์.\n์: '๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ข์ํ๊ณ ๊ทผ์ก์ง์ด๋ค.'", | |
label="ํ ์คํธ ์ ๋ ฅ" | |
), | |
outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"), | |
title="๐ค AI ์ฑ๋ณ ์์ธก๊ธฐ", | |
description="์ ๋ ฅ๋ ํ ์คํธ๋ฅผ ๋ฐํ์ผ๋ก ์ฑ๋ณ์ ์์ธกํฉ๋๋ค.", | |
examples=[ | |
["๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค."], | |
["๊ทธ๋ ๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์ ์๋ค."], | |
["์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์ ์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค."], | |
["์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋ ธ๋ํ๋ ๊ทธ๋ ๋ ๊ฐ์๋ค."] | |
], | |
theme=gr.themes.Soft() | |
) | |
# ์ฑ ์คํ | |
if __name__ == "__main__": | |
iface.launch() |