marinap's picture
gallery args cleanup: older gradio version is installed on top of requirements.txt
d79a08c
raw
history blame
3.02 kB
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
from datetime import datetime
model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')
embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)
#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)
img_df = pd.read_csv('image_data.csv')
def url2img(url, resize = False, fix_height = 150):
data = requests.get(url, allow_redirects = True).content
img = Image.open(io.BytesIO(data))
if resize:
img.thumbnail([fix_height, fix_height], Image.LANCZOS)
return img
def find_topk(text):
print('text', text)
top_k = 20
text_data = model_multi.preprocess_text(text)
text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
print('Got features', datetime.now().strftime("%H:%M:%S"))
sims = F.cosine_similarity(text_embedding, embeddings)
vals, inds = sims.topk(top_k)
top_k_urls = img_df.iloc[inds]['photo_image_url'].values
print('Got top_k_urls', top_k_urls)
print(datetime.now().strftime("%H:%M:%S"))
return top_k_urls
# def rerank(text_features, text_data):
# # craet joint embeddings & get scores
# joint_embedding = model_multi.encode_multimodal(
# image_features=image_features,
# text_features=text_features,
# attention_mask=text_data['attention_mask']
# )
# score = model_multi.get_matching_scores(joint_embedding)
# # argmax to get top N
# return
#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')
print('version', gr.__version__)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown('# Enter a prompt in one of the supported languages.')
with gr.Row():
gr.Markdown('| Code | Lang | # | Code | Lang |\n'
'| :------- | :------- | :--- | :------- | :------------------- |\n'
'| eng_Latn | English | # | fra_Latn | French |\n'
'| deu_Latn | German | # | ita_Latn | Italian |\n'
'| ita_Latn | Spanish | # | jpn_Jpan | Japanese |\n'
'| tur_Latn | Turkish | # | zho_Hans | Chinese (Simplified) |\n'
'| kor_Hang | Korean | # | pol_Latn | Polish |\n'
'| rus_Cyrl | Russian | # | . | . |\n')
with gr.Column():
prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
btn_search = gr.Button("Find images")
gallery = gr.Gallery().style(grid = [5], height="auto")
btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)
if __name__ == "__main__":
demo.launch()