Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
Β·
7e159c0
1
Parent(s):
f996296
[NEW] Sound modality input. Yet to put in proper examples
Browse files
app.py
CHANGED
|
@@ -10,10 +10,13 @@ import numpy as np
|
|
| 10 |
from PIL import Image
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
import io
|
|
|
|
| 13 |
|
| 14 |
from torchvision import transforms
|
| 15 |
import open_clip
|
| 16 |
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# ββββββββββββββββββββββββββ global config & models ββββββββββββββββββββ
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -25,11 +28,20 @@ bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
|
| 25 |
|
| 26 |
# Satellite patch encoder CLIP-L-336 per-patch)
|
| 27 |
sat_model: CLIPVisionPerPatchModel = (
|
| 28 |
-
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta")
|
| 29 |
.to(device)
|
| 30 |
.eval()
|
| 31 |
)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 34 |
logit_scale = logit_scale.exp()
|
| 35 |
blur_kernel = (5,5)
|
|
@@ -58,6 +70,13 @@ imo_transform = transforms.Compose(
|
|
| 58 |
]
|
| 59 |
)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# ββββββββββββββββββββββββββ helpers βββββββββββββββββββββββββββββββββββ
|
| 62 |
|
| 63 |
@torch.no_grad()
|
|
@@ -81,6 +100,16 @@ def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
|
|
| 81 |
return imo_embeds
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
| 85 |
sims = torch.matmul(query, patches.t()) * logit_scale
|
| 86 |
sims = sims.t().sigmoid()
|
|
@@ -122,13 +151,14 @@ def process(
|
|
| 122 |
sat_img: Image.Image,
|
| 123 |
taxonomy: str,
|
| 124 |
ground_img: Image.Image | None,
|
|
|
|
| 125 |
):
|
| 126 |
if sat_img is None:
|
| 127 |
return None, None
|
| 128 |
|
| 129 |
patches = _encode_sat(sat_img)
|
| 130 |
|
| 131 |
-
heat_ground, heat_text = None, None
|
| 132 |
|
| 133 |
if ground_img is not None:
|
| 134 |
q_img = _encode_ground(ground_img)
|
|
@@ -138,7 +168,11 @@ def process(
|
|
| 138 |
q_txt = _encode_text(taxonomy.strip())
|
| 139 |
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
|
|
@@ -191,6 +225,13 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
| 191 |
label="Full Taxonomy Name (optional)",
|
| 192 |
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 193 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
run_btn = gr.Button("Run", variant="primary")
|
| 195 |
|
| 196 |
# RIGHT COLUMN (ground image + two heat-maps)
|
|
@@ -209,6 +250,15 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
| 209 |
label="Heatmap (Text query)",
|
| 210 |
height=160,
|
| 211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
# EXAMPLES
|
| 214 |
with gr.Row():
|
|
@@ -218,25 +268,29 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
| 218 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_NAIP_yosemite_v3_resized.png",
|
| 219 |
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
| 220 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_inat_248820933.jpeg",
|
|
|
|
| 221 |
],
|
| 222 |
[
|
| 223 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_sentinel2_410613_5.35573_100.28948.jpg",
|
| 224 |
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
| 225 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_inat_461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
],
|
| 227 |
[
|
| 228 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
|
| 229 |
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 230 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
|
|
|
| 231 |
],
|
| 232 |
-
# [
|
| 233 |
-
# "examples/satellite_coast.png",
|
| 234 |
-
# "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus",
|
| 235 |
-
# None,
|
| 236 |
-
# ],
|
| 237 |
],
|
| 238 |
-
inputs=[sat_input, taxonomy_input, ground_input],
|
| 239 |
-
outputs=[heat_ground_out, heat_text_out],
|
| 240 |
fn=process,
|
| 241 |
cache_examples=False,
|
| 242 |
)
|
|
@@ -244,8 +298,8 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
| 244 |
# CALLBACK
|
| 245 |
run_btn.click(
|
| 246 |
fn=process,
|
| 247 |
-
inputs=[sat_input, taxonomy_input, ground_input],
|
| 248 |
-
outputs=[heat_ground_out, heat_text_out],
|
| 249 |
)
|
| 250 |
|
| 251 |
# Footer to point out to model and data from app page.
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
import io
|
| 13 |
+
import torchaudio
|
| 14 |
|
| 15 |
from torchvision import transforms
|
| 16 |
import open_clip
|
| 17 |
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
| 18 |
+
from transformers import ClapAudioModelWithProjection
|
| 19 |
+
from transformers import ClapProcessor
|
| 20 |
|
| 21 |
# ββββββββββββββββββββββββββ global config & models ββββββββββββββββββββ
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 28 |
|
| 29 |
# Satellite patch encoder CLIP-L-336 per-patch)
|
| 30 |
sat_model: CLIPVisionPerPatchModel = (
|
| 31 |
+
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
|
| 32 |
.to(device)
|
| 33 |
.eval()
|
| 34 |
)
|
| 35 |
|
| 36 |
+
# Sound CLAP model
|
| 37 |
+
sound_model: ClapAudioModelWithProjection = (
|
| 38 |
+
ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
|
| 39 |
+
.to(device)
|
| 40 |
+
.eval()
|
| 41 |
+
)
|
| 42 |
+
sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
|
| 43 |
+
SAMPLE_RATE = 48000
|
| 44 |
+
|
| 45 |
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 46 |
logit_scale = logit_scale.exp()
|
| 47 |
blur_kernel = (5,5)
|
|
|
|
| 70 |
]
|
| 71 |
)
|
| 72 |
|
| 73 |
+
def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
|
| 74 |
+
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
|
| 75 |
+
track = track.mean(axis=0)
|
| 76 |
+
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
|
| 77 |
+
output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
|
| 78 |
+
return output
|
| 79 |
+
|
| 80 |
# ββββββββββββββββββββββββββ helpers βββββββββββββββββββββββββββββββββββ
|
| 81 |
|
| 82 |
@torch.no_grad()
|
|
|
|
| 100 |
return imo_embeds
|
| 101 |
|
| 102 |
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def _encode_sound(sound) -> torch.Tensor:
|
| 105 |
+
processed_sound = get_audio_clap(sound)
|
| 106 |
+
for k in processed_sound.keys():
|
| 107 |
+
processed_sound[k] = processed_sound[k].to(device)
|
| 108 |
+
unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
|
| 109 |
+
sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
| 110 |
+
return sound_embeds
|
| 111 |
+
|
| 112 |
+
|
| 113 |
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
| 114 |
sims = torch.matmul(query, patches.t()) * logit_scale
|
| 115 |
sims = sims.t().sigmoid()
|
|
|
|
| 151 |
sat_img: Image.Image,
|
| 152 |
taxonomy: str,
|
| 153 |
ground_img: Image.Image | None,
|
| 154 |
+
sound: torch.Tensor | None,
|
| 155 |
):
|
| 156 |
if sat_img is None:
|
| 157 |
return None, None
|
| 158 |
|
| 159 |
patches = _encode_sat(sat_img)
|
| 160 |
|
| 161 |
+
heat_ground, heat_text, heat_sound = None, None, None
|
| 162 |
|
| 163 |
if ground_img is not None:
|
| 164 |
q_img = _encode_ground(ground_img)
|
|
|
|
| 168 |
q_txt = _encode_text(taxonomy.strip())
|
| 169 |
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
|
| 170 |
|
| 171 |
+
if sound is not None:
|
| 172 |
+
q_sound = _encode_sound(sound)
|
| 173 |
+
heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
|
| 174 |
+
|
| 175 |
+
return heat_ground, heat_text, heat_sound
|
| 176 |
|
| 177 |
|
| 178 |
# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
|
|
|
|
| 225 |
label="Full Taxonomy Name (optional)",
|
| 226 |
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 227 |
)
|
| 228 |
+
|
| 229 |
+
# βββ NEW: sound input βββββββββββββββββββββββββββ
|
| 230 |
+
sound_input = gr.Audio(
|
| 231 |
+
label="Sound Input",
|
| 232 |
+
source="upload", # or "microphone" / "url" as you prefer
|
| 233 |
+
type="filepath", # or "numpy" if you want raw arrays
|
| 234 |
+
)
|
| 235 |
run_btn = gr.Button("Run", variant="primary")
|
| 236 |
|
| 237 |
# RIGHT COLUMN (ground image + two heat-maps)
|
|
|
|
| 250 |
label="Heatmap (Text query)",
|
| 251 |
height=160,
|
| 252 |
)
|
| 253 |
+
heat_sound_out = gr.Image(
|
| 254 |
+
label="Heatmap (Sound query)",
|
| 255 |
+
height=160,
|
| 256 |
+
)
|
| 257 |
+
# βββ NEW: sound output βββββββββββββββββββββββββ
|
| 258 |
+
# sound_output = gr.Audio(
|
| 259 |
+
# label="Playback",
|
| 260 |
+
# )
|
| 261 |
+
|
| 262 |
|
| 263 |
# EXAMPLES
|
| 264 |
with gr.Row():
|
|
|
|
| 268 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_NAIP_yosemite_v3_resized.png",
|
| 269 |
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
| 270 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus_inat_248820933.jpeg",
|
| 271 |
+
None
|
| 272 |
],
|
| 273 |
[
|
| 274 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_sentinel2_410613_5.35573_100.28948.jpg",
|
| 275 |
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
| 276 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator_inat_461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
| 277 |
+
None
|
| 278 |
+
],
|
| 279 |
+
[
|
| 280 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
|
| 281 |
+
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 282 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
| 283 |
+
None
|
| 284 |
],
|
| 285 |
[
|
| 286 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_sentinel2_388246_45.49036_7.14796.jpg",
|
| 287 |
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 288 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota_inat_327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
| 289 |
+
"/mnt/hdd/inat2021_ds/2_OTHERS/sound_test/sounds_mp3/386157.mp3"
|
| 290 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
],
|
| 292 |
+
inputs=[sat_input, taxonomy_input, ground_input, sound_input],
|
| 293 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
| 294 |
fn=process,
|
| 295 |
cache_examples=False,
|
| 296 |
)
|
|
|
|
| 298 |
# CALLBACK
|
| 299 |
run_btn.click(
|
| 300 |
fn=process,
|
| 301 |
+
inputs=[sat_input, taxonomy_input, ground_input, sound_input],
|
| 302 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
| 303 |
)
|
| 304 |
|
| 305 |
# Footer to point out to model and data from app page.
|