File size: 2,963 Bytes
ee6db9e
 
498e0ad
ee6db9e
e0ed46b
ee6db9e
e0ed46b
498e0ad
 
 
e0ed46b
 
d29211b
 
e05182e
e0ed46b
498e0ad
ee6db9e
 
 
e0ed46b
 
1ac6501
ee6db9e
 
 
 
 
 
 
 
e0ed46b
 
498e0ad
e0ed46b
ee6db9e
 
 
 
 
 
 
 
 
e0ed46b
ee6db9e
1ac6501
e0ed46b
ee6db9e
f6a2a35
 
 
 
 
 
 
 
d29211b
 
e05182e
d29211b
 
e05182e
 
d29211b
498e0ad
e0ed46b
e05182e
 
 
 
 
 
 
f6a2a35
ee6db9e
 
 
e0ed46b
ee6db9e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# λͺ¨λΈ λ‘œλ“œ (DeepSeek-R1-Distill-Qwen-1.5B μ˜ˆμ‹œ)
@st.cache_resource
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    pipe = pipeline(
        "text-generation",
        model=model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        truncation=True,
        max_new_tokens=2048
    )
    return pipe

# μ•± μ‹€ν–‰ ν•¨μˆ˜
def main():
    st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="πŸ€–")
    st.title("DeepSeek-R1 기반 λŒ€ν™”ν˜• 챗봇")
    st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ„ μ‚¬μš©ν•œ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
    
    # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈ μ΄ˆκΈ°ν™”
    if "chat_history_ids" not in st.session_state:
        st.session_state["chat_history_ids"] = None
    if "past_user_inputs" not in st.session_state:
        st.session_state["past_user_inputs"] = []
    if "generated_responses" not in st.session_state:
        st.session_state["generated_responses"] = []
    
    # λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € 뢈러였기
    pipe = load_model()
    
    # κΈ°μ‘΄ λŒ€ν™” λ‚΄μ—­ ν‘œμ‹œ
    if st.session_state["past_user_inputs"]:
        for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
            # μ‚¬μš©μž λ©”μ‹œμ§€
            with st.chat_message("user"):
                st.write(user_text)
            # 봇 λ©”μ‹œμ§€
            with st.chat_message("assistant"):
                st.write(bot_text)
    
    # μ±„νŒ… μž…λ ₯μ°½
    user_input = st.chat_input("μ˜μ–΄λ‘œ λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...")
    
    if user_input:
        # μ‚¬μš©μž λ©”μ‹œμ§€ ν‘œμ‹œ
        with st.chat_message("user"):
            st.write(user_input)

        # ν”„λ‘¬ν”„νŠΈ 생성
        prompt = f"Human: {user_input}\n\nAssistant:"
        
        # λͺ¨λΈ 생성
        response = pipe(
            prompt,
            max_new_tokens=2048,
            temperature=0.7,
            do_sample=True,
            truncation=True,
            pad_token_id=50256
        )
        bot_text = response[0]["generated_text"]
        
        # Assistant μ‘λ‹΅λ§Œ μΆ”μΆœ (κ°œμ„ λœ 방식)
        try:
            bot_text = bot_text.split("Assistant:")[-1].strip()
            if "</think>" in bot_text:  # λ‚΄λΆ€ 사고 κ³Όμ • 제거
                bot_text = bot_text.split("</think>")[-1].strip()
        except:
            bot_text = "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 μƒμ„±ν•˜λŠ” 데 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."
        
        # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈμ— λŒ€ν™” λ‚΄μš© μ—…λ°μ΄νŠΈ
        st.session_state["past_user_inputs"].append(user_input)
        st.session_state["generated_responses"].append(bot_text)
        
        # 봇 λ©”μ‹œμ§€ ν‘œμ‹œ
        with st.chat_message("assistant"):
            st.write(bot_text)

if __name__ == "__main__":
    main()