File size: 1,940 Bytes
32f9449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor

# Model and processor initialization
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = torch.load("../finetunned_blipv2_epoch_5_loss_0.4936.pth").to(DEVICE)
model.eval()


def caption_image(image: Image.Image) -> str:
    """
    Takes in an image and returns its caption using the trained model.
    """
    image = image.convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)
    pixel_values = inputs.pixel_values

    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=pixel_values,
            max_length=256
        )

    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_caption


# Gradio interface
interface = gr.Interface(
    fn=caption_image,  # function to call
    inputs=gr.Image(type="pil"),  # Image input
    outputs=gr.Textbox(),  # Textbox output
    title="Image Captioning with BLIP-2 and LoRa",
    description=("<div style='text-align: center; padding: 10px; border: 2px solid #FFC107; border-radius: 10px;'>"
                 "<p>Welcome to our <strong>state-of-the-art</strong> image captioning tool!</p>"
                 "<p>We combine the strengths of the <em>BLIP-2</em> model with <em>LoRa</em> to provide precise image captions.</p>"
                 "<p>Our rich dataset has been labeled using multi-modal models. Upload an image to see its caption!</p></div>"),
    article=("<div style='text-align: center; padding: 10px; background-color: #E3F2FD; border-radius: 10px;'>"
             "<a href='https://diegobonilla98.github.io/PixLore/' style='color: #1976D2; font-weight: bold;'>GitHub Project</a></div>"),
    live=True,
    layout="vertical"
)


if __name__ == '__main__':
    interface.launch()