ojs595 commited on
Commit
9d26c64
ยท
verified ยท
1 Parent(s): 09f8eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
10
  # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
11
  MODEL_NAME = "beomi/kcbert-base"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
14
 
15
  # ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜
16
  class CustomDataset(Dataset):
@@ -57,11 +57,12 @@ def train_model():
57
  "๊ทธ๋…€๋Š” ์นœ๊ตฌ๋“ค๊ณผ ์ˆ˜๋‹ค ๋– ๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•œ๋‹ค.",์—ฌ์ž
58
  "๊ฐ•๋ ฅํ•œ ๋ฆฌ๋”์‹ญ์œผ๋กœ ํŒ€์„ ์ด๋„๋Š” ๋ชจ์Šต์ด ์ธ์ƒ์ ์ด์—ˆ๋‹ค.",๋‚จ์ž
59
  "์ž์‹ ์ด ์ง์ ‘ ๋งŒ๋“  ์ฟ ํ‚ค๋ฅผ ์ฃผ๋ณ€์— ๋‚˜๋ˆ„์–ด์ฃผ๊ณค ํ•œ๋‹ค.",์—ฌ์ž
60
- "์ •๋ฏผ์ง€",์ฒœ์‚ฌ์‚ฌ
61
  """
62
 
63
  data = pd.read_csv(io.StringIO(csv_data))
64
- data['label'] = data['gender'].apply(lambda x: 0 if x == '๋‚จ์ž' else 1)
 
65
  train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
66
 
67
  train_dataset = CustomDataset(train_data, tokenizer)
@@ -117,7 +118,10 @@ def predict_gender(text):
117
  prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
118
  confidence = probabilities[0][prediction].item()
119
 
120
- gender = "๋‚จ์ž" if prediction == 0 else "์—ฌ์ž"
 
 
 
121
  return f"์˜ˆ์ธก ์„ฑ๋ณ„: {gender} (์‹ ๋ขฐ๋„: {confidence:.2%})"
122
 
123
  # ์•ฑ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ํ›ˆ๋ จ
@@ -133,13 +137,15 @@ iface = gr.Interface(
133
  label="ํ…์ŠคํŠธ ์ž…๋ ฅ"
134
  ),
135
  outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
136
- title="๐Ÿค– AI ์„ฑ๋ณ„ ์˜ˆ์ธก๊ธฐ",
137
- description="์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.",
138
  examples=[
139
  ["๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค."],
140
  ["๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค."],
141
  ["์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค."],
142
- ["์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค."]
 
 
143
  ],
144
  theme=gr.themes.Soft()
145
  )
 
10
  # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
11
  MODEL_NAME = "beomi/kcbert-base"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # 3๊ฐœ ํด๋ž˜์Šค๋กœ ๋ณ€๊ฒฝ
14
 
15
  # ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜
16
  class CustomDataset(Dataset):
 
57
  "๊ทธ๋…€๋Š” ์นœ๊ตฌ๋“ค๊ณผ ์ˆ˜๋‹ค ๋– ๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•œ๋‹ค.",์—ฌ์ž
58
  "๊ฐ•๋ ฅํ•œ ๋ฆฌ๋”์‹ญ์œผ๋กœ ํŒ€์„ ์ด๋„๋Š” ๋ชจ์Šต์ด ์ธ์ƒ์ ์ด์—ˆ๋‹ค.",๋‚จ์ž
59
  "์ž์‹ ์ด ์ง์ ‘ ๋งŒ๋“  ์ฟ ํ‚ค๋ฅผ ์ฃผ๋ณ€์— ๋‚˜๋ˆ„์–ด์ฃผ๊ณค ํ•œ๋‹ค.",์—ฌ์ž
60
+ "์ •๋ฏผ์ง€",์ฒœ์‚ฌ
61
  """
62
 
63
  data = pd.read_csv(io.StringIO(csv_data))
64
+ # 3๊ฐœ ํด๋ž˜์Šค๋กœ ๋ผ๋ฒจ ๋ณ€๊ฒฝ: ๋‚จ์ž=0, ์—ฌ์ž=1, ์ฒœ์‚ฌ=2
65
+ data['label'] = data['gender'].apply(lambda x: 0 if x == '๋‚จ์ž' else (1 if x == '์—ฌ์ž' else 2))
66
  train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
67
 
68
  train_dataset = CustomDataset(train_data, tokenizer)
 
118
  prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
119
  confidence = probabilities[0][prediction].item()
120
 
121
+ # 3๊ฐœ ํด๋ž˜์Šค ๋งคํ•‘: 0=๋‚จ์ž, 1=์—ฌ์ž, 2=์ฒœ์‚ฌ
122
+ gender_map = {0: "๋‚จ์ž", 1: "์—ฌ์ž", 2: "์ฒœ์‚ฌ"}
123
+ gender = gender_map[prediction]
124
+
125
  return f"์˜ˆ์ธก ์„ฑ๋ณ„: {gender} (์‹ ๋ขฐ๋„: {confidence:.2%})"
126
 
127
  # ์•ฑ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ํ›ˆ๋ จ
 
137
  label="ํ…์ŠคํŠธ ์ž…๋ ฅ"
138
  ),
139
  outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
140
+ title="๐Ÿค– AI ์„ฑ๋ณ„ ์˜ˆ์ธก๊ธฐ (3๋ถ„๋ฅ˜)",
141
+ description="์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. (๋‚จ์ž/์—ฌ์ž/์ฒœ์‚ฌ)",
142
  examples=[
143
  ["๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค."],
144
  ["๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค."],
145
  ["์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค."],
146
+ ["์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค."],
147
+ ["๊ทธ๋“ค์€ ์ฑ… ์ฝ๊ธฐ๋ฅผ ์ข‹์•„ํ•˜๊ณ  ์กฐ์šฉํ•œ ์„ฑ๊ฒฉ์ด๋‹ค."],
148
+ ["์š”๋ฆฌ์™€ ์ฒญ์†Œ๋ฅผ ๋ชจ๋‘ ์ž˜ํ•˜๋ฉฐ ์ง‘์•ˆ์ผ์„ ๋„๋งก์•„ ํ•œ๋‹ค."]
149
  ],
150
  theme=gr.themes.Soft()
151
  )