omar0scarf commited on
Commit
977f5f8
·
1 Parent(s): f1e762d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +106 -4
app.py CHANGED
@@ -1,7 +1,109 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Literal
2
+ from fastapi import FastAPI, Body
3
+ from pydantic import BaseModel
4
+ from transformers import AutoProcessor, AutoModelForVision2Seq
5
+ from PIL import Image as PILImage
6
+ import torch
7
+ import base64
8
+ import io
9
+ import os
10
 
11
  app = FastAPI()
12
 
13
+ # Initialize model and processor
14
+ MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO"
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ try:
18
+ model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device) # Use float16 with low CPU memory usage
19
+ except RuntimeError as e:
20
+ if "CUDA out of memory" in str(e):
21
+ print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.")
22
+ device = "cpu" # Switch to CPU
23
+ model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device) # Load in float32 on CPU with low CPU mem usage
24
+ import gc
25
+ gc.collect()
26
+ torch.cuda.empty_cache()
27
+ else:
28
+ raise e
29
+
30
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
31
+
32
+ # Pydantic models
33
+ class ImageUrl(BaseModel):
34
+ url: str
35
+
36
+ class Image(BaseModel):
37
+ type: Literal["image_url"] = "image_url"
38
+ image_url: ImageUrl
39
+
40
+ class Content(BaseModel):
41
+ type: Literal["text", "image_url"]
42
+ text: Optional[str] = None
43
+ image_url: Optional[ImageUrl] = None
44
+
45
+ class Message(BaseModel):
46
+ role: Literal["user", "system", "assistant"]
47
+ content: Union[str, List[Content]]
48
+
49
+ class ChatCompletionRequest(BaseModel):
50
+ messages: List[Message]
51
+ max_tokens: Optional[int] = 128
52
+
53
+ @app.post("/chat/completions")
54
+ async def chat_completion(request: ChatCompletionRequest = Body(...)):
55
+ # Extract first message content
56
+ messages = request.messages
57
+ max_tokens = request.max_tokens
58
+
59
+ first_message = messages[0]
60
+ image_url = None
61
+ text_content = None
62
+
63
+ if isinstance(first_message.content, str):
64
+ text_content = first_message.content
65
+ else:
66
+ for content_item in first_message.content:
67
+ if content_item.type == "image_url":
68
+ image_url = content_item.image_url.url
69
+ elif content_item.type == "text":
70
+ text_content = content_item.text
71
+
72
+ # Process image if provided
73
+ pil_image = None
74
+ if image_url:
75
+ try:
76
+ if image_url.startswith("data:image"):
77
+ header, encoded = image_url.split(",", 1)
78
+ image_data = base64.b64decode(encoded)
79
+ pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB")
80
+ else:
81
+ print("Image URL provided, but base64 expected.")
82
+ except Exception as e:
83
+ print(f"Error processing image: {e}")
84
+ raise e
85
+
86
+ # Generate response
87
+ try:
88
+ inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device)
89
+ outputs = model.generate(**inputs, max_new_tokens=max_tokens)
90
+ response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
91
+ except Exception as e:
92
+ print(f"Error during model inference: {e}")
93
+ raise e
94
+
95
+ return {
96
+ "choices": [{
97
+ "message": {
98
+ "role": "assistant",
99
+ "content": response
100
+ }
101
+ }]
102
+ }
103
+
104
+ @app.on_event("startup")
105
+ def startup_event():
106
+ # In Hugging Face Spaces, the application is usually accessible at https://<space_name>.hf.space
107
+ # Here we assume the space name is 'api-UI-TARS-7B-DPO'
108
+ public_url = "https://api-UI-TARS-7B-DPO.hf.space"
109
+ print(f"Public URL: {public_url}")