llm / app.py
ritiksh's picture
Upload app.py
ceb4613
# -*- coding: utf-8 -*-
"""Text_to_Image_Demo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5
## Generate Images from Text
"""
# Important installations
# pip install openai
# pip install python-dotenv
# pip install transformers datasets -q
# pip install streamlit
import os
import openai
# open_ai_key_file = "openai_api_key_llm_2023.txt"
# with open(open_ai_key_file, "r") as f:
# for line in f:
# OPENAI_KEY = line
# break
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
# Read 100 flower names from 100flowers.txt
# openai.api_key = OPENAI_KEY
file1 = open('./100flowers.txt', 'r')
Lines = file1.readlines()
Lines = [line.strip() for line in Lines]
from openai import OpenAI
from PIL import Image
import urllib.request
from io import BytesIO
from IPython.display import display
# client = OpenAI(api_key=OPENAI_KEY)
# Code to generate images from names in 100flowers.txt
# for prompt in Lines:
# response = client.images.generate(
# model="dall-e-3",
# prompt=prompt,
# size="1024x1024",
# quality="standard",
# n=1,
# )
# Code to save generated images as png in Flowers folder
# image_url = response.data[0].url
# with urllib.request.urlopen(image_url) as image_url:
# img = Image.open(BytesIO(image_url.read()))
# img.save(f'./Flowers/{prompt}.png')
# from transformers.utils import send_example_telemetry
# send_example_telemetry("image_similarity_notebook", framework="pytorch")
# Creates a list of flower names
directory = './Flowers'
png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")]
from datasets import Dataset, Image
# Gets list of file paths
def get_paths_to_images(images_directory):
paths = []
for file in os.listdir(images_directory):
print(file)
paths.append(file)
return paths
# Creates dataset
def load_dataset(images_directory):
paths_images = get_paths_to_images(images_directory)
print(paths_images[0])
dataset = Dataset.from_dict({"image": paths_images})
return dataset
path_images = "./Flowers"
dataset = load_dataset(path_images)
from transformers import AutoFeatureExtractor, AutoModel
model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
import torchvision.transforms as T
import torch
from PIL import Image
# Data transformation chain.
transformation_chain = T.Compose(
[
# We first resize the input image to 256x256 and then we take center crop.
T.Resize(int((256 / 224) * extractor.size["height"])),
T.CenterCrop(extractor.size["height"]),
T.ToTensor(),
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
]
)
def extract_embeddings(model: torch.nn.Module):
"""Utility to compute embeddings."""
device = model.device
def pp(batch):
images = batch["image"]
image_batch_transformed = torch.stack(
[transformation_chain(Image.open("./Flowers/" + image)) for image in images]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
return {"embeddings": embeddings}
return pp
import numpy as np
# Here, we map embedding extraction utility on our subset of candidate images.
batch_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1)
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
print(all_candidate_embeddings.shape[0])
def compute_scores(emb_one, emb_two):
"""Computes cosine similarity between two vectors."""
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
return scores.numpy().tolist()
def fetch_similar(image, top_k=5):
"""Fetches the `top_k` similar images with `image` as the query."""
# Prepare the input query image for embedding computation.
image_transformed = transformation_chain(image).unsqueeze(0)
new_batch = {"pixel_values": image_transformed.to(device)}
# Compute the embedding.
with torch.no_grad():
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
# Compute similarity scores with all the candidate images at one go.
# We also create a mapping between the candidate image identifiers
# and their similarity scores with the query image.
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores))
# Sort the mapping dictionary and return `top_k` candidates.
similarity_mapping_sorted = dict(
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
)
id_entries = list(similarity_mapping_sorted.keys())[:top_k]
ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
return ids
import matplotlib.pyplot as plt
def plot_images(images):
for image, name in images:
if name == 'original':
count = 0
st.write("Showing the original image")
st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
else:
count+=1
st.write(f"Showing similar image {count}")
img = Image.open(image)
st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
# Streamlit webpage code
import streamlit as st
from io import StringIO
# Text Search
st.title("Flower Type Demo")
st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset")
upload_file = st.file_uploader('Upload a Flower Image')
images = []
if upload_file:
test_sample = Image.open(upload_file)
sim_ids = fetch_similar(test_sample)
for id in sim_ids:
images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"]))
images.insert(0, (test_sample,'original'))
print(images)
plot_images(images)
st.write("")