Harshu0117 commited on
Commit
43d27e8
·
verified ·
1 Parent(s): 76c6708

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -226
app.py CHANGED
@@ -1,271 +1,327 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
4
 
5
- # Global variables for model and tokenizer
6
- model = None
7
- tokenizer = None
 
 
 
 
8
 
9
- def load_model():
10
- """Load the model and tokenizer from Hugging Face"""
11
- global model, tokenizer
 
 
 
 
 
12
 
13
- if model is None or tokenizer is None:
14
- try:
15
- model_name = "Harshu0117/Materials_IISC_MRC"
16
-
17
- # Load tokenizer
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
-
20
- # Load model for CPU
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_name,
23
- torch_dtype=torch.float32, # Use float32 for CPU
24
- device_map="cpu", # Force CPU usage
25
- trust_remote_code=True
26
- )
27
-
28
- # Set pad token if not set
29
- if tokenizer.pad_token is None:
30
- tokenizer.pad_token = tokenizer.eos_token
31
-
32
- return "✅ Model loaded successfully!"
33
- except Exception as e:
34
- return f"❌ Error loading model: {str(e)}"
35
 
36
- return "✅ Model already loaded!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def generate_response(prompt, max_tokens, temperature, top_p, repetition_penalty):
39
  """Generate response using the loaded model"""
40
- global model, tokenizer
41
-
42
- # Load model if not already loaded
43
- if model is None or tokenizer is None:
44
- load_result = load_model()
45
- if "Error" in load_result:
46
- return load_result
47
-
48
- if not prompt.strip():
49
- return "⚠️ Please enter a question or topic first!"
50
 
51
  try:
52
  # Tokenize input
53
- inputs = tokenizer(
54
- prompt.strip(),
55
  return_tensors="pt",
56
  truncation=True,
57
  max_length=1024
58
  )
59
 
60
- # Keep on CPU
61
- inputs = inputs.to("cpu")
 
62
 
63
  # Generate response
64
  with torch.no_grad():
65
- outputs = model.generate(
66
  **inputs,
67
- max_new_tokens=int(max_tokens),
68
- temperature=float(temperature),
69
- top_p=float(top_p),
70
- repetition_penalty=float(repetition_penalty),
71
  do_sample=True,
72
- pad_token_id=tokenizer.pad_token_id,
73
- eos_token_id=tokenizer.eos_token_id,
74
  use_cache=True
75
  )
76
 
77
  # Decode response
78
- response = tokenizer.decode(
79
  outputs[0],
80
  skip_special_tokens=True
81
  )
82
 
83
  # Remove the original prompt from response
84
- response = response.replace(prompt.strip(), "").strip()
85
 
86
  return response
87
 
88
  except Exception as e:
89
  return f"❌ Error generating response: {str(e)}"
90
 
91
- # Create Gradio interface
92
- def create_interface():
93
- # Custom CSS for styling
94
- css = """
95
- .gradio-container {
96
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
97
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
98
- }
99
- .gr-button-primary {
100
- background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important;
101
- border: none !important;
102
- border-radius: 25px !important;
103
- color: white !important;
104
- font-weight: bold !important;
105
- padding: 12px 24px !important;
106
- font-size: 16px !important;
107
- transition: all 0.3s ease !important;
108
- }
109
- .gr-button-primary:hover {
110
- transform: translateY(-2px) !important;
111
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2) !important;
112
- }
113
- .gr-textbox {
114
- border-radius: 15px !important;
115
- border: 2px solid #e0e0e0 !important;
116
- background: rgba(255, 255, 255, 0.95) !important;
117
- }
118
- .gr-textbox:focus {
119
- border-color: #4ECDC4 !important;
120
- box-shadow: 0 0 10px rgba(78, 205, 196, 0.3) !important;
121
- }
122
- .output-text {
123
- background: rgba(255, 255, 255, 0.95) !important;
124
- border-radius: 15px !important;
125
- padding: 20px !important;
126
- margin: 10px 0 !important;
127
- border-left: 4px solid #4ECDC4 !important;
128
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
129
- }
130
- .gr-accordion {
131
- background: rgba(255, 255, 255, 0.1) !important;
132
- border-radius: 15px !important;
133
- border: 1px solid rgba(255, 255, 255, 0.3) !important;
134
- }
135
- """
136
 
