Eiad Gomaa commited on
Commit
f488be3
·
1 Parent(s): 6f6da11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -5,6 +5,9 @@ import time
5
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
6
  import logging
7
 
 
 
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
@@ -21,20 +24,20 @@ def load_model():
21
  try:
22
  st.spinner("Loading model... This may take a few minutes")
23
  logger.info("Starting model loading...")
24
-
25
  # Basic model loading without device map
26
  model = AutoModelForCausalLM.from_pretrained(
27
  "NousResearch/Llama-3.2-1B",
28
  torch_dtype=torch.float32 # Use float32 for CPU
29
  )
30
-
31
  tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
32
-
33
  # Set up padding token
34
  if tokenizer.pad_token is None:
35
  tokenizer.pad_token = tokenizer.eos_token
36
  model.config.pad_token_id = model.config.eos_token_id
37
-
38
  logger.info("Model loaded successfully")
39
  return model, tokenizer
40
  except Exception as e:
@@ -47,7 +50,7 @@ def check_for_repetition(text, threshold=3):
47
  words = text.split()
48
  if len(words) < threshold:
49
  return False
50
-
51
  # Check for repeated phrases
52
  for i in range(len(words) - threshold):
53
  phrase = ' '.join(words[i:i+threshold])
@@ -66,9 +69,9 @@ def generate_response_with_timeout(model, tokenizer, prompt, timeout_seconds=30)
66
  truncation=True,
67
  max_length=256 # Reduced for CPU
68
  )
69
-
70
  start_time = time.time()
71
-
72
  # Generate response with stricter parameters
73
  with torch.no_grad():
74
  outputs = model.generate(
@@ -86,47 +89,44 @@ def generate_response_with_timeout(model, tokenizer, prompt, timeout_seconds=30)
86
  no_repeat_ngram_size=3, # Prevent 3-gram repetitions
87
  early_stopping=True
88
  )
89
-
90
  generation_time = time.time() - start_time
91
  logger.info(f"Response generated in {generation_time:.2f} seconds")
92
-
93
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
  response = response.replace(prompt, "").strip()
95
-
96
  # Check for repetitions and retry if necessary
97
  if check_for_repetition(response):
98
  logger.warning("Detected repetition, retrying with stricter parameters")
99
  return "I apologize, but I'm having trouble generating a coherent response. Could you try rephrasing your question?"
100
-
101
  return response
102
-
103
  except Exception as e:
104
  logger.error(f"Error in generation: {str(e)}")
105
  return f"Error generating response: {str(e)}"
106
 
107
- # Page config
108
- st.set_page_config(page_title="Chat with Quasar-32B", layout="wide")
109
-
110
  # Add debug information in sidebar
111
  with st.sidebar:
112
  st.write("### System Information")
113
  st.write("Model: Quasar-32B")
114
-
115
  # Device and memory information
116
  device = "GPU" if torch.cuda.is_available() else "CPU"
117
  st.write(f"Running on: {device}")
118
-
119
  # Warning for CPU usage
120
  if not torch.cuda.is_available():
121
  st.warning("⚠️ Running on CPU - Responses may be very slow. Consider using a GPU or a smaller model.")
122
-
123
  # Model settings
124
  st.write("### Model Settings")
125
  if 'temperature' not in st.session_state:
126
  st.session_state.temperature = 0.8
127
  if 'max_length' not in st.session_state:
128
  st.session_state.max_length = 100
129
-
130
  st.session_state.temperature = st.slider("Temperature", 0.1, 1.0, st.session_state.temperature)
131
  st.session_state.max_length = st.slider("Max Length", 50, 200, st.session_state.max_length)
132
 
@@ -153,12 +153,12 @@ with chat_container:
153
  if prompt := st.chat_input("Type your message here"):
154
  # Add user message to chat history
155
  st.session_state.messages.append({"role": "user", "content": prompt})
156
-
157
  # Display user message
158
  with chat_container:
159
  with st.chat_message("user"):
160
  st.write(prompt)
161
-
162
  # Generate and display assistant response
163
  if model and tokenizer:
164
  with st.chat_message("assistant"):
@@ -172,10 +172,10 @@ if prompt := st.chat_input("Type your message here"):
172
  prompt
173
  )
174
  response = future.result(timeout=30)
175
-
176
  st.write(response)
177
  st.session_state.messages.append({"role": "assistant", "content": response})
178
-
179
  except TimeoutError:
