AssanaliAidarkhan commited on
Commit
c49bb9e
Β·
verified Β·
1 Parent(s): 0aab966

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+ import numpy as np
6
+ import traceback
7
+
8
+ # Your model configuration
9
+ MODEL_REPO = "AssanaliAidarkhan/Biomedclip" # Your uploaded model
10
+ MODEL_FILENAME = "pytorch_model.bin" # The .pt file you uploaded
11
+
12
+ # Global variables
13
+ model = None
14
+
15
+ def load_model():
16
+ """Load the BiodemCLIP model from your uploaded .pt file"""
17
+ global model
18
+
19
+ try:
20
+ print(f"Downloading model from: {MODEL_REPO}")
21
+
22
+ # Download your model file
23
+ model_path = hf_hub_download(
24
+ repo_id=MODEL_REPO,
25
+ filename=MODEL_FILENAME,
26
+ cache_dir="./model_cache"
27
+ )
28
+
29
+ print(f"Model downloaded to: {model_path}")
30
+
31
+ # Load the model
32
+ # Note: Adjust this based on how your model was saved
33
+ model = torch.load(model_path, map_location='cpu')
34
+
35
+ # If your model was saved as a state dict, you might need:
36
+ # model = YourModelClass() # Initialize your model architecture
37
+ # model.load_state_dict(torch.load(model_path, map_location='cpu'))
38
+
39
+ # Set to evaluation mode
40
+ if hasattr(model, 'eval'):
41
+ model.eval()
42
+
43
+ print("βœ“ Model loaded successfully!")
44
+ return True
45
+
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+ print(traceback.format_exc())
49
+ return False
50
+
51
+ def predict(image, text_query):
52
+ """Make prediction with your model"""
53
+ global model
54
+
55
+ if model is None:
56
+ return "❌ Model not loaded! Please wait for initialization."
57
+
58
+ if image is None:
59
+ return "❌ Please upload an image."
60
+
61
+ if not text_query or text_query.strip() == "":
62
+ return "❌ Please enter a text query."
63
+
64
+ try:
65
+ # Convert PIL image to tensor
66
+ if isinstance(image, Image.Image):
67
+ # Convert to RGB if not already
68
+ image = image.convert('RGB')
69
+
70
+ # Convert to numpy array and then tensor
71
+ image_array = np.array(image)
72
+
73
+ # Normalize pixel values to [0, 1]
74
+ image_tensor = torch.from_numpy(image_array).float() / 255.0
75
+
76
+ # Rearrange dimensions from HWC to CHW
77
+ image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # Add batch dimension
78
+
79
+ # Process text (this is a basic tokenization - adjust based on your model)
80
+ # You might need to use a specific tokenizer here
81
+ text_tokens = text_query.lower().split() # Basic tokenization
82
+
83
+ # Run inference
84
+ with torch.no_grad():
85
+ try:
86
+ # This depends on your model's forward method
87
+ # You'll need to adjust this based on how your model expects inputs
88
+
89
+ # Example approaches (try these and see which works):
90
+
91
+ # Option 1: If your model has a forward method that takes image and text
92
+ if hasattr(model, 'forward'):
93
+ output = model(image_tensor, text_query)
94
+
95
+ # Option 2: If your model has separate encode methods
96
+ elif hasattr(model, 'encode_image') and hasattr(model, 'encode_text'):
97
+ image_features = model.encode_image(image_tensor)
98
+ text_features = model.encode_text(text_query)
99
+
100
+ # Calculate similarity
101
+ similarity = torch.cosine_similarity(image_features, text_features, dim=-1)
102
+ output = similarity
103
+
104
+ # Option 3: If it's a different architecture
105
+ else:
106
+ # You might need to call your model differently
107
+ # For example: output = model.predict(image_tensor, text_query)
108
+ output = model(image_tensor) # Adjust this line based on your model
109
+
110
+ # Process the output
111
+ if torch.is_tensor(output):
112
+ if output.numel() == 1: # Single value (like similarity score)
113
+ score = output.item()
114
+ else: # Multiple values
115
+ score = torch.mean(output).item() # Take mean as similarity
116
+ else:
117
+ score = float(output) if isinstance(output, (int, float)) else 0.5
118
+
119
+ result = f"""
120
+ 🎯 **Similarity Score:** {score:.4f}
121
+
122
+ πŸ“ **Query:** {text_query}
123
+
124
+ πŸ–ΌοΈ **Image Shape:** {image_array.shape}
125
+
126
+ πŸ’‘ **Interpretation:**
127
+ {interpret_similarity(score)}
128
+
129
+ πŸ”§ **Model Info:** Loaded from {MODEL_REPO}
130
+ """
131
+
132
+ return result
133
+
134
+ except Exception as model_error:
135
+ return f"""
136
+ ❌ **Model Inference Error:**
137
+ {str(model_error)}
138
+
139
+ πŸ”§ **Debug Info:**
140
+ - Image shape: {image_array.shape}
141
+ - Text query: "{text_query}"
142
+ - Model type: {type(model)}
143
+
144
+ πŸ’‘ **Note:** You may need to adjust the inference code based on your specific model architecture.
145
+ """
146
+
147
+ except Exception as e:
148
+ error_msg = f"❌ Error during prediction: {str(e)}"
149
+ print(traceback.format_exc())
150
+ return error_msg
151
+
152
+ def interpret_similarity(score):
153
+ """Interpret the similarity score"""
154
+ if score >= 0.8:
155
+ return "🟒 Very high similarity - Strong match!"
156
+ elif score >= 0.6:
157
+ return "🟑 Good similarity - Reasonable match"
158
+ elif score >= 0.4:
159
+ return "🟠 Moderate similarity - Some relevance"
160
+ elif score >= 0.2:
161
+ return "πŸ”΄ Low similarity - Limited relevance"
162
+ else:
163
+ return "⚫ Very low similarity - Poor match"
164
+
165
+ # Load model on startup
166
+ print("Initializing BiodemCLIP model...")
167
+ model_loaded = load_model()
168
+
169
+ # Create Gradio interface
170
+ with gr.Blocks(title="BiodemCLIP Demo", theme=gr.themes.Soft()) as demo:
171
+ gr.Markdown("""
172
+ # 🧬 BiodemCLIP Model Demo
173
+
174
+ Upload a biomedical image and enter a text description to see how well they match!
175
+
176
+ **Model:** AssanaliAidarkhan/Biomedclip
177
+ """)
178
+
179
+ if not model_loaded:
180
+ gr.Markdown("⚠️ **Warning: Model failed to load. Check the logs for details.**")
181
+ else:
182
+ gr.Markdown("βœ… **Model loaded successfully!**")
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ image_input = gr.Image(
187
+ type="pil",
188
+ label="πŸ“Έ Upload Biomedical Image",
189
+ height=400
190
+ )
191
+
192
+ text_input = gr.Textbox(
193
+ label="πŸ“ Enter Text Query",
194
+ placeholder="e.g., 'chest X-ray showing pneumonia', 'normal tissue sample', etc.",
195
+ lines=3
196
+ )
197
+
198
+ submit_btn = gr.Button("πŸ” Analyze", variant="primary", size="lg")
199
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
200
+
201
+ with gr.Column(scale=1):
202
+ output = gr.Markdown(label="πŸ“Š Results")
203
+
204
+ # Event handlers
205
+ submit_btn.click(
206
+ fn=predict,
207
+ inputs=[image_input, text_input],
208
+ outputs=output
209
+ )
210
+
211
+ clear_btn.click(
212
+ fn=lambda: [None, "", ""],
213
+ inputs=[],
214
+ outputs=[image_input, text_input, output]
215
+ )
216
+
217
+ gr.Markdown("""
218
+ ### πŸ“‹ Instructions:
219
+ 1. Upload a biomedical image (X-ray, MRI, microscopy, etc.)
220
+ 2. Enter a descriptive text query
221
+ 3. Click "Analyze" to get the similarity score
222
+
223
+ ### ℹ️ About:
224
+ This model analyzes the similarity between biomedical images and text descriptions.
225
+ Higher scores indicate better matches between the image and text.
226
+
227
+ ### πŸ”§ Technical Notes:
228
+ - Model loaded from Hugging Face Hub
229
+ - Runs on CPU (may be slower for large images)
230
+ - Custom .pt model loading
231
+ """)
232
+
233
+ # Launch the app
234
+ if __name__ == "__main__":
235
+ demo.launch()