Quasar / app.py
Eiad Gomaa
Update app.py
5c658dc
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
import logging
# Page config - this must be the first Streamlit command
st.set_page_config(page_title="Chat with Quasar-32B", layout="wide")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Display installation instructions if needed
st.sidebar.write("### Required Packages")
st.sidebar.code("""
pip install transformers torch streamlit
""")
@st.cache_resource
def load_model():
"""Load model and tokenizer with caching"""
try:
st.spinner("Loading model... This may take a few minutes")
logger.info("Starting model loading...")
# Basic model loading without device map
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-3.2-1B",
torch_dtype=torch.float32 # Use float32 for CPU
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
# Set up padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
logger.info("Model loaded successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
st.error(f"Error loading model: {str(e)}")
return None, None
def check_for_repetition(text, threshold=3):
"""Check if the generated text has too many repetitions"""
words = text.split()
if len(words) < threshold:
return False
# Check for repeated phrases
for i in range(len(words) - threshold):
phrase = ' '.join(words[i:i+threshold])
if text.count(phrase) > 2: # If phrase appears more than twice
return True
return False
def generate_response_with_timeout(model, tokenizer, prompt, timeout_seconds=30):
"""Generate response with timeout and repetition checking"""
try:
# Prepare the input
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256 # Reduced for CPU
)
start_time = time.time()
# Generate response with stricter parameters
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=100, # Shorter responses
min_length=20, # Ensure some minimum content
num_return_sequences=1,
temperature=0.8, # Slightly higher temperature
pad_token_id=tokenizer.pad_token_id,
attention_mask=inputs["attention_mask"],
do_sample=True,
top_p=0.92,
top_k=40,
repetition_penalty=1.5, # Increased repetition penalty
no_repeat_ngram_size=3, # Prevent 3-gram repetitions
early_stopping=True
)
generation_time = time.time() - start_time
logger.info(f"Response generated in {generation_time:.2f} seconds")
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
# Check for repetitions and retry if necessary
if check_for_repetition(response):
logger.warning("Detected repetition, retrying with stricter parameters")
return "I apologize, but I'm having trouble generating a coherent response. Could you try rephrasing your question?"
return response
except Exception as e:
logger.error(f"Error in generation: {str(e)}")
return f"Error generating response: {str(e)}"
# Add debug information in sidebar
with st.sidebar:
st.write("### System Information")
st.write("Model: Quasar-32B")
# Device and memory information
device = "GPU" if torch.cuda.is_available() else "CPU"
st.write(f"Running on: {device}")
# Warning for CPU usage
if not torch.cuda.is_available():
st.warning("⚠️ Running on CPU - Responses may be very slow. Consider using a GPU or a smaller model.")
# Model settings
st.write("### Model Settings")
if 'temperature' not in st.session_state:
st.session_state.temperature = 0.8
if 'max_length' not in st.session_state:
st.session_state.max_length = 100
st.session_state.temperature = st.slider("Temperature", 0.1, 1.0, st.session_state.temperature)
st.session_state.max_length = st.slider("Max Length", 50, 200, st.session_state.max_length)
st.title("Chat with Quasar-32B")
# Initialize session state for chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Load model and tokenizer
model, tokenizer = load_model()
# Chat interface
st.write("### Chat")
chat_container = st.container()
# Display chat history
with chat_container:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# User input
if prompt := st.chat_input("Type your message here"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with chat_container:
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
if model and tokenizer:
with st.chat_message("assistant"):
try:
with st.spinner("Generating response... (timeout: 30s)"):
with ThreadPoolExecutor() as executor:
future = executor.submit(
generate_response_with_timeout,
model,
tokenizer,
prompt
)
response = future.result(timeout=200)
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except TimeoutError:
error_msg = "Response generation timed out. The model might be overloaded."
st.error(error_msg)
logger.error(error_msg)
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
st.error(error_msg)
logger.error(error_msg)
else:
st.error("Model failed to load. Please check your configuration.")
# Add a button to clear chat history
if st.button("Clear Chat History"):
st.session_state.messages = []
st.experimental_rerun()