File size: 3,307 Bytes
5a56e41
 
9522bcd
 
 
 
 
 
 
 
 
 
 
0610d82
57231ec
 
5a56e41
 
 
 
 
 
 
 
 
7e847dc
c749f9d
57231ec
 
 
08ec34d
57231ec
 
 
 
7e847dc
82b37de
9522bcd
 
 
 
420f5bf
57231ec
 
 
 
08ec34d
5f69d7b
7e847dc
 
57231ec
 
 
1af1fd7
0959b2d
 
 
 
 
 
 
 
 
 
 
 
57231ec
 
0959b2d
57231ec
 
 
 
 
 
 
 
08ec34d
 
57231ec
 
 
 
 
 
08ec34d
 
 
b26fc84
57231ec
 
 
 
 
 
 
 
 
 
08ec34d
57231ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# !pip install torch
# import torch
from transformers import AutoTokenizer, AutoModel ,AutoConfig
import torch
from transformers import  ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast
import PIL 


# Move model to GPU , depnding on device
device2 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model



import streamlit as st
from PIL import Image
# from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast, VisionEncoderDecoderModel
#from transformers import BlipProcessor, BlipForConditionalGeneration


# Load model directly
from transformers import AutoTokenizer, AutoModel

# tokenizer = AutoTokenizer.from_pretrained("sourabhbargi11/Caption_generator_model")
# model = AutoModel.from_pretrained("sourabhbargi11/Caption_generator_model")



def set_page_config():
    st.set_page_config(
        page_title='Caption an Cartoon Image', 
        page_icon=':camera:', 
        layout='wide',
    )

def initialize_model():
    device = 'cpu'
    config = AutoConfig.from_pretrained("sourabhbargi11/Caption_generator_model")  
    model = VisionEncoderDecoderModel.from_pretrained("sourabhbargi11/Caption_generator_model", config=config).to(device)
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224",device=device)
    return image_processor, model,tokenizer, device

def upload_image():
    return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])

def image_preprocess(image):
    image = image.resize((224,224))
    if image.mode == "L":
        image = image.convert("RGB")
    return image

def generate_caption(processor, model, device, image):
    inputs =  image_processor (image, return_tensors='pt').to(device)
    model.eval()
    # Generate caption
    with torch.no_grad():
      output = model.generate(
          pixel_values=inputs ,
          max_length=1000,  # Adjust the maximum length of the generated caption as needed
          num_beams=4,    # Adjust the number of beams for beam search decoding
          early_stopping=True  # Enable early stopping to stop generation when all beams finished
      )

      # Decode the generated caption
      caption = tokenizer.decode(output[0], skip_special_tokens=True)
    return caption


def main():
    set_page_config()
    st.header("Caption an Image :camera:")

    uploaded_image = upload_image()

    if uploaded_image is not None:
        image = Image.open(uploaded_image)
        image = image_preprocess(image)
        
        st.image(image, caption='Your image')
        
        with st.sidebar:
            st.divider() 
            if st.sidebar.button('Generate Caption'):
                with st.spinner('Generating caption...'):
                    image_processor, model,tokenizer, device = initialize_model()
                    
                    caption = generate_caption(image_processor, model, device, image)
            
                    st.header("Caption:")
                    st.markdown(f'**{caption}**')


if __name__ == '__main__':
    main()


st.markdown("""
    ---
   You are looking at partial finetuned model , please JUDGE ME!!!  """)