Radixpert / handler.py
abdurafeyf's picture
added handler.py
cf633a8 verified
import base64
import io
import json
import torch
from unsloth import FastVisionModel
from PIL import Image
# Global variables to hold the model and tokenizer.
model = None
tokenizer = None
def initialize():
"""
Called once when the model is loaded.
Loads the model and tokenizer from the pretrained checkpoint
and prepares the model for inference.
"""
global model, tokenizer
model, tokenizer = FastVisionModel.from_pretrained(
"abdurafeyf/Radixpert",
device_map="cuda"
)
FastVisionModel.for_inference(model)
def inference(payload):
"""
Expects a payload that is either a dict or a JSON string with the following format:
{
"data": {
"image": "<base64-encoded image string>",
"instruction": "<text instruction>"
}
}
The function decodes the image, applies the chat template to the instruction,
tokenizes both image and text, runs the model's generate method, and returns
the generated text as output.
"""
global model, tokenizer
try:
# If payload is a JSON string, decode it.
if isinstance(payload, str):
payload = json.loads(payload)
data = payload.get("data")
if data is None:
return {"error": "Missing 'data' in payload."}
image_b64 = data.get("image")
instruction = data.get("instruction")
if image_b64 is None or instruction is None:
return {"error": "Both 'image' and 'instruction' are required in the payload."}
# Decode the base64-encoded image and load it.
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Construct the chat messages as expected by the tokenizer.
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": instruction}
]
}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Tokenize both image and text inputs.
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to("cuda")
# Generate output tokens.
outputs = model.generate(
**inputs,
max_new_tokens=128,
use_cache=True,
temperature=1.5,
min_p=0.1
)
# Decode the tokens to obtain the generated text.
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"output": output_text}
except Exception as e:
return {"error": str(e)}
# Optional: For local testing of the handler.
if __name__ == "__main__":
# Run initialization.
initialize()
# Example payload (you can replace with an actual base64-encoded image string).
sample_payload = {
"data": {
"image": "", # Insert a valid base64-encoded image string here.
"instruction": (
"You are an expert radiologist. Describe accurately in detail like a radiology report "
"what you see in this X-Ray Scan of a Chest."
)
}
}
result = inference(sample_payload)
print(result)