Harshu0117 commited on
Commit
2fca8ab
·
verified ·
1 Parent(s): 43d27e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -287
app.py CHANGED
@@ -1,327 +1,294 @@
1
- import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import time
5
- import base64
6
- from io import BytesIO
7
 
8
- # Page configuration
9
- st.set_page_config(
10
- page_title="Materials Science AI Assistant",
11
- page_icon="🧪",
12
- layout="wide",
13
- initial_sidebar_state="collapsed"
14
- )
15
 
16
- # Custom CSS for styling
17
- st.markdown("""
18
- <style>
19
- /* Main background gradient */
20
- .stApp {
21
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
22
- color: white;
23
- }
24
-
25
- /* Header styling */
26
- .main-header {
27
- text-align: center;
28
- padding: 2rem 0;
29
- background: rgba(255, 255, 255, 0.1);
30
- border-radius: 20px;
31
- margin-bottom: 2rem;
32
- backdrop-filter: blur(10px);
33
- border: 1px solid rgba(255, 255, 255, 0.2);
34
- }
35
-
36
- /* Input area styling */
37
- .stTextArea textarea {
38
- background: rgba(255, 255, 255, 0.15);
39
- border: 1px solid rgba(255, 255, 255, 0.3);
40
- border-radius: 15px;
41
- color: white;
42
- font-size: 16px;
43
- backdrop-filter: blur(5px);
44
- }
45
-
46
- /* Button styling */
47
- .stButton button {
48
- background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
49
- border: none;
50
- border-radius: 25px;
51
- color: white;
52
- font-weight: bold;
53
- padding: 0.75rem 2rem;
54
- font-size: 16px;
55
- transition: all 0.3s ease;
56
- box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
57
- }
58
-
59
- .stButton button:hover {
60
- transform: translateY(-2px);
61
- box-shadow: 0 6px 20px rgba(0, 0, 0, 0.3);
62
- }
63
-
64
- /* Response area styling */
65
- .response-container {
66
- background: rgba(255, 255, 255, 0.1);
67
- border-radius: 15px;
68
- padding: 1.5rem;
69
- margin: 1rem 0;
70
- backdrop-filter: blur(10px);
71
- border: 1px solid rgba(255, 255, 255, 0.2);
72
- }
73
-
74
- /* Advanced options styling */
75
- .advanced-options {
76
- background: rgba(255, 255, 255, 0.08);
77
- border-radius: 15px;
78
- padding: 1rem;
79
- margin: 1rem 0;
80
- border: 1px solid rgba(255, 255, 255, 0.1);
81
- }
82
-
83
- /* Loading animation */
84
- .loading-animation {
85
- text-align: center;
86
- font-size: 18px;
87
- color: #4ECDC4;
88
- animation: pulse 2s infinite;
89
- }
90
-
91
- @keyframes pulse {
92
- 0% { opacity: 1; }
93
- 50% { opacity: 0.5; }
94
- 100% { opacity: 1; }
95
- }
96
 
97
- /* Sidebar styling */
98
- .sidebar .sidebar-content {
99
- background: rgba(255, 255, 255, 0.1);
100
- backdrop-filter: blur(10px);
101
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- /* Hide streamlit menu */
104
- #MainMenu {visibility: hidden;}
105
- footer {visibility: hidden;}
106
- header {visibility: hidden;}
107
- </style>
108
- """, unsafe_allow_html=True)
109
-
110
- # Initialize session state
111
- if 'model' not in st.session_state:
112
- st.session_state.model = None
113
- st.session_state.tokenizer = None
114
- st.session_state.model_loaded = False
115
-
116
- @st.cache_resource
117
- def load_model():
118
- """Load the model and tokenizer from Hugging Face"""
119
- try:
120
- model_name = "Harshu0117/Materials_IISC_MRC"
121
-
122
- # Load tokenizer
123
- tokenizer = AutoTokenizer.from_pretrained(model_name)
124
-
125
- # Load model
126
- model = AutoModelForCausalLM.from_pretrained(
127
- model_name,
128
- torch_dtype=torch.float16,
129
- device_map="auto",
130
- trust_remote_code=True
131
- )
132
-
133
- # Set pad token if not set
134
- if tokenizer.pad_token is None:
135
- tokenizer.pad_token = tokenizer.eos_token
136
-
137
- return model, tokenizer
138
- except Exception as e:
139
- st.error(f"Error loading model: {str(e)}")
140
- return None, None
141
 
