Boni98 commited on
Commit
32f9449
·
1 Parent(s): 560ffab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoProcessor
5
+
6
+ # Model and processor initialization
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
9
+ model = torch.load("../finetunned_blipv2_epoch_5_loss_0.4936.pth").to(DEVICE)
10
+ model.eval()
11
+
12
+
13
+ def caption_image(image: Image.Image) -> str:
14
+ """
15
+ Takes in an image and returns its caption using the trained model.
16
+ """
17
+ image = image.convert("RGB")
18
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
19
+ pixel_values = inputs.pixel_values
20
+
21
+ with torch.no_grad():
22
+ generated_ids = model.generate(
23
+ pixel_values=pixel_values,
24
+ max_length=256
25
+ )
26
+
27
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
+ return generated_caption
29
+
30
+
31
+ # Gradio interface
32
+ interface = gr.Interface(
33
+ fn=caption_image, # function to call
34
+ inputs=gr.Image(type="pil"), # Image input
35
+ outputs=gr.Textbox(), # Textbox output
36
+ title="Image Captioning with BLIP-2 and LoRa",
37
+ description=("<div style='text-align: center; padding: 10px; border: 2px solid #FFC107; border-radius: 10px;'>"
38
+ "<p>Welcome to our <strong>state-of-the-art</strong> image captioning tool!</p>"
39
+ "<p>We combine the strengths of the <em>BLIP-2</em> model with <em>LoRa</em> to provide precise image captions.</p>"
40
+ "<p>Our rich dataset has been labeled using multi-modal models. Upload an image to see its caption!</p></div>"),
41
+ article=("<div style='text-align: center; padding: 10px; background-color: #E3F2FD; border-radius: 10px;'>"
42
+ "<a href='https://diegobonilla98.github.io/PixLore/' style='color: #1976D2; font-weight: bold;'>GitHub Project</a></div>"),
43
+ live=True,
44
+ layout="vertical"
45
+ )
46
+
47
+
48
+ if __name__ == '__main__':
49
+ interface.launch()