# Construct pairs of text and image from configs import CFG from costum_datasets import make_pairs from text_image_audio import OneEncoder import torch import gradio as gr import torchaudio # Construct pairs of text and image training_pairs = make_pairs(CFG.image_dir, CFG.image_dir, 5) # 413.915 -> 82.783 images # Sorted according images training_pairs = sorted(training_pairs, key=lambda x: x[0]) coco_images, coco_captions = zip(*training_pairs) # Take unique images unique_images = set() unique_pairs = [(item[0], item[1]) for item in training_pairs if item[0] not in unique_images and not unique_images.add(item[0])] coco_images, _ = zip(*unique_pairs) # Load model model = OneEncoder.from_pretrained("bilalfaye/OneEncoder-text-image-audio") # Load coco image features coco_image_features = torch.load("image_embeddings_best.pt", map_location=CFG.device) coco_image_features = coco_image_features[:3000] def text_image(query): model.text_image_encoder.image_retrieval(query, image_paths=coco_images, image_embeddings=coco_image_features, n=9, plot=True, temperature=0.0 ) return "img.png" def audio_image(query): # Load the audio with torchaudio (returns tensor and sample rate) waveform, sample_rate = torchaudio.load(query) # Check if audio is stereo if waveform.shape[0] > 1: # Stereo (2 channels) # Convert stereo to mono: sum the left and right channels and divide by 2 mono_audio = waveform.mean(dim=0, keepdim=True) else: # Audio is already mono mono_audio = waveform # Resample to 16000 Hz if not already if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) mono_audio = resampler(mono_audio) # Convert to numpy array for pipeline processing (if required) mono_audio = mono_audio.squeeze(0).numpy() audio_encoding = model.process_audio([mono_audio]) model.image_retrieval(audio_encoding, image_paths=coco_images, image_embeddings=coco_image_features, n=9, plot=True, temperature=0.0, display_audio=False) return "img.png" # Updated Gradio Interface iface = gr.TabbedInterface( [ gr.Interface( fn=text_image, inputs=gr.Textbox(label="Text Query"), outputs="image", title="Retrieve images using text as query", description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." ), gr.Interface( fn=audio_image, inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Provide Audio Query"), outputs="image", title="Retrieve images using audio as query", description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." ) ], tab_names=["Text - Image", "Audio - Image"] ) iface.launch(debug=True, share=True)