137
- # Create interface
138
- with gr.Blocks(
139
- css=css,
140
- title="🧪 Materials Science AI Assistant",
141
- theme=gr.themes.Soft(
142
- primary_hue="blue",
143
- secondary_hue="cyan",
144
- neutral_hue="slate"
145
- )
146
- ) as demo:
147
-
148
- # Header
149
- gr.HTML("""
150
- <div style="text-align: center; padding: 30px; background: rgba(255, 255, 255, 0.95); border-radius: 20px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);">
151
- <h1 style="color: #2c3e50; font-size: 2.5em; margin: 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.1);">
152
- 🧪 Materials Science AI Assistant
153
- </h1>
154
- <p style="color: #7f8c8d; font-size: 1.2em; margin: 10px 0 0 0; font-weight: 500;">
155
- Powered by Fine-tuned LLaMA 3 8B | Specialized in Materials Research
156
- </p>
157
- </div>
158
- """)
159
-
160
- # Main interface
161
- with gr.Row():
162
- with gr.Column(scale=2):
163
- # Input area
164
- gr.Markdown("### 💬 Ask me anything about Materials Science!")
165
-
166
- prompt = gr.Textbox(
167
- label="Enter your question or topic:",
168
- placeholder="e.g., Crystalline MAX Phases and their 2D derivative MXenes",
169
- lines=4,
170
- max_lines=8
171
- )
172
-
173
- # Advanced options
174
- with gr.Accordion("⚙️ Advanced Options", open=False):
175
- with gr.Row():
176
- max_tokens = gr.Slider(
177
- label="Max Tokens (Response Length)",
178
- minimum=50,
179
- maximum=500,
180
- value=200,
181
- step=10,
182
- info="Maximum number of tokens in the response"
183
- )
184
-
185
- temperature = gr.Slider(
186
- label="Temperature (Creativity)",
187
- minimum=0.1,
188
- maximum=1.0,
189
- value=0.7,
190
- step=0.1,
191
- info="Higher values make responses more creative"
192
- )
193
-
194
- with gr.Row():
195
- top_p = gr.Slider(
196
- label="Top-p (Diversity)",
197
- minimum=0.1,
198
- maximum=1.0,
199
- value=0.9,
200
- step=0.1,
201
- info="Controls diversity of word choices"
202
- )
203
-
204
- repetition_penalty = gr.Slider(
205
- label="Repetition Penalty",
206
- minimum=1.0,
207
- maximum=2.0,
208
- value=1.2,
209
- step=0.1,
210
- info="Penalty for repeating words/phrases"
211
- )
212
-
213
- # Generate button
214
- generate_btn = gr.Button(
215
- "🚀 Generate Response",
216
- variant="primary",
217
- size="lg"
218
- )
219
-
220
- # Output area
221
- gr.Markdown("### 🤖 AI Response:")
222
- output = gr.Textbox(
223
- label="Generated Response",
224
- lines=10,
225
- max_lines=20,
226
- interactive=False,
227
- elem_classes=["output-text"]
228
- )
229
 
230
- # Example prompts
231
- gr.Markdown("### 💡 Example Prompts (Click to use):")
232
- examples = [
233
- "Crystalline MAX Phases and their 2D derivative MXenes",
234
- "Properties of titanium carbide MXenes",
235
- "Synthesis methods for MAX phases",
236
- "Applications of MXenes in energy storage",
237
- "Mechanical properties of ceramic materials"
238
- ]
239
 
