Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
# λͺ¨λΈ λ‘λ (DeepSeek-R1-Distill-Qwen-1.5B μμ) | |
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
truncation=True, | |
max_new_tokens=2048 | |
) | |
return pipe | |
# μ± μ€ν ν¨μ | |
def main(): | |
st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="π€") | |
st.title("DeepSeek-R1 κΈ°λ° λνν μ±λ΄") | |
st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ μ¬μ©ν λν ν μ€νΈμ© λ°λͺ¨μ λλ€.") | |
# μΈμ μ€ν μ΄νΈ μ΄κΈ°ν | |
if "chat_history_ids" not in st.session_state: | |
st.session_state["chat_history_ids"] = None | |
if "past_user_inputs" not in st.session_state: | |
st.session_state["past_user_inputs"] = [] | |
if "generated_responses" not in st.session_state: | |
st.session_state["generated_responses"] = [] | |
# λͺ¨λΈκ³Ό ν ν¬λμ΄μ λΆλ¬μ€κΈ° | |
pipe = load_model() | |
# κΈ°μ‘΄ λν λ΄μ νμ | |
if st.session_state["past_user_inputs"]: | |
for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]): | |
# μ¬μ©μ λ©μμ§ | |
with st.chat_message("user"): | |
st.write(user_text) | |
# λ΄ λ©μμ§ | |
with st.chat_message("assistant"): | |
st.write(bot_text) | |
# μ±ν μ λ ₯μ°½ | |
user_input = st.chat_input("μμ΄λ‘ λ©μμ§λ₯Ό μ λ ₯νμΈμ...") | |
if user_input: | |
# μ¬μ©μ λ©μμ§ νμ | |
with st.chat_message("user"): | |
st.write(user_input) | |
# ν둬ννΈ μμ± | |
prompt = f"Human: {user_input}\n\nAssistant:" | |
# λͺ¨λΈ μμ± | |
response = pipe( | |
prompt, | |
max_new_tokens=2048, | |
temperature=0.7, | |
do_sample=True, | |
truncation=True, | |
pad_token_id=50256 | |
) | |
bot_text = response[0]["generated_text"] | |
# Assistant μλ΅λ§ μΆμΆ (κ°μ λ λ°©μ) | |
try: | |
bot_text = bot_text.split("Assistant:")[-1].strip() | |
if "</think>" in bot_text: # λ΄λΆ μ¬κ³ κ³Όμ μ κ±° | |
bot_text = bot_text.split("</think>")[-1].strip() | |
except: | |
bot_text = "μ£μ‘ν©λλ€. μλ΅μ μμ±νλ λ° λ¬Έμ κ° λ°μνμ΅λλ€." | |
# μΈμ μ€ν μ΄νΈμ λν λ΄μ© μ λ°μ΄νΈ | |
st.session_state["past_user_inputs"].append(user_input) | |
st.session_state["generated_responses"].append(bot_text) | |
# λ΄ λ©μμ§ νμ | |
with st.chat_message("assistant"): | |
st.write(bot_text) | |
if __name__ == "__main__": | |
main() | |