142
  def generate_response(prompt, max_tokens, temperature, top_p, repetition_penalty):
143
- """Generate response using the loaded model"""
144
- if st.session_state.model is None or st.session_state.tokenizer is None:
145
- return "❌ Model not loaded properly. Please refresh the page."
 
 
 
 
 
 
 
 
146
 
147
  try:
148
- # Tokenize input
149
- inputs = st.session_state.tokenizer(
150
- prompt,
151
  return_tensors="pt",
152
  truncation=True,
153
- max_length=1024
 
154
  )
155
 
156
- # Move to device
157
- if torch.cuda.is_available():
158
- inputs = inputs.to("cuda")
159
 
160
- # Generate response
161
  with torch.no_grad():
162
- outputs = st.session_state.model.generate(
163
  **inputs,
164
- max_new_tokens=max_tokens,
165
- temperature=temperature,
166
- top_p=top_p,
167
- repetition_penalty=repetition_penalty,
168
  do_sample=True,
169
- pad_token_id=st.session_state.tokenizer.pad_token_id,
170
- eos_token_id=st.session_state.tokenizer.eos_token_id,
171
- use_cache=True
 
 
172
  )
173
 
174
  # Decode response
175
- response = st.session_state.tokenizer.decode(
176
  outputs[0],
177
  skip_special_tokens=True
178
  )
179
 
180
  # Remove the original prompt from response
181
- response = response.replace(prompt, "").strip()
 
 
 
 
182
 
183
  return response
184
 
185
  except Exception as e:
186
  return f"❌ Error generating response: {str(e)}"
187
 
188
- # Main app layout
189
- def main():
190
- # Header
191
- st.markdown("""
192
- <div class="main-header">
193
- <h1>🧪 Materials Science AI Assistant</h1>
194
- <p style="font-size: 18px; margin-top: 10px;">
195
- Powered by Fine-tuned LLaMA 3 8B | Specialized in Materials Research
196
- </p>
197
- </div>
198
- """, unsafe_allow_html=True)
199
-
200
- # Load model on first run
201
- if not st.session_state.model_loaded:
202
- with st.spinner("🔄 Loading AI model... This may take a moment..."):
203
- st.session_state.model, st.session_state.tokenizer = load_model()
204
- if st.session_state.model is not None:
205
- st.session_state.model_loaded = True
206
- st.success("✅ Model loaded successfully!")
207
- else:
208
- st.error("❌ Failed to load model. Please refresh the page.")
209
- return
210
-
211
- # Main input area
212
- st.markdown("### 💬 Ask me anything about Materials Science!")
213
-
214
- # Input text area
215
- prompt = st.text_area(
216
- "Enter your question or topic:",
217
- placeholder="e.g., Crystalline MAX Phases and their 2D derivative MXenes",
218
- height=100,
219
- key="prompt_input"
220
- )
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # Advanced options (collapsible)
223
- with st.expander("⚙️ Advanced Options"):
224
- st.markdown('<div class="advanced-options">', unsafe_allow_html=True)
 
 
 
 
 
 
 
225
 
226
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
227
 
