disohugface commited on
Commit
4bf7c22
·
verified ·
1 Parent(s): abec680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -2,21 +2,46 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from transformers import ColPaliForRetrieval, ColPaliProcessor
 
5
 
6
  model_name = "vidore/colpali-v1.3-hf"
7
  model = ColPaliForRetrieval.from_pretrained(model_name, torch_dtype=torch.float32).eval()
8
  processor = ColPaliProcessor.from_pretrained(model_name)
9
 
10
  def process_image(image):
 
 
 
 
11
  inputs = processor(images=image, return_tensors="pt")
 
 
12
  with torch.no_grad():
13
  outputs = model(**inputs)
14
- return outputs.embeddings.squeeze().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
  demo = gr.Interface(
17
  fn=process_image,
18
  inputs=gr.Image(type="pil"),
19
- outputs="json",
 
 
20
  )
21
 
 
22
  demo.launch()
 
2
  import torch
3
  from PIL import Image
4
  from transformers import ColPaliForRetrieval, ColPaliProcessor
5
+ import numpy as np
6
 
7
  model_name = "vidore/colpali-v1.3-hf"
8
  model = ColPaliForRetrieval.from_pretrained(model_name, torch_dtype=torch.float32).eval()
9
  processor = ColPaliProcessor.from_pretrained(model_name)
10
 
11
  def process_image(image):
12
+ # Ensure the image is in RGB format
13
+ image = image.convert('RGB')
14
+
15
+ # Process the image
16
  inputs = processor(images=image, return_tensors="pt")
17
+
18
+ # Generate embeddings
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
+
22
+ # Extract embeddings and convert to list
23
+ embeddings = outputs.embeddings.squeeze().cpu().numpy().tolist()
24
+
25
+ # Truncate the embeddings for display purposes
26
+ truncated_embeddings = embeddings[:10] # Show only first 10 values
27
+
28
+ # Prepare the output
29
+ output = {
30
+ "embedding_sample": truncated_embeddings,
31
+ "embedding_length": len(embeddings),
32
+ "embedding_shape": list(np.array(embeddings).shape)
33
+ }
34
+
35
+ return output
36
 
37
+ # Create Gradio interface
38
  demo = gr.Interface(
39
  fn=process_image,
40
  inputs=gr.Image(type="pil"),
41
+ outputs=gr.JSON(),
42
+ title="ColPali Image Embedding Generator",
43
+ description="Upload an image to generate its embedding using the ColPali model."
44
  )
45
 
46
+ # Launch the interface
47
  demo.launch()