File size: 6,812 Bytes
5d8d025 335e8ff 5ab0078 f488be3 5ab0078 5d8d025 6f6da11 335e8ff 5ab0078 f488be3 6f6da11 5ab0078 6f6da11 5ab0078 f488be3 403eecc f488be3 04b4d4a f488be3 5ab0078 335e8ff 5ab0078 335e8ff 5d8d025 5ab0078 f488be3 5ab0078 5d8d025 5ab0078 335e8ff 04b4d4a 5ab0078 6f6da11 f488be3 5ab0078 f488be3 5ab0078 335e8ff 5ab0078 335e8ff 5ab0078 04b4d4a 5ab0078 6f6da11 335e8ff f488be3 5ab0078 f488be3 335e8ff 5ab0078 f488be3 5ab0078 f488be3 5ab0078 f488be3 335e8ff 5ab0078 335e8ff 5ab0078 f488be3 5ab0078 f488be3 6f6da11 f488be3 5ab0078 f488be3 5ab0078 335e8ff 5d8d025 335e8ff f488be3 335e8ff f488be3 335e8ff 5ab0078 5c658dc f488be3 335e8ff f488be3 5ab0078 335e8ff f488be3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
import logging
# Page config - this must be the first Streamlit command
st.set_page_config(page_title="Chat with Quasar-32B", layout="wide")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Display installation instructions if needed
st.sidebar.write("### Required Packages")
st.sidebar.code("""
pip install transformers torch streamlit
""")
@st.cache_resource
def load_model():
"""Load model and tokenizer with caching"""
try:
st.spinner("Loading model... This may take a few minutes")
logger.info("Starting model loading...")
# Basic model loading without device map
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-3.2-1B",
torch_dtype=torch.float32 # Use float32 for CPU
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
# Set up padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
logger.info("Model loaded successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
st.error(f"Error loading model: {str(e)}")
return None, None
def check_for_repetition(text, threshold=3):
"""Check if the generated text has too many repetitions"""
words = text.split()
if len(words) < threshold:
return False
# Check for repeated phrases
for i in range(len(words) - threshold):
phrase = ' '.join(words[i:i+threshold])
if text.count(phrase) > 2: # If phrase appears more than twice
return True
return False
def generate_response_with_timeout(model, tokenizer, prompt, timeout_seconds=30):
"""Generate response with timeout and repetition checking"""
try:
# Prepare the input
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256 # Reduced for CPU
)
start_time = time.time()
# Generate response with stricter parameters
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=100, # Shorter responses
min_length=20, # Ensure some minimum content
num_return_sequences=1,
temperature=0.8, # Slightly higher temperature
pad_token_id=tokenizer.pad_token_id,
attention_mask=inputs["attention_mask"],
do_sample=True,
top_p=0.92,
top_k=40,
repetition_penalty=1.5, # Increased repetition penalty
no_repeat_ngram_size=3, # Prevent 3-gram repetitions
early_stopping=True
)
generation_time = time.time() - start_time
logger.info(f"Response generated in {generation_time:.2f} seconds")
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
# Check for repetitions and retry if necessary
if check_for_repetition(response):
logger.warning("Detected repetition, retrying with stricter parameters")
return "I apologize, but I'm having trouble generating a coherent response. Could you try rephrasing your question?"
return response
except Exception as e:
logger.error(f"Error in generation: {str(e)}")
return f"Error generating response: {str(e)}"
# Add debug information in sidebar
with st.sidebar:
st.write("### System Information")
st.write("Model: Quasar-32B")
# Device and memory information
device = "GPU" if torch.cuda.is_available() else "CPU"
st.write(f"Running on: {device}")
# Warning for CPU usage
if not torch.cuda.is_available():
st.warning("⚠️ Running on CPU - Responses may be very slow. Consider using a GPU or a smaller model.")
# Model settings
st.write("### Model Settings")
if 'temperature' not in st.session_state:
st.session_state.temperature = 0.8
if 'max_length' not in st.session_state:
st.session_state.max_length = 100
st.session_state.temperature = st.slider("Temperature", 0.1, 1.0, st.session_state.temperature)
st.session_state.max_length = st.slider("Max Length", 50, 200, st.session_state.max_length)
st.title("Chat with Quasar-32B")
# Initialize session state for chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Load model and tokenizer
model, tokenizer = load_model()
# Chat interface
st.write("### Chat")
chat_container = st.container()
# Display chat history
with chat_container:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# User input
if prompt := st.chat_input("Type your message here"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with chat_container:
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
if model and tokenizer:
with st.chat_message("assistant"):
try:
with st.spinner("Generating response... (timeout: 30s)"):
with ThreadPoolExecutor() as executor:
future = executor.submit(
generate_response_with_timeout,
model,
tokenizer,
prompt
)
response = future.result(timeout=200)
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except TimeoutError:
error_msg = "Response generation timed out. The model might be overloaded."
st.error(error_msg)
logger.error(error_msg)
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
st.error(error_msg)
logger.error(error_msg)
else:
st.error("Model failed to load. Please check your configuration.")
# Add a button to clear chat history
if st.button("Clear Chat History"):
st.session_state.messages = []
st.experimental_rerun()
|