|
|
|
"""Text_to_Image_Demo.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5 |
|
|
|
## Generate Images from Text |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import openai |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dotenv import load_dotenv, find_dotenv |
|
_ = load_dotenv(find_dotenv()) |
|
|
|
|
|
|
|
file1 = open('./100flowers.txt', 'r') |
|
Lines = file1.readlines() |
|
Lines = [line.strip() for line in Lines] |
|
|
|
from openai import OpenAI |
|
from PIL import Image |
|
import urllib.request |
|
from io import BytesIO |
|
from IPython.display import display |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
directory = './Flowers' |
|
png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")] |
|
|
|
|
|
from datasets import Dataset, Image |
|
|
|
|
|
def get_paths_to_images(images_directory): |
|
|
|
paths = [] |
|
for file in os.listdir(images_directory): |
|
print(file) |
|
paths.append(file) |
|
|
|
return paths |
|
|
|
|
|
def load_dataset(images_directory): |
|
|
|
paths_images = get_paths_to_images(images_directory) |
|
print(paths_images[0]) |
|
dataset = Dataset.from_dict({"image": paths_images}) |
|
|
|
return dataset |
|
|
|
path_images = "./Flowers" |
|
dataset = load_dataset(path_images) |
|
|
|
from transformers import AutoFeatureExtractor, AutoModel |
|
|
|
model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower" |
|
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt) |
|
model = AutoModel.from_pretrained(model_ckpt) |
|
|
|
import torchvision.transforms as T |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
|
|
transformation_chain = T.Compose( |
|
[ |
|
|
|
T.Resize(int((256 / 224) * extractor.size["height"])), |
|
T.CenterCrop(extractor.size["height"]), |
|
T.ToTensor(), |
|
T.Normalize(mean=extractor.image_mean, std=extractor.image_std), |
|
] |
|
) |
|
def extract_embeddings(model: torch.nn.Module): |
|
"""Utility to compute embeddings.""" |
|
device = model.device |
|
|
|
def pp(batch): |
|
images = batch["image"] |
|
image_batch_transformed = torch.stack( |
|
[transformation_chain(Image.open("./Flowers/" + image)) for image in images] |
|
) |
|
new_batch = {"pixel_values": image_batch_transformed.to(device)} |
|
with torch.no_grad(): |
|
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu() |
|
return {"embeddings": embeddings} |
|
|
|
return pp |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
batch_size = 1 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
extract_fn = extract_embeddings(model.to(device)) |
|
candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1) |
|
|
|
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"]) |
|
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings) |
|
|
|
print(all_candidate_embeddings.shape[0]) |
|
|
|
def compute_scores(emb_one, emb_two): |
|
"""Computes cosine similarity between two vectors.""" |
|
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two) |
|
return scores.numpy().tolist() |
|
|
|
|
|
def fetch_similar(image, top_k=5): |
|
"""Fetches the `top_k` similar images with `image` as the query.""" |
|
|
|
image_transformed = transformation_chain(image).unsqueeze(0) |
|
new_batch = {"pixel_values": image_transformed.to(device)} |
|
|
|
|
|
with torch.no_grad(): |
|
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu() |
|
|
|
|
|
|
|
|
|
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings) |
|
similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores)) |
|
|
|
|
|
similarity_mapping_sorted = dict( |
|
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True) |
|
) |
|
id_entries = list(similarity_mapping_sorted.keys())[:top_k] |
|
|
|
ids = list(map(lambda x: int(x.split("_")[0]), id_entries)) |
|
return ids |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def plot_images(images): |
|
|
|
for image, name in images: |
|
if name == 'original': |
|
count = 0 |
|
st.write("Showing the original image") |
|
st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto') |
|
|
|
else: |
|
count+=1 |
|
st.write(f"Showing similar image {count}") |
|
img = Image.open(image) |
|
st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto') |
|
|
|
|
|
import streamlit as st |
|
from io import StringIO |
|
|
|
|
|
st.title("Flower Type Demo") |
|
st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset") |
|
|
|
upload_file = st.file_uploader('Upload a Flower Image') |
|
|
|
images = [] |
|
|
|
if upload_file: |
|
test_sample = Image.open(upload_file) |
|
|
|
sim_ids = fetch_similar(test_sample) |
|
|
|
for id in sim_ids: |
|
images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"])) |
|
|
|
|
|
images.insert(0, (test_sample,'original')) |
|
print(images) |
|
plot_images(images) |
|
st.write("") |