240
- gr.Examples(
241
- examples=examples,
242
- inputs=prompt,
243
- label="Click any example to try:"
244
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- # Footer
247
- gr.HTML("""
248
- <div style="text-align: center; padding: 20px; margin-top: 30px; background: rgba(255, 255, 255, 0.1); border-radius: 15px; border: 1px solid rgba(255, 255, 255, 0.3);">
249
- <p style="color: white; font-size: 16px; margin: 0;">
250
- 🔬 <strong>Specialized in Materials Science</strong> | 🧪 <strong>MAX Phases & MXenes Expert</strong>
251
- </p>
252
- <p style="color: rgba(255, 255, 255, 0.8); font-size: 14px; margin: 5px 0 0 0;">
253
- Built with ❤️ using Gradio & Hugging Face Spaces
254
- </p>
255
- </div>
256
- """)
 
 
 
 
 
 
 
257
 
258
- # Connect the generate button to the function
259
- generate_btn.click(
260
- fn=generate_response,
261
- inputs=[prompt, max_tokens, temperature, top_p, repetition_penalty],
262
- outputs=output,
263
- show_progress=True
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- return demo
 
 
267
 
268
- # Launch the app
269
  if __name__ == "__main__":
270
- demo = create_interface()
271
- demo.launch()
 
1
+ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import time
5
+ import base64
6
+ from io import BytesIO
7
 
8
+ # Page configuration
9
+ st.set_page_config(
10
+ page_title="Materials Science AI Assistant",
11
+ page_icon="🧪",
12
+ layout="wide",
13
+ initial_sidebar_state="collapsed"
14
+ )
15
 
16
+ # Custom CSS for styling
17
+ st.markdown("""
18
+ <style>
19
+ /* Main background gradient */
20
+ .stApp {
21
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
22
+ color: white;
23
+ }
24
 
25
+ /* Header styling */
26
+ .main-header {
27
+ text-align: center;
28
+ padding: 2rem 0;
29
+ background: rgba(255, 255, 255, 0.1);
30
+ border-radius: 20px;
31
+ margin-bottom: 2rem;
32
+ backdrop-filter: blur(10px);
33
+ border: 1px solid rgba(255, 255, 255, 0.2);
34
+ }
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ /* Input area styling */
37
+ .stTextArea textarea {
38
+ background: rgba(255, 255, 255, 0.15);
39
+ border: 1px solid rgba(255, 255, 255, 0.3);
40
+ border-radius: 15px;
41
+ color: white;
42
+ font-size: 16px;
43
+ backdrop-filter: blur(5px);
44
+ }
45
+
46
+ /* Button styling */
47
+ .stButton button {
48
+ background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
49
+ border: none;
50
+ border-radius: 25px;
51
+ color: white;
52
+ font-weight: bold;
53
+ padding: 0.75rem 2rem;
54
+ font-size: 16px;
55
+ transition: all 0.3s ease;
56
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
57
+ }
58
+
59
+ .stButton button:hover {
60
+ transform: translateY(-2px);
61
+ box-shadow: 0 6px 20px rgba(0, 0, 0, 0.3);
62
+ }
63
+
64
+ /* Response area styling */
65
+ .response-container {
66
+ background: rgba(255, 255, 255, 0.1);
67
+ border-radius: 15px;
68
+ padding: 1.5rem;
69
+ margin: 1rem 0;
70
+ backdrop-filter: blur(10px);
71
+ border: 1px solid rgba(255, 255, 255, 0.2);
72
+ }
73
+
74
+ /* Advanced options styling */
75
+ .advanced-options {
76
+ background: rgba(255, 255, 255, 0.08);
77
+ border-radius: 15px;
78
+ padding: 1rem;
79
+ margin: 1rem 0;
80
+ border: 1px solid rgba(255, 255, 255, 0.1);
81
+ }
82
+
83
+ /* Loading animation */
84
+ .loading-animation {
85
+ text-align: center;
86
+ font-size: 18px;
87
+ color: #4ECDC4;
88
+ animation: pulse 2s infinite;
89
+ }
90
+
91
+ @keyframes pulse {
92
+ 0% { opacity: 1; }
93
+ 50% { opacity: 0.5; }
94
+ 100% { opacity: 1; }
95
+ }
96
+
97
+ /* Sidebar styling */
98
+ .sidebar .sidebar-content {
99
+ background: rgba(255, 255, 255, 0.1);
100
+ backdrop-filter: blur(10px);
101
+ }
102
+
103
+ /* Hide streamlit menu */
104
+ #MainMenu {visibility: hidden;}
105
+ footer {visibility: hidden;}
106
+ header {visibility: hidden;}
107
+ </style>
108
+ """, unsafe_allow_html=True)
109
+
110
+ # Initialize session state
111
+ if 'model' not in st.session_state:
112
+ st.session_state.model = None
113
+ st.session_state.tokenizer = None
114
+ st.session_state.model_loaded = False
115
+
116
+ @st.cache_resource
117
+ def load_model():
118
+ """Load the model and tokenizer from Hugging Face"""
119
+ try:
120
+ model_name = "Harshu0117/Materials_IISC_MRC"
121
+
122
+ # Load tokenizer
123
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
124
+
125
+ # Load model
126
+ model = AutoModelForCausalLM.from_pretrained(
127
+ model_name,
128
+ torch_dtype=torch.float16,
129
+ device_map="auto",
130
+ trust_remote_code=True
131
+ )
132
+
133
+ # Set pad token if not set
134
+ if tokenizer.pad_token is None:
135
+ tokenizer.pad_token = tokenizer.eos_token
136
+
137
+ return model, tokenizer
138
+ except Exception as e:
139
+ st.error(f"Error loading model: {str(e)}")
140
+ return None, None
141
 
142
  def generate_response(prompt, max_tokens, temperature, top_p, repetition_penalty):
143
  """Generate response using the loaded model"""
144
+ if st.session_state.model is None or st.session_state.tokenizer is None:
145
+ return "❌ Model not loaded properly. Please refresh the page."
 
 
 
 
 
 
 
 
146
 
147
  try:
148
  # Tokenize input
149
+ inputs = st.session_state.tokenizer(
150
+ prompt,
151
  return_tensors="pt",
152
  truncation=True,
153
  max_length=1024
154
  )
155
 
156
+ # Move to device
157
+ if torch.cuda.is_available():
158
+ inputs = inputs.to("cuda")
159
 
160
  # Generate response
161
  with torch.no_grad():
162
+ outputs = st.session_state.model.generate(
163
  **inputs,
164
+ max_new_tokens=max_tokens,
165
+ temperature=temperature,
166
+ top_p=top_p,
167
+ repetition_penalty=repetition_penalty,
168
  do_sample=True,
169
+ pad_token_id=st.session_state.tokenizer.pad_token_id,
170
+ eos_token_id=st.session_state.tokenizer.eos_token_id,
171
  use_cache=True
172
  )
173
 
174
  # Decode response
175
+ response = st.session_state.tokenizer.decode(
176
  outputs[0],
177
  skip_special_tokens=True
178
  )
179
 
180
  # Remove the original prompt from response
181
+ response = response.replace(prompt, "").strip()
182
 
183
  return response
184
 
185
  except Exception as e:
186
  return f"❌ Error generating response: {str(e)}"
187
 
188
+ # Main app layout
189
+ def main():
190
+ # Header
191
+ st.markdown("""
192
+ <div class="main-header">
193
+ <h1>🧪 Materials Science AI Assistant</h1>
194
+ <p style="font-size: 18px; margin-top: 10px;">
195
+ Powered by Fine-tuned LLaMA 3 8B | Specialized in Materials Research
196
+ </p>
197
+ </div>
198
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Load model on first run
201
+ if not st.session_state.model_loaded:
202
+ with st.spinner("🔄 Loading AI model... This may take a moment..."):
203
+ st.session_state.model, st.session_state.tokenizer = load_model()
204
+ if st.session_state.model is not None:
205
+ st.session_state.model_loaded = True
206
+ st.success("✅ Model loaded successfully!")
207
+ else:
208
+ st.error("❌ Failed to load model. Please refresh the page.")
209
+ return
210
+
211
+ # Main input area
212
+ st.markdown("### 💬 Ask me anything about Materials Science!")
213
+
214
+ # Input text area
215
+ prompt = st.text_area(
216
+ "Enter your question or topic:",
217
+ placeholder="e.g., Crystalline MAX Phases and their 2D derivative MXenes",
218
+ height=100,
219
+ key="prompt_input"
220
+ )
221
+
222
+ # Advanced options (collapsible)
223
+ with st.expander("⚙️ Advanced Options"):
224
+ st.markdown('<div class="advanced-options">', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
227
 
228
+ with col1:
229
+ max_tokens = st.slider(
230
+ "Max Tokens (Response Length)",
231
+ min_value=50,
232
+ max_value=500,
233
+ value=200,
234
+ step=10,
235
+ help="Maximum number of tokens in the response"
236
+ )
237
+
238
+ temperature = st.slider(
239
+ "Temperature (Creativity)",
240
+ min_value=0.1,
241
+ max_value=1.0,
242
+ value=0.7,
243
+ step=0.1,
244
+ help="Higher values make responses more creative but less focused"
245
+ )
246
 
247
+ with col2:
248
+ top_p = st.slider(
249
+ "Top-p (Diversity)",
250
+ min_value=0.1,
251
+ max_value=1.0,
252
+ value=0.9,
253
+ step=0.1,
254
+ help="Controls diversity of word choices"
255
+ )
256
+
257
+ repetition_penalty = st.slider(
258
+ "Repetition Penalty",
259
+ min_value=1.0,
260
+ max_value=2.0,
261
+ value=1.2,
262
+ step=0.1,
263
+ help="Penalty for repeating words/phrases"
264
+ )
265
 
266
+ st.markdown('</div>', unsafe_allow_html=True)
267
+
268
+ # Generate button
269
+ col1, col2, col3 = st.columns([1, 2, 1])
270
+ with col2:
271
+ generate_btn = st.button("🚀 Generate Response", use_container_width=True)
272
+
273
+ # Response area
274
+ if generate_btn and prompt.strip():
275
+ if st.session_state.model_loaded:
276
+ with st.spinner("🧠 AI is thinking..."):
277
+ response = generate_response(
278
+ prompt.strip(),
279
+ max_tokens,
280
+ temperature,
281
+ top_p,
282
+ repetition_penalty
283
+ )
284
+
285
+ # Display response
286
+ st.markdown("### 🤖 AI Response:")
287
+ st.markdown(f"""
288
+ <div class="response-container">
289
+ <p style="font-size: 16px; line-height: 1.6;">
290
+ {response}
291
+ </p>
292
+ </div>
293
+ """, unsafe_allow_html=True)
294
+
295
+ else:
296
+ st.error("❌ Model not loaded. Please refresh the page.")
297
+
298
+ elif generate_btn and not prompt.strip():
299
+ st.warning("⚠️ Please enter a question or topic first!")
300
+
301
+ # Footer
302
+ st.markdown("---")
303
+ st.markdown("""
304
+ <div style="text-align: center; padding: 1rem; color: rgba(255, 255, 255, 0.7);">
305
+ <p>🔬 Specialized in Materials Science | 🧪 MAX Phases & MXenes Expert</p>
306
+ <p>Built with ❤️ using Streamlit & Hugging Face</p>
307
+ </div>
308
+ """, unsafe_allow_html=True)
309
+
310
+ # Example prompts sidebar
311
+ def show_examples():
312
+ st.sidebar.markdown("### 💡 Example Prompts")
313
+ examples = [
314
+ "Crystalline MAX Phases and their 2D derivative MXenes",
315
+ "Properties of titanium carbide MXenes",
316
+ "Synthesis methods for MAX phases",
317
+ "Applications of MXenes in energy storage",
318
+ "Mechanical properties of ceramic materials"
319
+ ]
320
 
321
+ for i, example in enumerate(examples):
322
+ if st.sidebar.button(f"📝 {example[:30]}...", key=f"example_{i}"):
323
+ st.session_state.prompt_input = example
324
 
 
325
  if __name__ == "__main__":
326
+ show_examples()
327
+ main()