ojs595 commited on
Commit
22b2ce1
ยท
verified ยท
1 Parent(s): 2c0dc53

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +149 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import io
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from torch.optim import AdamW
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
11
+ MODEL_NAME = "beomi/kcbert-base"
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
14
+
15
+ # ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜
16
+ class CustomDataset(Dataset):
17
+ def __init__(self, dataframe, tokenizer, max_len=128):
18
+ self.tokenizer = tokenizer
19
+ self.data = dataframe
20
+ self.max_len = max_len
21
+
22
+ def __len__(self):
23
+ return len(self.data)
24
+
25
+ def __getitem__(self, index):
26
+ item = self.data.iloc[index]
27
+ description = str(item['description'])
28
+ label = item['label']
29
+
30
+ encoding = self.tokenizer.encode_plus(
31
+ description,
32
+ add_special_tokens=True,
33
+ max_length=self.max_len,
34
+ return_token_type_ids=False,
35
+ padding='max_length',
36
+ truncation=True,
37
+ return_attention_mask=True,
38
+ return_tensors='pt',
39
+ )
40
+
41
+ return {
42
+ 'input_ids': encoding['input_ids'].flatten(),
43
+ 'attention_mask': encoding['attention_mask'].flatten(),
44
+ 'labels': torch.tensor(label, dtype=torch.long)
45
+ }
46
+
47
+ # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์ค€๋น„ ๋ฐ ๋ชจ๋ธ ํ›ˆ๋ จ
48
+ def train_model():
49
+ csv_data = """description,gender
50
+ "๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค.",๋‚จ์ž
51
+ "๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค.",์—ฌ์ž
52
+ "์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค.",๋‚จ์ž
53
+ "์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค.",์—ฌ์ž
54
+ "๊ทธ์˜ ์ทจ๋ฏธ๋Š” ์ž๋™์ฐจ ์ •๋น„์™€ ์ปดํ“จํ„ฐ ๊ฒŒ์ž„์ด๋‹ค.",๋‚จ์ž
55
+ "๊ทธ๋…€๋Š” ์„ฌ์„ธํ•œ ์†๊ธธ๋กœ ์•„๊ธฐ ์ธํ˜•์„ ๋งŒ๋“ค์—ˆ๋‹ค.",์—ฌ์ž
56
+ "๊ตฐ๋Œ€์—์„œ ๋ง‰ ์ œ๋Œ€ํ•œ ๊ทธ๋Š” ์”ฉ์”ฉํ•ด ๋ณด์˜€๋‹ค.",๋‚จ์ž
57
+ "๊ทธ๋…€๋Š” ์นœ๊ตฌ๋“ค๊ณผ ์ˆ˜๋‹ค ๋– ๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•œ๋‹ค.",์—ฌ์ž
58
+ "๊ฐ•๋ ฅํ•œ ๋ฆฌ๋”์‹ญ์œผ๋กœ ํŒ€์„ ์ด๋„๋Š” ๋ชจ์Šต์ด ์ธ์ƒ์ ์ด์—ˆ๋‹ค.",๋‚จ์ž
59
+ "์ž์‹ ์ด ์ง์ ‘ ๋งŒ๋“  ์ฟ ํ‚ค๋ฅผ ์ฃผ๋ณ€์— ๋‚˜๋ˆ„์–ด์ฃผ๊ณค ํ•œ๋‹ค.",์—ฌ์ž
60
+ "์•ˆ์ผ์ฐฌ",์—ฌ์ž
61
+ """
62
+
63
+ data = pd.read_csv(io.StringIO(csv_data))
64
+ data['label'] = data['gender'].apply(lambda x: 0 if x == '๋‚จ์ž' else 1)
65
+ train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
66
+
67
+ train_dataset = CustomDataset(train_data, tokenizer)
68
+ train_loader = DataLoader(train_dataset, batch_size=2)
69
+
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model.to(device)
72
+
73
+ optimizer = AdamW(model.parameters(), lr=5e-5)
74
+
75
+ print("๋ชจ๋ธ ํ›ˆ๋ จ ์‹œ์ž‘...")
76
+ model.train()
77
+ for epoch in range(3):
78
+ for batch in train_loader:
79
+ optimizer.zero_grad()
80
+ input_ids = batch['input_ids'].to(device)
81
+ attention_mask = batch['attention_mask'].to(device)
82
+ labels = batch['labels'].to(device)
83
+
84
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
85
+ loss = outputs.loss
86
+ loss.backward()
87
+ optimizer.step()
88
+ print(f"Epoch {epoch + 1} ์™„๋ฃŒ")
89
+
90
+ print("๋ชจ๋ธ ํ›ˆ๋ จ ์™„๋ฃŒ!")
91
+
92
+ # ์˜ˆ์ธก ํ•จ์ˆ˜
93
+ def predict_gender(text):
94
+ if not text.strip():
95
+ return "ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
96
+
97
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ model.eval()
99
+
100
+ encoding = tokenizer.encode_plus(
101
+ text,
102
+ add_special_tokens=True,
103
+ max_length=128,
104
+ return_token_type_ids=False,
105
+ padding='max_length',
106
+ truncation=True,
107
+ return_attention_mask=True,
108
+ return_tensors='pt',
109
+ )
110
+
111
+ input_ids = encoding['input_ids'].to(device)
112
+ attention_mask = encoding['attention_mask'].to(device)
113
+
114
+ with torch.no_grad():
115
+ outputs = model(input_ids, attention_mask=attention_mask)
116
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
117
+ prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
118
+ confidence = probabilities[0][prediction].item()
119
+
120
+ gender = "๋‚จ์ž" if prediction == 0 else "์—ฌ์ž"
121
+ return f"์˜ˆ์ธก ์„ฑ๋ณ„: {gender} (์‹ ๋ขฐ๋„: {confidence:.2%})"
122
+
123
+ # ์•ฑ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ํ›ˆ๋ จ
124
+ print("์•ฑ ์ดˆ๊ธฐํ™” ์ค‘...")
125
+ train_model()
126
+
127
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
128
+ iface = gr.Interface(
129
+ fn=predict_gender,
130
+ inputs=gr.Textbox(
131
+ lines=3,
132
+ placeholder="์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.\n์˜ˆ: '๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ข‹์•„ํ•˜๊ณ  ๊ทผ์œก์งˆ์ด๋‹ค.'",
133
+ label="ํ…์ŠคํŠธ ์ž…๋ ฅ"
134
+ ),
135
+ outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
136
+ title="๐Ÿค– AI ์„ฑ๋ณ„ ์˜ˆ๏ฟฝ๏ฟฝ๏ฟฝ๊ธฐ",
137
+ description="์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.",
138
+ examples=[
139
+ ["๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค."],
140
+ ["๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค."],
141
+ ["์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค."],
142
+ ["์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค."]
143
+ ],
144
+ theme=gr.themes.Soft()
145
+ )
146
+
147
+ # ์•ฑ ์‹คํ–‰
148
+ if __name__ == "__main__":
149
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ pandas
5
+ scikit-learn