import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the model and tokenizer
@st.cache_resource
def load_model():
    model_name = "tiiuae/falcon-7b-instruct"  # Replace with the desired Falcon model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",  # Automatically assign model layers to available GPUs/CPUs
        torch_dtype=torch.float16  # Use FP16 for faster inference
    )
    return model, tokenizer

model, tokenizer = load_model()

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state["messages"] = []

# Sidebar configuration
st.sidebar.title("Chatbot Settings")
st.sidebar.write("Customize your chatbot:")
max_length = st.sidebar.slider("Max Response Length (Tokens)", 50, 500, 150)
temperature = st.sidebar.slider("Response Creativity (Temperature)", 0.1, 1.0, 0.7)

# App title
st.title("🤖 Falcon Chatbot")

# Chat interface
st.write("### Chat with the bot:")
user_input = st.text_input("You:", key="user_input", placeholder="Type your message here...")

if user_input:
    # Add user input to chat history
    st.session_state["messages"].append(f"User: {user_input}")

    # Prepare input for the model
    prompt = "\n".join(st.session_state["messages"]) + f"\nAssistant:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)

    # Generate response
    with st.spinner("Thinking..."):
        output = model.generate(
            inputs.input_ids,
            max_length=max_length,
            temperature=temperature,
            pad_token_id=tokenizer.eos_token_id,
        )
        bot_response = tokenizer.decode(output[0], skip_special_tokens=True).split("Assistant:")[-1].strip()

    # Add bot response to chat history
    st.session_state["messages"].append(f"Assistant: {bot_response}")

    # Display chat history
    for msg in st.session_state["messages"]:
        if msg.startswith("User:"):
            st.markdown(f"**{msg}**")
        elif msg.startswith("Assistant:"):
            st.markdown(f"> {msg}")

# Clear chat history button
if st.button("Clear Chat"):
    st.session_state["messages"] = []
    st.experimental_rerun()