sourabhbargi11's picture
Update app.py
1af1fd7 verified
raw
history blame
2.65 kB
# !pip install torch
# import torch
import streamlit as st
from PIL import Image
# from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast, VisionEncoderDecoderModel
#from transformers import BlipProcessor, BlipForConditionalGeneration
# Load model directly
from transformers import AutoTokenizer, AutoModel
# tokenizer = AutoTokenizer.from_pretrained("sourabhbargi11/Caption_generator_model")
# model = AutoModel.from_pretrained("sourabhbargi11/Caption_generator_model")
def set_page_config():
st.set_page_config(
page_title='Caption an Cartoon Image',
page_icon=':camera:',
layout='wide',
)
def initialize_model():
device = 'cpu'
# load a fine-tuned image captioning model and corresponding tokenizer and image processor
model = AutoModel.from_pretrained("sourabhbargi11/Caption_generator_model").to(device)
tokenizer = AutoTokenizer.from_pretrained("sourabhbargi11/Caption_generator_model")
image_processor = ViTImageProcessor.from_pretrained("sourabhbargi11/Caption_generator_model")
return image_processor, model,tokenizer, device
def upload_image():
return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])
def image_preprocess(image):
image = image.resize((224,224))
if image.mode == "L":
image = image.convert("RGB")
return image
def generate_caption(processor, model, device, image):
inputs = image_processor (image, return_tensors='pt').to(device)
out = model.generate(**inputs, max_new_tokens=20)
caption = processor.decode(out[0], skip_special_tokens=True)
#caption="im here "
return caption
def main():
set_page_config()
st.header("Caption an Image :camera:")
uploaded_image = upload_image()
if uploaded_image is not None:
image = Image.open(uploaded_image)
image = image_preprocess(image)
st.image(image, caption='Your image')
with st.sidebar:
st.divider()
if st.sidebar.button('Generate Caption'):
with st.spinner('Generating caption...'):
image_processor, model,tokenizer, device = initialize_model()
caption = generate_caption(image_processor, model, device, image)
#caption="im here man"
st.header("Caption:")
st.markdown(f'**{caption}**')
if __name__ == '__main__':
main()
st.markdown("""
---
You are looking at partial finetuned model , please JUDGE ME!!! """)