YEINJEONG commited on
Commit
224d6a6
·
verified ·
1 Parent(s): f7e4a64

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +42 -0
  2. sentiment7_model_acc0.9653.pth +3 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
4
+ import numpy as np
5
+
6
+
7
+ # Load the model and tokenizer
8
+ def load_model():
9
+ tokenizer = BertTokenizer.from_pretrained('beomi/kcbert-base')
10
+ config = BertConfig.from_pretrained('beomi/kcbert-base', num_labels=7)
11
+ model = BertForSequenceClassification.from_pretrained('beomi/kcbert-base', config=config)
12
+ model_state_dict = torch.load('sentiment7_model_acc8878.pth', map_location=torch.device('cpu')) # cpu 사용
13
+ model.load_state_dict(model_state_dict)
14
+ model.eval()
15
+ return model, tokenizer
16
+
17
+ model, tokenizer = load_model()
18
+
19
+ # Define the inference function
20
+ def inference(input_doc):
21
+ inputs = tokenizer(input_doc, return_tensors='pt')
22
+ outputs = model(**inputs)
23
+ probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
24
+ class_idx = {'공포': 0, '놀람': 1, '분노': 2, '슬픔': 3, '중립': 4, '행복': 5, '혐오': 6}
25
+ results = {class_name: prob for class_name, prob in zip(class_idx, probs)}
26
+ # Find the class with the highest probability
27
+ max_prob_class = max(results, key=results.get)
28
+ max_prob = results[max_prob_class]
29
+ # Display results
30
+ return [results, f"가장 강하게 나타난 감정: {max_prob_class}"]
31
+ ''' for class_name, prob in results.items():
32
+ print(f"{class_name}: {prob:.2%}")'''
33
+
34
+ # Set up the Streamlit interface
35
+ st.title('감정분석(Sentiment Analysis)')
36
+ st.markdown('<small style="color:grey;">글에 나타난 공포, 놀람, 분노, 슬픔, 중립, 행복, 혐오의 정도를 비율로 알려드립니다.</small>', unsafe_allow_html=True)
37
+ user_input = st.text_area("이 곳에 글 입력(100자 이하 권장):")
38
+ if st.button('시작'):
39
+ result = inference(user_input)
40
+ st.write(result[0])
41
+ st.write(result[1])
42
+
sentiment7_model_acc0.9653.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a6e54803c4c1d0eae2c24cc17cf411f68f8e5e37af89dce068b3140abc7761d
3
+ size 435781207