sourabhbargi11's picture
Update app.py
cf6cf6f verified
raw
history blame
2.56 kB
from transformers import AutoTokenizer, AutoModel ,AutoConfig
import torch
from transformers import ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast
import PIL
import streamlit as st
from PIL import Image
import trasformer
def set_page_config():
st.set_page_config(
page_title='Caption an Cartoon Image',
page_icon=':camera:',
layout='wide',
)
def initialize_model():
device = 'cpu'
config = AutoConfig.from_pretrained("sourabhbargi11/Caption_generator_model")
model = VisionEncoderDecoderModel.from_pretrained("sourabhbargi11/Caption_generator_model", config=config)
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
return image_processor, model,tokenizer, device
def upload_image():
return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])
def image_preprocess(image):
image = image.resize((224,224))
if image.mode == "L":
image = image.convert("RGB")
return image
def generate_caption(image_processor, model, tokenizer,device, image):
inputs = image_processor(image, return_tensors='pt')
print(inputs)
model.eval()
# Generate caption
with torch.no_grad():
output = model.generate(pixel_values=inputs)
# Decode the generated caption
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
def main():
set_page_config()
st.header("Caption an Image :camera:")
uploaded_image = upload_image()
if uploaded_image is not None:
image = Image.open(uploaded_image)
image = image_preprocess(image)
st.image(image, caption='Your image')
with st.sidebar:
st.divider()
if st.sidebar.button('Generate Caption'):
with st.spinner('Generating caption...'):
image_processor, model,tokenizer, device = initialize_model()
caption = generate_caption(image_processor, model, tokenizer,device, image)
st.header("Caption:")
st.markdown(f'**{caption}**')
if __name__ == '__main__':
main()
# st.markdown("""
# ---
# You are looking at partial tuned model , please JUDGE ME!!! (I am Funny , Sensible , Creative )""")
st.markdown("""
---
You are looking at a partially tuned model. Judge me! (I am Funny and Creative) πŸ˜„πŸŽ¨""")