File size: 6,491 Bytes
ceb4613 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# -*- coding: utf-8 -*-
"""Text_to_Image_Demo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5
## Generate Images from Text
"""
# Important installations
# pip install openai
# pip install python-dotenv
# pip install transformers datasets -q
# pip install streamlit
import os
import openai
# open_ai_key_file = "openai_api_key_llm_2023.txt"
# with open(open_ai_key_file, "r") as f:
# for line in f:
# OPENAI_KEY = line
# break
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
# Read 100 flower names from 100flowers.txt
# openai.api_key = OPENAI_KEY
file1 = open('./100flowers.txt', 'r')
Lines = file1.readlines()
Lines = [line.strip() for line in Lines]
from openai import OpenAI
from PIL import Image
import urllib.request
from io import BytesIO
from IPython.display import display
# client = OpenAI(api_key=OPENAI_KEY)
# Code to generate images from names in 100flowers.txt
# for prompt in Lines:
# response = client.images.generate(
# model="dall-e-3",
# prompt=prompt,
# size="1024x1024",
# quality="standard",
# n=1,
# )
# Code to save generated images as png in Flowers folder
# image_url = response.data[0].url
# with urllib.request.urlopen(image_url) as image_url:
# img = Image.open(BytesIO(image_url.read()))
# img.save(f'./Flowers/{prompt}.png')
# from transformers.utils import send_example_telemetry
# send_example_telemetry("image_similarity_notebook", framework="pytorch")
# Creates a list of flower names
directory = './Flowers'
png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")]
from datasets import Dataset, Image
# Gets list of file paths
def get_paths_to_images(images_directory):
paths = []
for file in os.listdir(images_directory):
print(file)
paths.append(file)
return paths
# Creates dataset
def load_dataset(images_directory):
paths_images = get_paths_to_images(images_directory)
print(paths_images[0])
dataset = Dataset.from_dict({"image": paths_images})
return dataset
path_images = "./Flowers"
dataset = load_dataset(path_images)
from transformers import AutoFeatureExtractor, AutoModel
model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
import torchvision.transforms as T
import torch
from PIL import Image
# Data transformation chain.
transformation_chain = T.Compose(
[
# We first resize the input image to 256x256 and then we take center crop.
T.Resize(int((256 / 224) * extractor.size["height"])),
T.CenterCrop(extractor.size["height"]),
T.ToTensor(),
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
]
)
def extract_embeddings(model: torch.nn.Module):
"""Utility to compute embeddings."""
device = model.device
def pp(batch):
images = batch["image"]
image_batch_transformed = torch.stack(
[transformation_chain(Image.open("./Flowers/" + image)) for image in images]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
return {"embeddings": embeddings}
return pp
import numpy as np
# Here, we map embedding extraction utility on our subset of candidate images.
batch_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1)
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
print(all_candidate_embeddings.shape[0])
def compute_scores(emb_one, emb_two):
"""Computes cosine similarity between two vectors."""
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
return scores.numpy().tolist()
def fetch_similar(image, top_k=5):
"""Fetches the `top_k` similar images with `image` as the query."""
# Prepare the input query image for embedding computation.
image_transformed = transformation_chain(image).unsqueeze(0)
new_batch = {"pixel_values": image_transformed.to(device)}
# Compute the embedding.
with torch.no_grad():
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
# Compute similarity scores with all the candidate images at one go.
# We also create a mapping between the candidate image identifiers
# and their similarity scores with the query image.
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores))
# Sort the mapping dictionary and return `top_k` candidates.
similarity_mapping_sorted = dict(
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
)
id_entries = list(similarity_mapping_sorted.keys())[:top_k]
ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
return ids
import matplotlib.pyplot as plt
def plot_images(images):
for image, name in images:
if name == 'original':
count = 0
st.write("Showing the original image")
st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
else:
count+=1
st.write(f"Showing similar image {count}")
img = Image.open(image)
st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
# Streamlit webpage code
import streamlit as st
from io import StringIO
# Text Search
st.title("Flower Type Demo")
st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset")
upload_file = st.file_uploader('Upload a Flower Image')
images = []
if upload_file:
test_sample = Image.open(upload_file)
sim_ids = fetch_similar(test_sample)
for id in sim_ids:
images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"]))
images.insert(0, (test_sample,'original'))
print(images)
plot_images(images)
st.write("") |