File size: 3,804 Bytes
977f5f8
 
 
 
 
 
 
 
 
e655ddc
f1e762d
8c3254f
f1e762d
977f5f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c17fdc
 
e655ddc
0c17fdc
977f5f8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import List, Optional, Union, Literal
from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image as PILImage
import torch
import base64
import io
import os
from starlette.responses import FileResponse

app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json")

# Initialize model and processor
MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO"
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device)  # Use float16 with low CPU memory usage
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.")
        device = "cpu"  # Switch to CPU
        model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device)  # Load in float32 on CPU with low CPU mem usage
        import gc
        gc.collect()
        torch.cuda.empty_cache()
    else:
        raise e

processor = AutoProcessor.from_pretrained(MODEL_NAME)

# Pydantic models
class ImageUrl(BaseModel):
    url: str

class Image(BaseModel):
    type: Literal["image_url"] = "image_url"
    image_url: ImageUrl

class Content(BaseModel):
    type: Literal["text", "image_url"]
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

class Message(BaseModel):
    role: Literal["user", "system", "assistant"]
    content: Union[str, List[Content]]

class ChatCompletionRequest(BaseModel):
    messages: List[Message]
    max_tokens: Optional[int] = 128

@app.post("/chat/completions")
async def chat_completion(request: ChatCompletionRequest = Body(...)):
    # Extract first message content
    messages = request.messages
    max_tokens = request.max_tokens

    first_message = messages[0]
    image_url = None
    text_content = None

    if isinstance(first_message.content, str):
        text_content = first_message.content
    else:
        for content_item in first_message.content:
            if content_item.type == "image_url":
                image_url = content_item.image_url.url
            elif content_item.type == "text":
                text_content = content_item.text

    # Process image if provided
    pil_image = None
    if image_url:
        try:
            if image_url.startswith("data:image"):
                header, encoded = image_url.split(",", 1)
                image_data = base64.b64decode(encoded)
                pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB")
            else:
                print("Image URL provided, but base64 expected.")
        except Exception as e:
            print(f"Error processing image: {e}")
            raise e

    # Generate response
    try:
        inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device)
        outputs = model.generate(**inputs, max_new_tokens=max_tokens)
        response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    except Exception as e:
        print(f"Error during model inference: {e}")
        raise e

    return {
        "choices": [{
            "message": {
                "role": "assistant",
                "content": response
            }
        }]
    }

@app.get("/")
def index():
    return FileResponse("static/index.html")

@app.on_event("startup")
def startup_event():
    # In Hugging Face Spaces, the application is usually accessible at https://<space_name>.hf.space
    # Here we assume the space name is 'api-UI-TARS-7B-DPO'
    public_url = "https://api-UI-TARS-7B-DPO.hf.space"
    print(f"Public URL: {public_url}")