228
- with col1:
229
- max_tokens = st.slider(
230
- "Max Tokens (Response Length)",
231
- min_value=50,
232
- max_value=500,
233
- value=200,
234
- step=10,
235
- help="Maximum number of tokens in the response"
236
- )
237
-
238
- temperature = st.slider(
239
- "Temperature (Creativity)",
240
- min_value=0.1,
241
- max_value=1.0,
242
- value=0.7,
243
- step=0.1,
244
- help="Higher values make responses more creative but less focused"
245
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- with col2:
248
- top_p = st.slider(
249
- "Top-p (Diversity)",
250
- min_value=0.1,
251
- max_value=1.0,
252
- value=0.9,
253
- step=0.1,
254
- help="Controls diversity of word choices"
255
- )
256
-
257
- repetition_penalty = st.slider(
258
- "Repetition Penalty",
259
- min_value=1.0,
260
- max_value=2.0,
261
- value=1.2,
262
- step=0.1,
263
- help="Penalty for repeating words/phrases"
264
- )
265
 
266
- st.markdown('</div>', unsafe_allow_html=True)
267
-
268
- # Generate button
269
- col1, col2, col3 = st.columns([1, 2, 1])
270
- with col2:
271
- generate_btn = st.button("🚀 Generate Response", use_container_width=True)
272
-
273
- # Response area
274
- if generate_btn and prompt.strip():
275
- if st.session_state.model_loaded:
276
- with st.spinner("🧠 AI is thinking..."):
277
- response = generate_response(
278
- prompt.strip(),
279
- max_tokens,
280
- temperature,
281
- top_p,
282
- repetition_penalty
283
- )
284
-
285
- # Display response
286
- st.markdown("### 🤖 AI Response:")
287
- st.markdown(f"""
288
- <div class="response-container">
289
- <p style="font-size: 16px; line-height: 1.6;">
290
- {response}
291
- </p>
292
- </div>
293
- """, unsafe_allow_html=True)
294
-
295
- else:
296
- st.error("❌ Model not loaded. Please refresh the page.")
297
-
298
- elif generate_btn and not prompt.strip():
299
- st.warning("⚠️ Please enter a question or topic first!")
300
-
301
- # Footer
302
- st.markdown("---")
303
- st.markdown("""
304
- <div style="text-align: center; padding: 1rem; color: rgba(255, 255, 255, 0.7);">
305
- <p>🔬 Specialized in Materials Science | 🧪 MAX Phases & MXenes Expert</p>
306
- <p>Built with ❤️ using Streamlit & Hugging Face</p>
307
- </div>
308
- """, unsafe_allow_html=True)
309
-
310
- # Example prompts sidebar
311
- def show_examples():
312
- st.sidebar.markdown("### 💡 Example Prompts")
313
- examples = [
314
- "Crystalline MAX Phases and their 2D derivative MXenes",
315
- "Properties of titanium carbide MXenes",
316
- "Synthesis methods for MAX phases",
317
- "Applications of MXenes in energy storage",
318
- "Mechanical properties of ceramic materials"
319
- ]
320
 
321
- for i, example in enumerate(examples):
322
- if st.sidebar.button(f"📝 {example[:30]}...", key=f"example_{i}"):
323
- st.session_state.prompt_input = example
324
 
 
325
  if __name__ == "__main__":
326
- show_examples()
327
- main()
 
1
+ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import gc
 
 
5
 
6
+ # Global variables for model and tokenizer
7
+ model = None
8
+ tokenizer = None
 
 
 
 
9
 
10
+ def load_model():
11
+ """Load the model and tokenizer from Hugging Face with CPU optimizations"""
12
+ global model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ if model is None or tokenizer is None:
15
+ try:
16
+ model_name = "Harshu0117/Materials_IISC_MRC"
17
+
18
+ # Load tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ # Load model with CPU optimizations
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float16, # Use float16 for faster CPU inference
25
+ device_map="cpu",
26
+ trust_remote_code=True,
27
+ low_cpu_mem_usage=True, # Reduce memory usage
28
+ offload_folder="offload" # Enable model offloading
29
+ )
30
+
31
+ # Convert to float16 for faster inference
32
+ model = model.half()
33
+
34
+ # Enable CPU optimizations
35
+ model.eval()
36
+
37
+ # Set pad token if not set
38
+ if tokenizer.pad_token is None:
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+
41
+ # Clear GPU cache if any
42
+ if torch.cuda.is_available():
43
+ torch.cuda.empty_cache()
44
+
45
+ # Force garbage collection
46
+ gc.collect()
47
+
48
+ return "✅ Model loaded successfully with CPU optimizations!"
49
+ except Exception as e:
50
+ return f"❌ Error loading model: {str(e)}"
51
 
52
+ return "✅ Model already loaded!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def generate_response(prompt, max_tokens, temperature, top_p, repetition_penalty):
55
+ """Generate response using the loaded model with CPU optimizations"""
56
+ global model, tokenizer
57
+
58
+ # Load model if not already loaded
59
+ if model is None or tokenizer is None:
60
+ load_result = load_model()
61
+ if "Error" in load_result:
62
+ return load_result
63
+
64
+ if not prompt.strip():
65
+ return "⚠️ Please enter a question or topic first!"
66
 
