Shahabmoin's picture
Update app.py
55ed7de verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model and tokenizer
@st.cache_resource
def load_model():
model_name = "tiiuae/falcon-7b-instruct" # Replace with the desired Falcon model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Automatically assign model layers to available GPUs/CPUs
torch_dtype=torch.float16 # Use FP16 for faster inference
)
return model, tokenizer
model, tokenizer = load_model()
# Initialize chat history
if "messages" not in st.session_state:
st.session_state["messages"] = []
# Sidebar configuration
st.sidebar.title("Chatbot Settings")
st.sidebar.write("Customize your chatbot:")
max_length = st.sidebar.slider("Max Response Length (Tokens)", 50, 500, 150)
temperature = st.sidebar.slider("Response Creativity (Temperature)", 0.1, 1.0, 0.7)
# App title
st.title("πŸ€– Falcon Chatbot")
# Chat interface
st.write("### Chat with the bot:")
user_input = st.text_input("You:", key="user_input", placeholder="Type your message here...")
if user_input:
# Add user input to chat history
st.session_state["messages"].append(f"User: {user_input}")
# Prepare input for the model
prompt = "\n".join(st.session_state["messages"]) + f"\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
# Generate response
with st.spinner("Thinking..."):
output = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
)
bot_response = tokenizer.decode(output[0], skip_special_tokens=True).split("Assistant:")[-1].strip()
# Add bot response to chat history
st.session_state["messages"].append(f"Assistant: {bot_response}")
# Display chat history
for msg in st.session_state["messages"]:
if msg.startswith("User:"):
st.markdown(f"**{msg}**")
elif msg.startswith("Assistant:"):
st.markdown(f"> {msg}")
# Clear chat history button
if st.button("Clear Chat"):
st.session_state["messages"] = []
st.experimental_rerun()