Spaces:
Build error
Build error
from __future__ import print_function | |
from src.misc.config import cfg, cfg_from_file | |
from src.dataset import TextDataset | |
from src.trainer import condGANTrainer as trainer | |
import time | |
import random | |
import pprint | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from pathlib import Path | |
import streamlit as st | |
def gen_example(wordtoix, algo, text): | |
"""generate images from example sentences""" | |
from nltk.tokenize import RegexpTokenizer | |
data_dic = {} | |
captions = [] | |
cap_lens = [] | |
sent = text.replace("\ufffd\ufffd", " ") | |
tokenizer = RegexpTokenizer(r"\w+") | |
tokens = tokenizer.tokenize(sent.lower()) | |
rev = [] | |
for t in tokens: | |
t = t.encode("ascii", "ignore").decode("ascii") | |
if len(t) > 0 and t in wordtoix: | |
rev.append(wordtoix[t]) | |
captions.append(rev) | |
cap_lens.append(len(rev)) | |
max_len = np.max(cap_lens) | |
sorted_indices = np.argsort(cap_lens)[::-1] | |
cap_lens = np.asarray(cap_lens) | |
cap_lens = cap_lens[sorted_indices] | |
cap_array = np.zeros((len(captions), max_len), dtype="int64") | |
for i in range(len(captions)): | |
idx = sorted_indices[i] | |
cap = captions[idx] | |
c_len = len(cap) | |
cap_array[i, :c_len] = cap | |
name = "output" | |
key = name[(name.rfind("/") + 1) :] | |
data_dic[key] = [cap_array, cap_lens, sorted_indices] | |
algo.gen_example(data_dic) | |
# streamlit function | |
def center_element(type, text=None, img_path=None): | |
""" | |
Function to center a streamlit element (text, image, etc) | |
""" | |
if type == "image": | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
elif type == "text" or type == "heading": | |
col1, col2, col3 = st.columns([1, 6, 1]) | |
elif type == "subheading": | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
elif type == "title": | |
col1, col2, col3 = st.columns([1, 8, 1]) | |
with col1: | |
st.write("") | |
with col2: | |
if type == "heading": | |
st.header(text) | |
elif type == "title": | |
st.title(text) | |
elif type == "image": | |
st.image(img_path) | |
elif type == "text": | |
st.write(text) | |
elif type == "subheading": | |
st.subheader(text) | |
# else: | |
# raise Exception("Unsupported input type") | |
with col3: | |
st.write("") | |
def demo_gan(): | |
cfg_from_file("eval_bird.yml") | |
# print("Using config:") | |
# pprint.pprint(cfg) | |
cfg.CUDA = False | |
manualSeed = 100 | |
random.seed(manualSeed) | |
np.random.seed(manualSeed) | |
torch.manual_seed(manualSeed) | |
output_dir = "output/" | |
split_dir = "test" | |
bshuffle = True | |
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) | |
image_transform = transforms.Compose( | |
[ | |
transforms.Resize(int(imsize * 76 / 64)), | |
transforms.RandomCrop(imsize), | |
transforms.RandomHorizontalFlip(), | |
] | |
) | |
st.cache(func=TextDataset, persist=True,ttl=10000) | |
dataset = TextDataset( | |
cfg.DATA_DIR, split_dir, base_size=cfg.TREE.BASE_SIZE, transform=image_transform | |
) | |
assert dataset | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=cfg.TRAIN.BATCH_SIZE, | |
drop_last=True, | |
shuffle=bshuffle, | |
num_workers=int(cfg.WORKERS), | |
) | |
# Define models and go to train/evaluate | |
st.cache( | |
func=trainer, persist=True, suppress_st_warning=True,ttl=10000 | |
) | |
algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword) | |
st.title("Text To Image Generator ") | |
st.subheader("Enter the description of the bird in the text box you like !!!") | |
st.write( | |
"**Example**: A yellow bird with red crown, black short beak and long tail" | |
) | |
st.markdown("**PS**: The synthesized birds might not even exist on earth ") | |
st.markdown("#") | |
user_input = st.text_input("Write the bird description below") | |
st.markdown("---") | |
if user_input: | |
start_t = time.time() | |
# generate images for customized captions | |
gen_example(dataset.wordtoix, algo, text=user_input) | |
end_t = time.time() | |
print("Total time for training:", end_t - start_t) | |
st.write(f"**Your input**: {user_input}") | |
center_element(type="subheading", text="AttnGAN synthesized bird") | |
st.text("") | |
center_element( | |
type="image", img_path="models/bird_AttnGAN2/output/0_s_0_g2.png" | |
) | |
center_element(type="subheading", text="The attention given for each word") | |
st.image("models/bird_AttnGAN2/output/0_s_0_a1.png") | |
st.markdown("---") | |
with st.expander("Click to see the first stage images"): | |
st.write("First stage image") | |
st.image("models/bird_AttnGAN2/output/0_s_0_g1.png") | |
st.write("First stage attention") | |
st.image("models/bird_AttnGAN2/output/0_s_0_a0.png") | |
def attngan_explained(): | |
# center_element(type="heading", text="AttnGAN: Fine-Grained Text To Image Generation with Attentional Generative Adverserial Networks") | |
st.header( | |
"**AttnGAN**: Fine-Grained Text To Image Generation with Attentional Generative Adverserial Networks" | |
) | |
from attngan_explanation import attngan_explanation | |
attngan_explanation() | |