180
  error_msg = "Response generation timed out. The model might be overloaded."
181
  st.error(error_msg)
@@ -190,4 +190,4 @@ if prompt := st.chat_input("Type your message here"):
190
  # Add a button to clear chat history
191
  if st.button("Clear Chat History"):
192
  st.session_state.messages = []
193
- st.experimental_rerun()
 
5
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
6
  import logging
7
 
8
+ # Page config - this must be the first Streamlit command
9
+ st.set_page_config(page_title="Chat with Quasar-32B", layout="wide")
10
+
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
 
24
  try:
25
  st.spinner("Loading model... This may take a few minutes")
26
  logger.info("Starting model loading...")
27
+
28
  # Basic model loading without device map
29
  model = AutoModelForCausalLM.from_pretrained(
30
  "NousResearch/Llama-3.2-1B",
31
  torch_dtype=torch.float32 # Use float32 for CPU
32
  )
33
+
34
  tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
35
+
36
  # Set up padding token
37
  if tokenizer.pad_token is None:
38
  tokenizer.pad_token = tokenizer.eos_token
39
  model.config.pad_token_id = model.config.eos_token_id
40
+
41
  logger.info("Model loaded successfully")
42
  return model, tokenizer
43
  except Exception as e:
 
50
  words = text.split()
51
  if len(words) < threshold:
52
  return False
53
+
54
  # Check for repeated phrases
55
  for i in range(len(words) - threshold):
56
  phrase = ' '.join(words[i:i+threshold])
 
69
  truncation=True,
70
  max_length=256 # Reduced for CPU
71
  )
72
+
73
  start_time = time.time()
74
+
75
  # Generate response with stricter parameters
76
  with torch.no_grad():
77
  outputs = model.generate(
 
89
  no_repeat_ngram_size=3, # Prevent 3-gram repetitions
90
  early_stopping=True
91
  )
92
+
93
  generation_time = time.time() - start_time
94
  logger.info(f"Response generated in {generation_time:.2f} seconds")
95
+
96
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
  response = response.replace(prompt, "").strip()
98
+
99
  # Check for repetitions and retry if necessary
100
  if check_for_repetition(response):
101
  logger.warning("Detected repetition, retrying with stricter parameters")
102
  return "I apologize, but I'm having trouble generating a coherent response. Could you try rephrasing your question?"
103
+
104
  return response
105
+
106
  except Exception as e:
107
  logger.error(f"Error in generation: {str(e)}")
108
  return f"Error generating response: {str(e)}"
109
 
 
 
 
110
  # Add debug information in sidebar
111
  with st.sidebar:
112
  st.write("### System Information")
113
  st.write("Model: Quasar-32B")
114
+
115
  # Device and memory information
116
  device = "GPU" if torch.cuda.is_available() else "CPU"
117
  st.write(f"Running on: {device}")
118
+
119
  # Warning for CPU usage
120
  if not torch.cuda.is_available():
121
  st.warning("⚠️ Running on CPU - Responses may be very slow. Consider using a GPU or a smaller model.")
122
+
123
  # Model settings
124
  st.write("### Model Settings")
125
  if 'temperature' not in st.session_state:
126
  st.session_state.temperature = 0.8
127
  if 'max_length' not in st.session_state:
128
  st.session_state.max_length = 100
129
+
130
  st.session_state.temperature = st.slider("Temperature", 0.1, 1.0, st.session_state.temperature)
131
  st.session_state.max_length = st.slider("Max Length", 50, 200, st.session_state.max_length)
132
 
 
153
  if prompt := st.chat_input("Type your message here"):
154
  # Add user message to chat history
155
  st.session_state.messages.append({"role": "user", "content": prompt})
156
+
157
  # Display user message
158
  with chat_container:
159
  with st.chat_message("user"):
160
  st.write(prompt)
161
+
162
  # Generate and display assistant response
163
  if model and tokenizer:
164
  with st.chat_message("assistant"):
 
172
  prompt
173
  )
174
  response = future.result(timeout=30)
175
+
176
  st.write(response)
177
  st.session_state.messages.append({"role": "assistant", "content": response})
178
+
179
  except TimeoutError:
180
  error_msg = "Response generation timed out. The model might be overloaded."
181
  st.error(error_msg)
 
190
  # Add a button to clear chat history
191
  if st.button("Clear Chat History"):
192
  st.session_state.messages = []
193
+ st.experimental_rerun()