Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| File: vlm.py | |
| Description: Vision language model utility functions. | |
| Author: Didier Guillevic | |
| Date: 2025-05-08 | |
| """ | |
| from transformers import AutoProcessor | |
| from transformers import Mistral3ForConditionalGeneration | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import re | |
| import time | |
| import torch | |
| import spaces | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # | |
| # Load the model: OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym | |
| # | |
| model_id = "OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Mistral3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| #_attn_implementation="flash_attention_2", | |
| torch_dtype=torch.float16 | |
| ).eval().to(device) | |
| # | |
| # Encode images as base64 | |
| # | |
| def encode_image(image_path): | |
| """Encode the image to base64.""" | |
| try: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| except FileNotFoundError: | |
| print(f"Error: The file {image_path} was not found.") | |
| return None | |
| except Exception as e: # Added general exception handling | |
| print(f"Error: {e}") | |
| return None | |
| # | |
| # Build messages | |
| # | |
| def build_messages(message: dict, history: list[dict]): | |
| """Build messages given message & history from a **multimodal** chat interface. | |
| Args: | |
| message: dictionary with keys: 'text', 'files' | |
| history: list of dictionaries | |
| Returns: | |
| list of messages (to be sent to the model) | |
| """ | |
| logger.info(f"{message=}") | |
| logger.info(f"{history=}") | |
| # Get the user's text and list of images | |
| user_text = message.get("text", "") | |
| user_images = message.get("files", []) # List of images | |
| # Build the user message's content from the provided message | |
| user_content = [] | |
| if user_text: | |
| user_content.append({"type": "text", "text": user_text}) | |
| for image in user_images: | |
| user_content.append( | |
| { | |
| "type": "image_url", | |
| "image_url": f"data:image/jpeg;base64,{encode_image(image)}" | |
| } | |
| ) | |
| # Append to the history to create the new messages | |
| messages = history | |
| messages.append({'role': 'user', 'content': user_content}) | |
| logger.info(f"{messages=}") | |
| return messages | |
| # | |
| # stream response | |
| # | |
| def stream_response( | |
| messages: list[dict], | |
| max_new_tokens: int=1_024, | |
| temperature: float=0.15 | |
| ): | |
| """Stream the model's response to the chat interface. | |
| Args: | |
| messages: list of messages to send to the model | |
| """ | |
| # Generate model's response | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.float16) | |
| # Generate | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_args = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_args) | |
| thread.start() | |
| partial_message = "" | |
| for new_text in streamer: | |
| partial_message += new_text | |
| yield partial_message | |