67
  try:
68
+ # Tokenize input with truncation for faster processing
69
+ inputs = tokenizer(
70
+ prompt.strip(),
71
  return_tensors="pt",
72
  truncation=True,
73
+ max_length=512, # Reduced from 1024 for faster processing
74
+ padding=True
75
  )
76
 
77
+ # Keep on CPU
78
+ inputs = inputs.to("cpu")
 
79
 
80
+ # Generate response with optimized settings
81
  with torch.no_grad():
82
+ outputs = model.generate(
83
  **inputs,
84
+ max_new_tokens=int(max_tokens),
85
+ temperature=float(temperature),
86
+ top_p=float(top_p),
87
+ repetition_penalty=float(repetition_penalty),
88
  do_sample=True,
89
+ pad_token_id=tokenizer.pad_token_id,
90
+ eos_token_id=tokenizer.eos_token_id,
91
+ use_cache=True,
92
+ num_beams=1, # Use greedy decoding for speed
93
+ early_stopping=True
94
  )
95
 
96
  # Decode response
97
+ response = tokenizer.decode(
98
  outputs[0],
99
  skip_special_tokens=True
100
  )
101
 
102
  # Remove the original prompt from response
103
+ response = response.replace(prompt.strip(), "").strip()
104
+
105
+ # Clear memory
106
+ del outputs
107
+ gc.collect()
108
 
109
  return response
110
 
111
  except Exception as e:
112
  return f"❌ Error generating response: {str(e)}"
113
 
114
+ # Create Gradio interface
115
+ def create_interface():
116
+ # Custom CSS for styling
117
+ css = """
118
+ .gradio-container {
119
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
120
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
121
+ }
122
+ .gr-button-primary {
123
+ background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
124
+ border: none !important;
125
+ border-radius: 25px !important;
126
+ color: white !important;
127
+ font-weight: bold !important;
128
+ padding: 12px 24px !important;
129
+ font-size: 16px !important;
130
+ transition: all 0.3s ease !important;
131
+ }
132
+ .gr-button-primary:hover {
133
+ transform: translateY(-2px) !important;
134
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2) !important;
135
+ }
136
+ .gr-textbox {
137
+ border-radius: 15px !important;
138
+ border: 2px solid #e0e0e0 !important;
139
+ background: rgba(255, 255, 255, 0.95) !important;
140
+ }
141
+ .gr-textbox:focus {
142
+ border-color: #4ECDC4 !important;
143
+ box-shadow: 0 0 10px rgba(78, 205, 196, 0.3) !important;
144
+ }
145
+ .output-text {
146
+ background: rgba(255, 255, 255, 0.95) !important;
147
+ border-radius: 15px !important;
148
+ padding: 20px !important;
149
+ margin: 10px 0 !important;
150
+ border-left: 4px solid #4ECDC4 !important;
151
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
152
+ }
153
+ .gr-accordion {
154
+ background: rgba(255, 255, 255, 0.1) !important;
155
+ border-radius: 15px !important;
156
+ border: 1px solid rgba(255, 255, 255, 0.3) !important;
157
+ }
158
+ """
159
 
160
+ # Create interface
161
+ with gr.Blocks(
162
+ css=css,
163
+ title="🧪 Materials Science AI Assistant",
164
+ theme=gr.themes.Soft(
165
+ primary_hue="blue",
166
+ secondary_hue="cyan",
167
+ neutral_hue="slate"
168
+ )
169
+ ) as demo:
170
 
171
+ # Header
172
+ gr.HTML("""
173
+ <div style="text-align: center; padding: 30px; background: rgba(255, 255, 255, 0.95); border-radius: 20px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);">
174
+ <h1 style="color: #2c3e50; font-size: 2.5em; margin: 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.1);">
175
+ 🧪 Materials Science AI Assistant
176
+ </h1>
177
+ <p style="color: #7f8c8d; font-size: 1.2em; margin: 10px 0 0 0; font-weight: 500;">
178
+ Powered by Fine-tuned LLaMA 3 8B | Specialized in Materials Research
179
+ </p>
180
+ </div>
181
+ """)
182
 
