import streamlit as st
from transformers import VisionEncoderDecoderModel, GPT2Tokenizer
import torch
from PIL import Image
from torchvision import transforms

# Load model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("ashok2216/vit-gpt2-image-captioning_COCO_FineTuned")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Define manual preprocessing
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Streamlit app setup
st.title("Image Captioning with ViT-GPT2")
st.write("Upload an image to generate a caption.")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_column_width=True)
    
    # Preprocess the image manually
    inputs = preprocess(image).unsqueeze(0)  # Add batch dimension
    
    # Generate the caption
    with st.spinner("Generating caption..."):
        output = model.generate(inputs)
        caption = tokenizer.decode(output[0], skip_special_tokens=True)
    
    st.success("Generated Caption:")
    st.write(caption)