Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
|
|
10 |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋
|
11 |
MODEL_NAME = "beomi/kcbert-base"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
13 |
-
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=
|
14 |
|
15 |
# ๋ฐ์ดํฐ์
ํด๋์ค ์ ์
|
16 |
class CustomDataset(Dataset):
|
@@ -57,11 +57,12 @@ def train_model():
|
|
57 |
"๊ทธ๋
๋ ์น๊ตฌ๋ค๊ณผ ์๋ค ๋ ๋ ๊ฒ์ ์ข์ํ๋ค.",์ฌ์
|
58 |
"๊ฐ๋ ฅํ ๋ฆฌ๋์ญ์ผ๋ก ํ์ ์ด๋๋ ๋ชจ์ต์ด ์ธ์์ ์ด์๋ค.",๋จ์
|
59 |
"์์ ์ด ์ง์ ๋ง๋ ์ฟ ํค๋ฅผ ์ฃผ๋ณ์ ๋๋์ด์ฃผ๊ณค ํ๋ค.",์ฌ์
|
60 |
-
"์ ๋ฏผ์ง"
|
61 |
"""
|
62 |
|
63 |
data = pd.read_csv(io.StringIO(csv_data))
|
64 |
-
|
|
|
65 |
train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
|
66 |
|
67 |
train_dataset = CustomDataset(train_data, tokenizer)
|
@@ -117,7 +118,10 @@ def predict_gender(text):
|
|
117 |
prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
|
118 |
confidence = probabilities[0][prediction].item()
|
119 |
|
120 |
-
|
|
|
|
|
|
|
121 |
return f"์์ธก ์ฑ๋ณ: {gender} (์ ๋ขฐ๋: {confidence:.2%})"
|
122 |
|
123 |
# ์ฑ ์์ ์ ๋ชจ๋ธ ํ๋ จ
|
@@ -133,13 +137,15 @@ iface = gr.Interface(
|
|
133 |
label="ํ
์คํธ ์
๋ ฅ"
|
134 |
),
|
135 |
outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"),
|
136 |
-
title="๐ค AI ์ฑ๋ณ ์์ธก๊ธฐ",
|
137 |
-
description="์
๋ ฅ๋ ํ
์คํธ๋ฅผ ๋ฐํ์ผ๋ก ์ฑ๋ณ์ ์์ธกํฉ๋๋ค.",
|
138 |
examples=[
|
139 |
["๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค."],
|
140 |
["๊ทธ๋
๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์
์๋ค."],
|
141 |
["์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์
์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค."],
|
142 |
-
["์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋
ธ๋ํ๋ ๊ทธ๋
๋ ๊ฐ์๋ค."]
|
|
|
|
|
143 |
],
|
144 |
theme=gr.themes.Soft()
|
145 |
)
|
|
|
10 |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋
|
11 |
MODEL_NAME = "beomi/kcbert-base"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
13 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # 3๊ฐ ํด๋์ค๋ก ๋ณ๊ฒฝ
|
14 |
|
15 |
# ๋ฐ์ดํฐ์
ํด๋์ค ์ ์
|
16 |
class CustomDataset(Dataset):
|
|
|
57 |
"๊ทธ๋
๋ ์น๊ตฌ๋ค๊ณผ ์๋ค ๋ ๋ ๊ฒ์ ์ข์ํ๋ค.",์ฌ์
|
58 |
"๊ฐ๋ ฅํ ๋ฆฌ๋์ญ์ผ๋ก ํ์ ์ด๋๋ ๋ชจ์ต์ด ์ธ์์ ์ด์๋ค.",๋จ์
|
59 |
"์์ ์ด ์ง์ ๋ง๋ ์ฟ ํค๋ฅผ ์ฃผ๋ณ์ ๋๋์ด์ฃผ๊ณค ํ๋ค.",์ฌ์
|
60 |
+
"์ ๋ฏผ์ง",์ฒ์ฌ
|
61 |
"""
|
62 |
|
63 |
data = pd.read_csv(io.StringIO(csv_data))
|
64 |
+
# 3๊ฐ ํด๋์ค๋ก ๋ผ๋ฒจ ๋ณ๊ฒฝ: ๋จ์=0, ์ฌ์=1, ์ฒ์ฌ=2
|
65 |
+
data['label'] = data['gender'].apply(lambda x: 0 if x == '๋จ์' else (1 if x == '์ฌ์' else 2))
|
66 |
train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
|
67 |
|
68 |
train_dataset = CustomDataset(train_data, tokenizer)
|
|
|
118 |
prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
|
119 |
confidence = probabilities[0][prediction].item()
|
120 |
|
121 |
+
# 3๊ฐ ํด๋์ค ๋งคํ: 0=๋จ์, 1=์ฌ์, 2=์ฒ์ฌ
|
122 |
+
gender_map = {0: "๋จ์", 1: "์ฌ์", 2: "์ฒ์ฌ"}
|
123 |
+
gender = gender_map[prediction]
|
124 |
+
|
125 |
return f"์์ธก ์ฑ๋ณ: {gender} (์ ๋ขฐ๋: {confidence:.2%})"
|
126 |
|
127 |
# ์ฑ ์์ ์ ๋ชจ๋ธ ํ๋ จ
|
|
|
137 |
label="ํ
์คํธ ์
๋ ฅ"
|
138 |
),
|
139 |
outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"),
|
140 |
+
title="๐ค AI ์ฑ๋ณ ์์ธก๊ธฐ (3๋ถ๋ฅ)",
|
141 |
+
description="์
๋ ฅ๋ ํ
์คํธ๋ฅผ ๋ฐํ์ผ๋ก ์ฑ๋ณ์ ์์ธกํฉ๋๋ค. (๋จ์/์ฌ์/์ฒ์ฌ)",
|
142 |
examples=[
|
143 |
["๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค."],
|
144 |
["๊ทธ๋
๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์
์๋ค."],
|
145 |
["์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์
์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค."],
|
146 |
+
["์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋
ธ๋ํ๋ ๊ทธ๋
๋ ๊ฐ์๋ค."],
|
147 |
+
["๊ทธ๋ค์ ์ฑ
์ฝ๊ธฐ๋ฅผ ์ข์ํ๊ณ ์กฐ์ฉํ ์ฑ๊ฒฉ์ด๋ค."],
|
148 |
+
["์๋ฆฌ์ ์ฒญ์๋ฅผ ๋ชจ๋ ์ํ๋ฉฐ ์ง์์ผ์ ๋๋งก์ ํ๋ค."]
|
149 |
],
|
150 |
theme=gr.themes.Soft()
|
151 |
)
|