File size: 2,403 Bytes
9522bcd
 
 
 
57231ec
 
c749f9d
57231ec
 
 
08ec34d
57231ec
 
 
 
7e847dc
82b37de
9522bcd
cf6cf6f
9522bcd
cf6cf6f
420f5bf
57231ec
 
 
 
08ec34d
5f69d7b
7e847dc
 
57231ec
 
0c24f96
cf6cf6f
b654e77
 
 
57231ec
 
0959b2d
57231ec
 
 
 
 
 
 
 
08ec34d
 
57231ec
 
 
 
 
 
08ec34d
0c24f96
b26fc84
57231ec
 
 
 
 
 
27a4a72
 
 
 
57231ec
 
27a4a72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoTokenizer, AutoModel ,AutoConfig
import torch
from transformers import  ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast
import PIL 
import streamlit as st
from PIL import Image


def set_page_config():
    st.set_page_config(
        page_title='Caption an Cartoon Image', 
        page_icon=':camera:', 
        layout='wide',
    )

def initialize_model():
    device = 'cpu'
    config = AutoConfig.from_pretrained("sourabhbargi11/Caption_generator_model")  
    model = VisionEncoderDecoderModel.from_pretrained("sourabhbargi11/Caption_generator_model", config=config)
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    return image_processor, model,tokenizer, device

def upload_image():
    return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])

def image_preprocess(image):
    image = image.resize((224,224))
    if image.mode == "L":
        image = image.convert("RGB")
    return image

def generate_caption(image_processor, model, tokenizer,device, image):
    inputs = image_processor(image, return_tensors='pt')
    output = model.generate(**inputs)

    caption = tokenizer.decode(output[0], skip_special_tokens=True)
    return caption


def main():
    set_page_config()
    st.header("Caption an Image :camera:")

    uploaded_image = upload_image()

    if uploaded_image is not None:
        image = Image.open(uploaded_image)
        image = image_preprocess(image)
        
        st.image(image, caption='Your image')
        
        with st.sidebar:
            st.divider() 
            if st.sidebar.button('Generate Caption'):
                with st.spinner('Generating caption...'):
                    image_processor, model,tokenizer, device = initialize_model()
                    caption = generate_caption(image_processor, model, tokenizer,device, image)
            
                    st.header("Caption:")
                    st.markdown(f'**{caption}**')

if __name__ == '__main__':
    main()

# st.markdown("""
#     ---
#    You are looking at partial tuned model , please JUDGE ME!!!  (I am Funny , Sensible , Creative )""")

st.markdown("""
    ---
   You are looking at a partially tuned model. Judge me! (I am Funny and Creative) πŸ˜„πŸŽ¨""")