chat / app.py
blockenters's picture
add
1ac6501
raw
history blame
2.96 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# λͺ¨λΈ λ‘œλ“œ (DeepSeek-R1-Distill-Qwen-1.5B μ˜ˆμ‹œ)
@st.cache_resource
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
truncation=True,
max_new_tokens=2048
)
return pipe
# μ•± μ‹€ν–‰ ν•¨μˆ˜
def main():
st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="πŸ€–")
st.title("DeepSeek-R1 기반 λŒ€ν™”ν˜• 챗봇")
st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ„ μ‚¬μš©ν•œ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
# μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈ μ΄ˆκΈ°ν™”
if "chat_history_ids" not in st.session_state:
st.session_state["chat_history_ids"] = None
if "past_user_inputs" not in st.session_state:
st.session_state["past_user_inputs"] = []
if "generated_responses" not in st.session_state:
st.session_state["generated_responses"] = []
# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € 뢈러였기
pipe = load_model()
# κΈ°μ‘΄ λŒ€ν™” λ‚΄μ—­ ν‘œμ‹œ
if st.session_state["past_user_inputs"]:
for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
# μ‚¬μš©μž λ©”μ‹œμ§€
with st.chat_message("user"):
st.write(user_text)
# 봇 λ©”μ‹œμ§€
with st.chat_message("assistant"):
st.write(bot_text)
# μ±„νŒ… μž…λ ₯μ°½
user_input = st.chat_input("μ˜μ–΄λ‘œ λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...")
if user_input:
# μ‚¬μš©μž λ©”μ‹œμ§€ ν‘œμ‹œ
with st.chat_message("user"):
st.write(user_input)
# ν”„λ‘¬ν”„νŠΈ 생성
prompt = f"Human: {user_input}\n\nAssistant:"
# λͺ¨λΈ 생성
response = pipe(
prompt,
max_new_tokens=2048,
temperature=0.7,
do_sample=True,
truncation=True,
pad_token_id=50256
)
bot_text = response[0]["generated_text"]
# Assistant μ‘λ‹΅λ§Œ μΆ”μΆœ (κ°œμ„ λœ 방식)
try:
bot_text = bot_text.split("Assistant:")[-1].strip()
if "</think>" in bot_text: # λ‚΄λΆ€ 사고 κ³Όμ • 제거
bot_text = bot_text.split("</think>")[-1].strip()
except:
bot_text = "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 μƒμ„±ν•˜λŠ” 데 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."
# μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈμ— λŒ€ν™” λ‚΄μš© μ—…λ°μ΄νŠΈ
st.session_state["past_user_inputs"].append(user_input)
st.session_state["generated_responses"].append(bot_text)
# 봇 λ©”μ‹œμ§€ ν‘œμ‹œ
with st.chat_message("assistant"):
st.write(bot_text)
if __name__ == "__main__":
main()