Spaces:
Sleeping
Sleeping
File size: 2,963 Bytes
ee6db9e 498e0ad ee6db9e e0ed46b ee6db9e e0ed46b 498e0ad e0ed46b d29211b e05182e e0ed46b 498e0ad ee6db9e e0ed46b 1ac6501 ee6db9e e0ed46b 498e0ad e0ed46b ee6db9e e0ed46b ee6db9e 1ac6501 e0ed46b ee6db9e f6a2a35 d29211b e05182e d29211b e05182e d29211b 498e0ad e0ed46b e05182e f6a2a35 ee6db9e e0ed46b ee6db9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# λͺ¨λΈ λ‘λ (DeepSeek-R1-Distill-Qwen-1.5B μμ)
@st.cache_resource
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
truncation=True,
max_new_tokens=2048
)
return pipe
# μ± μ€ν ν¨μ
def main():
st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="π€")
st.title("DeepSeek-R1 κΈ°λ° λνν μ±λ΄")
st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ μ¬μ©ν λν ν
μ€νΈμ© λ°λͺ¨μ
λλ€.")
# μΈμ
μ€ν
μ΄νΈ μ΄κΈ°ν
if "chat_history_ids" not in st.session_state:
st.session_state["chat_history_ids"] = None
if "past_user_inputs" not in st.session_state:
st.session_state["past_user_inputs"] = []
if "generated_responses" not in st.session_state:
st.session_state["generated_responses"] = []
# λͺ¨λΈκ³Ό ν ν¬λμ΄μ λΆλ¬μ€κΈ°
pipe = load_model()
# κΈ°μ‘΄ λν λ΄μ νμ
if st.session_state["past_user_inputs"]:
for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
# μ¬μ©μ λ©μμ§
with st.chat_message("user"):
st.write(user_text)
# λ΄ λ©μμ§
with st.chat_message("assistant"):
st.write(bot_text)
# μ±ν
μ
λ ₯μ°½
user_input = st.chat_input("μμ΄λ‘ λ©μμ§λ₯Ό μ
λ ₯νμΈμ...")
if user_input:
# μ¬μ©μ λ©μμ§ νμ
with st.chat_message("user"):
st.write(user_input)
# ν둬ννΈ μμ±
prompt = f"Human: {user_input}\n\nAssistant:"
# λͺ¨λΈ μμ±
response = pipe(
prompt,
max_new_tokens=2048,
temperature=0.7,
do_sample=True,
truncation=True,
pad_token_id=50256
)
bot_text = response[0]["generated_text"]
# Assistant μλ΅λ§ μΆμΆ (κ°μ λ λ°©μ)
try:
bot_text = bot_text.split("Assistant:")[-1].strip()
if "</think>" in bot_text: # λ΄λΆ μ¬κ³ κ³Όμ μ κ±°
bot_text = bot_text.split("</think>")[-1].strip()
except:
bot_text = "μ£μ‘ν©λλ€. μλ΅μ μμ±νλ λ° λ¬Έμ κ° λ°μνμ΅λλ€."
# μΈμ
μ€ν
μ΄νΈμ λν λ΄μ© μ
λ°μ΄νΈ
st.session_state["past_user_inputs"].append(user_input)
st.session_state["generated_responses"].append(bot_text)
# λ΄ λ©μμ§ νμ
with st.chat_message("assistant"):
st.write(bot_text)
if __name__ == "__main__":
main()
|