183
+ # Main interface
184
+ with gr.Row():
185
+ with gr.Column(scale=2):
186
+ # Input area
187
+ gr.Markdown("### 💬 Ask me anything about Materials Science!")
188
+
189
+ prompt = gr.Textbox(
190
+ label="Enter your question or topic:",
191
+ placeholder="e.g., Crystalline MAX Phases and their 2D derivative MXenes",
192
+ lines=4,
193
+ max_lines=8
194
+ )
195
+
196
+ # Advanced options
197
+ with gr.Accordion("⚙️ Advanced Options", open=False):
198
+ with gr.Row():
199
+ max_tokens = gr.Slider(
200
+ label="Max Tokens (Response Length)",
201
+ minimum=50,
202
+ maximum=500,
203
+ value=200,
204
+ step=10,
205
+ info="Maximum number of tokens in the response"
206
+ )
207
+
208
+ temperature = gr.Slider(
209
+ label="Temperature (Creativity)",
210
+ minimum=0.1,
211
+ maximum=1.0,
212
+ value=0.7,
213
+ step=0.1,
214
+ info="Higher values make responses more creative"
215
+ )
216
+
217
+ with gr.Row():
218
+ top_p = gr.Slider(
219
+ label="Top-p (Diversity)",
220
+ minimum=0.1,
221
+ maximum=1.0,
222
+ value=0.9,
223
+ step=0.1,
224
+ info="Controls diversity of word choices"
225
+ )
226
+
227
+ repetition_penalty = gr.Slider(
228
+ label="Repetition Penalty",
229
+ minimum=1.0,
230
+ maximum=2.0,
231
+ value=1.2,
232
+ step=0.1,
233
+ info="Penalty for repeating words/phrases"
234
+ )
235
+
236
+ # Generate button
237
+ generate_btn = gr.Button(
238
+ "🚀 Generate Response",
239
+ variant="primary",
240
+ size="lg"
241
+ )
242
 
243
+ # Output area
244
+ gr.Markdown("### 🤖 AI Response:")
245
+ output = gr.Textbox(
246
+ label="Generated Response",
247
+ lines=10,
248
+ max_lines=20,
249
+ interactive=False,
250
+ elem_classes=["output-text"]
251
+ )
 
 
 
 
 
 
 
 
 
252
 
253
+ # Example prompts
254
+ gr.Markdown("### 💡 Example Prompts (Click to use):")
255
+ examples = [
256
+ "Crystalline MAX Phases and their 2D derivative MXenes",
257
+ "Properties of titanium carbide MXenes",
258
+ "Synthesis methods for MAX phases",
259
+ "Applications of MXenes in energy storage",
260
+ "Mechanical properties of ceramic materials"
261
+ ]
262
+
263
+ gr.Examples(
264
+ examples=examples,
265
+ inputs=prompt,
266
+ label="Click any example to try:"
267
+ )
268
+
269
+ # Footer
270
+ gr.HTML("""
271
+ <div style="text-align: center; padding: 20px; margin-top: 30px; background: rgba(255, 255, 255, 0.1); border-radius: 15px; border: 1px solid rgba(255, 255, 255, 0.3);">
272
+ <p style="color: white; font-size: 16px; margin: 0;">
273
+ 🔬 <strong>Specialized in Materials Science</strong> | 🧪 <strong>MAX Phases & MXenes Expert</strong>
274
+ </p>
275
+ <p style="color: rgba(255, 255, 255, 0.8); font-size: 14px; margin: 5px 0 0 0;">
276
+ Built with ❤️ using Gradio & Hugging Face Spaces
277
+ </p>
278
+ </div>
279
+ """)
280
+
281
+ # Connect the generate button to the function
282
+ generate_btn.click(
283
+ fn=generate_response,
284
+ inputs=[prompt, max_tokens, temperature, top_p, repetition_penalty],
285
+ outputs=output,
286
+ show_progress=True
287
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ return demo
 
 
290
 
291
+ # Launch the app
292
  if __name__ == "__main__":
293
+ demo = create_interface()
294
+ demo.launch()