sitammeur commited on
Commit
4bd7e06
·
verified ·
1 Parent(s): 61c2cca

Update src/paligemma/response.py

Browse files
Files changed (1) hide show
  1. src/paligemma/response.py +72 -71
src/paligemma/response.py CHANGED
@@ -1,71 +1,72 @@
1
- # Necessary imports
2
- import sys
3
- import PIL.Image
4
- import torch
5
- import gradio as gr
6
- import spaces
7
-
8
- # Local imports
9
- from src.config import device, model_name
10
- from src.paligemma.model import load_model_and_processor
11
- from src.logger import logging
12
- from src.exception import CustomExceptionHandling
13
-
14
-
15
- # Language dictionary
16
- language_dict = {
17
- "English": "en",
18
- "Spanish": "es",
19
- "French": "fr",
20
- }
21
-
22
- # Model and processor
23
- model, processor = load_model_and_processor(model_name, device)
24
-
25
-
26
- @spaces.GPU
27
- def caption_image(image: PIL.Image.Image, max_new_tokens: int, language: str) -> str:
28
- """
29
- Generates a caption based on the given image using the model.
30
-
31
- Args:
32
- - image (PIL.Image.Image): The input image to be processed.
33
- - max_new_tokens (int): The maximum number of new tokens to generate.
34
- - language (str): The language of the generated caption.
35
-
36
- Returns:
37
- str: The generated caption text.
38
- """
39
- try:
40
- # Check if image is None
41
- if not image:
42
- gr.Warning("Please provide an image.")
43
-
44
- # Prepare the inputs
45
- language = language_dict[language]
46
- prompt = f"<image>caption {language}"
47
- model_inputs = (
48
- processor(text=prompt, images=image, return_tensors="pt")
49
- .to(torch.bfloat16)
50
- .to(device)
51
- )
52
- input_len = model_inputs["input_ids"].shape[-1]
53
-
54
- # Generate the response
55
- with torch.inference_mode():
56
- generation = model.generate(
57
- **model_inputs, max_new_tokens=max_new_tokens, do_sample=False
58
- )
59
- generation = generation[0][input_len:]
60
- decoded = processor.decode(generation, skip_special_tokens=True)
61
-
62
- # Log the successful generation of the caption
63
- logging.info("Caption generated successfully.")
64
-
65
- # Return the generated caption
66
- return decoded
67
-
68
- # Handle exceptions that may occur during caption generation
69
- except Exception as e:
70
- # Custom exception handling
71
- raise CustomExceptionHandling(e, sys) from e
 
 
1
+ # Necessary imports
2
+ import sys
3
+ import PIL.Image
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ # Local imports
9
+ from src.config import device, model_name
10
+ from src.paligemma.model import load_model_and_processor
11
+ from src.logger import logging
12
+ from src.exception import CustomExceptionHandling
13
+
14
+
15
+ # Language dictionary
16
+ language_dict = {
17
+ "English": "en",
18
+ "Spanish": "es",
19
+ "French": "fr",
20
+ }
21
+
22
+ # Model and processor
23
+ model, processor = load_model_and_processor(model_name, device)
24
+
25
+
26
+ @spaces.GPU
27
+ def caption_image(image: PIL.Image.Image, max_new_tokens: int, language: str) -> str:
28
+ """
29
+ Generates a caption based on the given image using the model.
30
+
31
+ Args:
32
+ - image (PIL.Image.Image): The input image to be processed.
33
+ - max_new_tokens (int): The maximum number of new tokens to generate.
34
+ - language (str): The language of the generated caption.
35
+
36
+ Returns:
37
+ str: The generated caption text.
38
+ """
39
+ try:
40
+ # Check if image is None
41
+ if not image:
42
+ gr.Warning("Please provide an image.")
43
+
44
+ # Prepare the inputs
45
+ language = language_dict[language]
46
+ print(language)
47
+ prompt = f"<image>caption {language}"
48
+ model_inputs = (
49
+ processor(text=prompt, images=image, return_tensors="pt")
50
+ .to(torch.bfloat16)
51
+ .to(device)
52
+ )
53
+ input_len = model_inputs["input_ids"].shape[-1]
54
+
55
+ # Generate the response
56
+ with torch.inference_mode():
57
+ generation = model.generate(
58
+ **model_inputs, max_new_tokens=max_new_tokens, do_sample=False
59
+ )
60
+ generation = generation[0][input_len:]
61
+ decoded = processor.decode(generation, skip_special_tokens=True)
62
+
63
+ # Log the successful generation of the caption
64
+ logging.info("Caption generated successfully.")
65
+
66
+ # Return the generated caption
67
+ return decoded
68
+
69
+ # Handle exceptions that may occur during caption generation
70
+ except Exception as e:
71
+ # Custom exception handling
72
+ raise CustomExceptionHandling(e, sys) from e