derektan commited on
Commit
7e159c0
Β·
1 Parent(s): f996296

[NEW] Sound modality input. Yet to put in proper examples

Browse files
Files changed (1) hide show
  1. app.py +66 -12
app.py CHANGED
@@ -10,10 +10,13 @@ import numpy as np
10
  from PIL import Image
11
  import matplotlib.pyplot as plt
12
  import io
 
13
 
14
  from torchvision import transforms
15
  import open_clip
16
  from clip_vision_per_patch_model import CLIPVisionPerPatchModel
 
 
17
 
18
  # ────────────────────────── global config & models ────────────────────
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -25,11 +28,20 @@ bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
25
 
26
  # Satellite patch encoder CLIP-L-336 per-patch)
27
  sat_model: CLIPVisionPerPatchModel = (
28
- CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta")
29
  .to(device)
30
  .eval()
31
  )
32
 
 
 
 
 
 
 
 
 
 
33
  logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
34
  logit_scale = logit_scale.exp()
35
  blur_kernel = (5,5)
@@ -58,6 +70,13 @@ imo_transform = transforms.Compose(
58
  ]
59
  )
60
 
 
 
 
 
 
 
 
61
  # ────────────────────────── helpers ───────────────────────────────────
62
 
63
  @torch.no_grad()
@@ -81,6 +100,16 @@ def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
81
  return imo_embeds
82
 
83
 
 
 
 
 
 
 
 
 
 
 
84
  def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
85
  sims = torch.matmul(query, patches.t()) * logit_scale
86
  sims = sims.t().sigmoid()
@@ -122,13 +151,14 @@ def process(
122
  sat_img: Image.Image,
123
  taxonomy: str,
124
  ground_img: Image.Image | None,
 
125
  ):
126
  if sat_img is None:
127
  return None, None
128
 
129
  patches = _encode_sat(sat_img)
130
 
131
- heat_ground, heat_text = None, None
132
 
133
  if ground_img is not None:
134
  q_img = _encode_ground(ground_img)
@@ -138,7 +168,11 @@ def process(
138
  q_txt = _encode_text(taxonomy.strip())
139
  heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
140
 
141
- return heat_ground, heat_text
 
 
 
 
142
 
143
 
144
  # ────────────────────────── Gradio UI ─────────────────────────────────
@@ -191,6 +225,13 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
191
  label="Full Taxonomy Name (optional)",
192
  placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
193
  )
 
 
 
 
 
 
 
194
  run_btn = gr.Button("Run", variant="primary")
195
 
196
  # RIGHT COLUMN (ground image + two heat-maps)
@@ -209,6 +250,15 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
209
  label="Heatmap (Text query)",
210
  height=160,
211
  )
 
 
 
 
 
 
 
 
 
212
 
213
  # EXAMPLES
214
  with gr.Row():
@@ -218,25 +268,29 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
218
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_NAIP_yosemite_v3_resized.png",
219
  "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
220
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_inat_248820933.jpeg",
 
221
  ],
222
  [
223
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_sentinel2_410613_5.35573_100.28948.jpg",
224
  "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
225
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_inat_461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
 
 
 
 
 
 
 
226
  ],
227
  [
228
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
229
  "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
230
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
 
231
  ],
232
- # [
233
- # "examples/satellite_coast.png",
234
- # "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus",
235
- # None,
236
- # ],
237
  ],
238
- inputs=[sat_input, taxonomy_input, ground_input],
239
- outputs=[heat_ground_out, heat_text_out],
240
  fn=process,
241
  cache_examples=False,
242
  )
@@ -244,8 +298,8 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
244
  # CALLBACK
245
  run_btn.click(
246
  fn=process,
247
- inputs=[sat_input, taxonomy_input, ground_input],
248
- outputs=[heat_ground_out, heat_text_out],
249
  )
250
 
251
  # Footer to point out to model and data from app page.
 
10
  from PIL import Image
11
  import matplotlib.pyplot as plt
12
  import io
13
+ import torchaudio
14
 
15
  from torchvision import transforms
16
  import open_clip
17
  from clip_vision_per_patch_model import CLIPVisionPerPatchModel
18
+ from transformers import ClapAudioModelWithProjection
19
+ from transformers import ClapProcessor
20
 
