blockenters commited on
Commit
ee6db9e
Β·
1 Parent(s): 69396f3
Files changed (2) hide show
  1. app.py +75 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # λͺ¨λΈ λ‘œλ“œ (DialoGPT-medium μ˜ˆμ‹œ)
6
+ @st.cache_resource
7
+ def load_model(model_name="microsoft/DialoGPT-medium"):
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ return tokenizer, model
11
+
12
+ # μ•± μ‹€ν–‰ ν•¨μˆ˜
13
+ def main():
14
+ st.title("ChatGPT μœ μ‚¬ λŒ€ν™” 데λͺ¨")
15
+ st.write("μ—¬κΈ°λŠ” DialoGPT λͺ¨λΈμ„ ν™œμš©ν•œ κ°„λ‹¨ν•œ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
16
+
17
+ # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈ μ΄ˆκΈ°ν™”
18
+ if "chat_history_ids" not in st.session_state:
19
+ st.session_state["chat_history_ids"] = None
20
+ if "past_user_inputs" not in st.session_state:
21
+ st.session_state["past_user_inputs"] = []
22
+ if "generated_responses" not in st.session_state:
23
+ st.session_state["generated_responses"] = []
24
+
25
+ tokenizer, model = load_model()
26
+
27
+ # κΈ°μ‘΄ λŒ€ν™” λ‚΄μ—­ ν‘œμ‹œ
28
+ if st.session_state["past_user_inputs"]:
29
+ for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
30
+ # μ‚¬μš©μž λ©”μ‹œμ§€
31
+ with st.chat_message("user"):
32
+ st.write(user_text)
33
+ # 봇 λ©”μ‹œμ§€
34
+ with st.chat_message("assistant"):
35
+ st.write(bot_text)
36
+
37
+ # μ±„νŒ… μž…λ ₯μ°½
38
+ user_input = st.chat_input("λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...")
39
+
40
+ if user_input:
41
+ # μ‚¬μš©μž λ©”μ‹œμ§€ ν‘œμ‹œ
42
+ with st.chat_message("user"):
43
+ st.write(user_input)
44
+
45
+ # μƒˆ μž…λ ₯을 토큰화
46
+ new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
47
+
48
+ if st.session_state["chat_history_ids"] is not None:
49
+ # κΈ°μ‘΄ νžˆμŠ€ν† λ¦¬μ— 이어 뢙이기
50
+ bot_input_ids = torch.cat([st.session_state["chat_history_ids"], new_user_input_ids], dim=-1)
51
+ else:
52
+ bot_input_ids = new_user_input_ids
53
+
54
+ # λͺ¨λΈ μΆ”λ‘ 
55
+ with torch.no_grad():
56
+ chat_history_ids = model.generate(
57
+ bot_input_ids,
58
+ max_length=1000,
59
+ pad_token_id=tokenizer.eos_token_id
60
+ )
61
+
62
+ # κ²°κ³Ό λ””μ½”λ”©
63
+ bot_text = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
64
+
65
+ # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈμ— λŒ€ν™” λ‚΄μš© μ—…λ°μ΄νŠΈ
66
+ st.session_state["past_user_inputs"].append(user_input)
67
+ st.session_state["generated_responses"].append(bot_text)
68
+ st.session_state["chat_history_ids"] = chat_history_ids
69
+
70
+ # 봇 λ©”μ‹œμ§€ ν‘œμ‹œ
71
+ with st.chat_message("assistant"):
72
+ st.write(bot_text)
73
+
74
+ if __name__ == "__main__":
75
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers