parthraninga commited on
Commit
c77c778
·
verified ·
1 Parent(s): 58a3ef2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -285
app.py CHANGED
@@ -1,285 +1,30 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
- import torch
5
- import torch.nn.functional as F
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from PIL import Image
8
- import io
9
- import numpy as np
10
- from typing import List, Dict, Any
11
- import logging
12
- import os
13
-
14
- # Configure logging
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
- app = FastAPI(
19
- title="ChatGPT Oasis Model Inference API",
20
- description="FastAPI inference server for Oasis and ViT models deployed on Hugging Face Spaces with Docker",
21
- version="1.0.0"
22
- )
23
-
24
- # Global variables to store loaded models
25
- oasis_model = None
26
- oasis_processor = None
27
- vit_model = None
28
- vit_processor = None
29
-
30
- class InferenceRequest(BaseModel):
31
- image: str # Base64 encoded image
32
- model_name: str = "oasis500m" # Default to oasis model
33
-
34
- class InferenceResponse(BaseModel):
35
- predictions: List[Dict[str, Any]]
36
- model_used: str
37
- confidence_scores: List[float]
38
-
39
- def load_models():
40
- """Load both models from local files"""
41
- global oasis_model, oasis_processor, vit_model, vit_processor
42
-
43
- try:
44
- logger.info("Loading Oasis 500M model from local files...")
45
- # Load Oasis model from local files
46
- oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m")
47
- oasis_model = AutoModelForImageClassification.from_pretrained(
48
- "microsoft/oasis-500m",
49
- local_files_only=False # Will download config but use local weights
50
- )
51
-
52
- # Load local weights if available
53
- oasis_model_path = "/app/models/oasis500m.safetensors"
54
- if os.path.exists(oasis_model_path):
55
- logger.info("Loading Oasis weights from local file...")
56
- from safetensors.torch import load_file
57
- state_dict = load_file(oasis_model_path)
58
- oasis_model.load_state_dict(state_dict, strict=False)
59
-
60
- oasis_model.eval()
61
-
62
- logger.info("Loading ViT-L-20 model from local files...")
63
- # Load ViT model from local files
64
- vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224")
65
- vit_model = AutoModelForImageClassification.from_pretrained(
66
- "google/vit-large-patch16-224",
67
- local_files_only=False # Will download config but use local weights
68
- )
69
-
70
- # Load local weights if available
71
- vit_model_path = "/app/models/vit-l-20.safetensors"
72
- if os.path.exists(vit_model_path):
73
- logger.info("Loading ViT weights from local file...")
74
- from safetensors.torch import load_file
75
- state_dict = load_file(vit_model_path)
76
- vit_model.load_state_dict(state_dict, strict=False)
77
-
78
- vit_model.eval()
79
-
80
- logger.info("All models loaded successfully!")
81
-
82
- except Exception as e:
83
- logger.error(f"Error loading models: {e}")
84
- raise e
85
-
86
- @app.on_event("startup")
87
- async def startup_event():
88
- """Load models when the application starts"""
89
- load_models()
90
-
91
- @app.get("/")
92
- async def root():
93
- """Root endpoint with API information"""
94
- return {
95
- "message": "ChatGPT Oasis Model Inference API",
96
- "version": "1.0.0",
97
- "deployed_on": "Hugging Face Spaces (Docker)",
98
- "available_models": ["oasis500m", "vit-l-20"],
99
- "endpoints": {
100
- "health": "/health",
101
- "inference": "/inference",
102
- "upload_inference": "/upload_inference",
103
- "predict": "/predict"
104
- },
105
- "usage": {
106
- "base64_inference": "POST /inference with JSON body containing 'image' (base64) and 'model_name'",
107
- "file_upload": "POST /upload_inference with multipart form containing 'file' and optional 'model_name'",
108
- "simple_predict": "POST /predict with file upload for quick inference"
109
- }
110
- }
111
-
112
- @app.get("/health")
113
- async def health_check():
114
- """Health check endpoint"""
115
- models_status = {
116
- "oasis500m": oasis_model is not None,
117
- "vit-l-20": vit_model is not None
118
- }
119
-
120
- # Check if model files exist
121
- model_files = {
122
- "oasis500m": os.path.exists("/app/models/oasis500m.safetensors"),
123
- "vit-l-20": os.path.exists("/app/models/vit-l-20.safetensors")
124
- }
125
-
126
- return {
127
- "status": "healthy",
128
- "models_loaded": models_status,
129
- "model_files_present": model_files,
130
- "deployment": "huggingface-spaces-docker"
131
- }
132
-
133
- def process_image_with_model(image: Image.Image, model_name: str):
134
- """Process image with the specified model"""
135
- if model_name == "oasis500m":
136
- if oasis_model is None or oasis_processor is None:
137
- raise HTTPException(status_code=500, detail="Oasis model not loaded")
138
-
139
- inputs = oasis_processor(images=image, return_tensors="pt")
140
- with torch.no_grad():
141
- outputs = oasis_model(**inputs)
142
- logits = outputs.logits
143
- probabilities = F.softmax(logits, dim=-1)
144
-
145
- # Get top predictions
146
- top_probs, top_indices = torch.topk(probabilities, 5)
147
-
148
- predictions = []
149
- for i in range(top_indices.shape[1]):
150
- pred = {
151
- "label": oasis_model.config.id2label[top_indices[0][i].item()],
152
- "confidence": top_probs[0][i].item()
153
- }
154
- predictions.append(pred)
155
-
156
- return predictions
157
-
158
- elif model_name == "vit-l-20":
159
- if vit_model is None or vit_processor is None:
160
- raise HTTPException(status_code=500, detail="ViT model not loaded")
161
-
162
- inputs = vit_processor(images=image, return_tensors="pt")
163
- with torch.no_grad():
164
- outputs = vit_model(**inputs)
165
- logits = outputs.logits
166
- probabilities = F.softmax(logits, dim=-1)
167
-
168
- # Get top predictions
169
- top_probs, top_indices = torch.topk(probabilities, 5)
170
-
171
- predictions = []
172
- for i in range(top_indices.shape[1]):
173
- pred = {
174
- "label": vit_model.config.id2label[top_indices[0][i].item()],
175
- "confidence": top_probs[0][i].item()
176
- }
177
- predictions.append(pred)
178
-
179
- return predictions
180
-
181
- else:
182
- raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
183
-
184
- @app.post("/inference", response_model=InferenceResponse)
185
- async def inference(request: InferenceRequest):
186
- """Inference endpoint using base64 encoded image"""
187
- try:
188
- import base64
189
-
190
- # Decode base64 image
191
- image_data = base64.b64decode(request.image)
192
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
193
-
194
- # Process with model
195
- predictions = process_image_with_model(image, request.model_name)
196
-
197
- # Extract confidence scores
198
- confidence_scores = [pred["confidence"] for pred in predictions]
199
-
200
- return InferenceResponse(
201
- predictions=predictions,
202
- model_used=request.model_name,
203
- confidence_scores=confidence_scores
204
- )
205
-
206
- except Exception as e:
207
- logger.error(f"Inference error: {e}")
208
- raise HTTPException(status_code=500, detail=str(e))
209
-
210
- @app.post("/upload_inference", response_model=InferenceResponse)
211
- async def upload_inference(
212
- file: UploadFile = File(...),
213
- model_name: str = "oasis500m"
214
- ):
215
- """Inference endpoint using file upload"""
216
- try:
217
- # Validate file type
218
- if not file.content_type.startswith('image/'):
219
- raise HTTPException(status_code=400, detail="File must be an image")
220
-
221
- # Read and process image
222
- image_data = await file.read()
223
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
224
-
225
- # Process with model
226
- predictions = process_image_with_model(image, model_name)
227
-
228
- # Extract confidence scores
229
- confidence_scores = [pred["confidence"] for pred in predictions]
230
-
231
- return InferenceResponse(
232
- predictions=predictions,
233
- model_used=model_name,
234
- confidence_scores=confidence_scores
235
- )
236
-
237
- except Exception as e:
238
- logger.error(f"Upload inference error: {e}")
239
- raise HTTPException(status_code=500, detail=str(e))
240
-
241
- @app.get("/models")
242
- async def list_models():
243
- """List available models and their status"""
244
- return {
245
- "available_models": [
246
- {
247
- "name": "oasis500m",
248
- "description": "Oasis 500M vision model",
249
- "loaded": oasis_model is not None,
250
- "file_present": os.path.exists("/app/models/oasis500m.safetensors")
251
- },
252
- {
253
- "name": "vit-l-20",
254
- "description": "Vision Transformer Large model",
255
- "loaded": vit_model is not None,
256
- "file_present": os.path.exists("/app/models/vit-l-20.safetensors")
257
- }
258
- ]
259
- }
260
-
261
- # Hugging Face Spaces specific endpoint for Gradio compatibility
262
- @app.post("/predict")
263
- async def predict(file: UploadFile = File(...)):
264
- """Simple prediction endpoint for Hugging Face Spaces integration"""
265
- try:
266
- # Validate file type
267
- if not file.content_type.startswith('image/'):
268
- raise HTTPException(status_code=400, detail="File must be an image")
269
-
270
- # Read and process image
271
- image_data = await file.read()
272
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
273
-
274
- # Process with default model (oasis500m)
275
- predictions = process_image_with_model(image, "oasis500m")
276
-
277
- # Return simplified format for Gradio
278
- return {
279
- "predictions": predictions[:3], # Top 3 predictions
280
- "model_used": "oasis500m"
281
- }
282
-
283
- except Exception as e:
284
- logger.error(f"Predict error: {e}")
285
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ # Load model & tokenizer
9
+ MODEL_PATH = "./" # since it's inside the same repo
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_PATH,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+
17
+ class RequestBody(BaseModel):
18
+ prompt: str
19
+ max_length: int = 100
20
+
21
+ @app.post("/generate")
22
+ def generate_text(req: RequestBody):
23
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
24
+ outputs = model.generate(**inputs, max_length=req.max_length)
25
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return {"generated_text": text}
27
+
28
+ @app.get("/")
29
+ def root():
30
+ return {"message": "FastAPI Hugging Face Space is running!"}