21
  # ────────────────────────── global config & models ────────────────────
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
 
29
  # Satellite patch encoder CLIP-L-336 per-patch)
30
  sat_model: CLIPVisionPerPatchModel = (
31
+ CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
32
  .to(device)
33
  .eval()
34
  )
35
 
36
+ # Sound CLAP model
37
+ sound_model: ClapAudioModelWithProjection = (
38
+ ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
39
+ .to(device)
40
+ .eval()
41
+ )
42
+ sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
43
+ SAMPLE_RATE = 48000
44
+
45
  logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
46
  logit_scale = logit_scale.exp()
47
  blur_kernel = (5,5)
 
70
  ]
71
  )
72
 
73
+ def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
74
+ track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
75
+ track = track.mean(axis=0)
76
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
77
+ output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
78
+ return output
79
+
80
  # ────────────────────────── helpers ───────────────────────────────────
81
 
82
  @torch.no_grad()
 
100
  return imo_embeds
101
 
102
 
103
+ @torch.no_grad()
104
+ def _encode_sound(sound) -> torch.Tensor:
105
+ processed_sound = get_audio_clap(sound)
106
+ for k in processed_sound.keys():
107
+ processed_sound[k] = processed_sound[k].to(device)
108
+ unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
109
+ sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
110
+ return sound_embeds
111
+
112
+
113
  def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
114
  sims = torch.matmul(query, patches.t()) * logit_scale
115
  sims = sims.t().sigmoid()
 
151
  sat_img: Image.Image,
152
  taxonomy: str,
153
  ground_img: Image.Image | None,
154
+ sound: torch.Tensor | None,
155
  ):
156
  if sat_img is None:
157
  return None, None
158
 
159
  patches = _encode_sat(sat_img)
160
 
161
+ heat_ground, heat_text, heat_sound = None, None, None
162
 
163
  if ground_img is not None:
164
  q_img = _encode_ground(ground_img)
 
168
  q_txt = _encode_text(taxonomy.strip())
169
  heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
170
 
171
+ if sound is not None:
172
+ q_sound = _encode_sound(sound)
173
+ heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
174
+
175
+ return heat_ground, heat_text, heat_sound
176
 
177
 
178
  # ────────────────────────── Gradio UI ─────────────────────────────────
 
225
  label="Full Taxonomy Name (optional)",
226
  placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
227
  )
228
+
229
+ # ─── NEW: sound input ───────────────────────────
230
+ sound_input = gr.Audio(
231
+ label="Sound Input",
232
+ source="upload", # or "microphone" / "url" as you prefer
233
+ type="filepath", # or "numpy" if you want raw arrays
234
+ )
235
  run_btn = gr.Button("Run", variant="primary")
236
 
237
  # RIGHT COLUMN (ground image + two heat-maps)
 
250
  label="Heatmap (Text query)",
251
  height=160,
252
  )
253
+ heat_sound_out = gr.Image(
254
+ label="Heatmap (Sound query)",
255
+ height=160,
256
+ )
257
+ # ─── NEW: sound output ─────────────────────────
258
+ # sound_output = gr.Audio(
259
+ # label="Playback",
260
+ # )
261
+
262
 
263
  # EXAMPLES
264
  with gr.Row():
 
268
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_NAIP_yosemite_v3_resized.png",
269
  "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
270
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_inat_248820933.jpeg",
271
+ None
272
  ],
273
  [
274
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_sentinel2_410613_5.35573_100.28948.jpg",
275
  "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
276
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_inat_461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
277
+ None
278
+ ],
279
+ [
280
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
281
+ "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
282
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
283
+ None
284
  ],
285
  [
286
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
287
  "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
288
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
289
+ "/mnt/hdd/inat2021_ds/2_OTHERS/sound_test/sounds_mp3/386157.mp3"
290
  ],
 
 
 
 
 
291
  ],
292
+ inputs=[sat_input, taxonomy_input, ground_input, sound_input],
293
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
294
  fn=process,
295
  cache_examples=False,
296
  )
 
298
  # CALLBACK
299
  run_btn.click(
300
  fn=process,
301
+ inputs=[sat_input, taxonomy_input, ground_input, sound_input],
302
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
303
  )
304
 
305
  # Footer to point out to model and data from app page.