sksstudio commited on
Commit
4a73fad
·
1 Parent(s): 5401975
Files changed (3) hide show
  1. app.py +106 -42
  2. image.png +0 -0
  3. requirements.txt +5 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # app.py
2
- from fastapi import FastAPI, HTTPException, UploadFile, File
 
3
  from pydantic import BaseModel
4
  from llama_cpp import Llama
5
  from typing import Optional
@@ -9,6 +10,11 @@ import os
9
  from PIL import Image
10
  import io
11
  import base64
 
 
 
 
 
12
 
13
  app = FastAPI(
14
  title="OmniVLM API",
@@ -16,20 +22,39 @@ app = FastAPI(
16
  version="1.0.0"
17
  )
18
 
19
- # Download the model from Hugging Face Hub
20
- model_path = huggingface_hub.hf_hub_download(
21
- repo_id="NexaAIDev/OmniVLM-968M",
22
- filename="omnivision-text-optimized-llm-Q8_0.gguf"
 
 
 
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Initialize the model with the downloaded file
26
- llm = Llama(
27
- model_path=model_path,
28
- n_ctx=2048,
29
- n_threads=4,
30
- n_batch=512,
31
- verbose=True
32
- )
 
 
 
 
 
33
 
34
  class GenerationRequest(BaseModel):
35
  prompt: str
@@ -37,13 +62,16 @@ class GenerationRequest(BaseModel):
37
  temperature: Optional[float] = 0.7
38
  top_p: Optional[float] = 0.9
39
 
40
- class ImageRequest(BaseModel):
41
- prompt: Optional[str] = "Describe this image in detail"
42
- max_tokens: Optional[int] = 200
43
- temperature: Optional[float] = 0.7
44
-
45
  class GenerationResponse(BaseModel):
46
  generated_text: str
 
 
 
 
 
 
 
 
47
 
48
  @app.post("/generate", response_model=GenerationResponse)
49
  async def generate_text(request: GenerationRequest):
@@ -57,44 +85,80 @@ async def generate_text(request: GenerationRequest):
57
 
58
  return GenerationResponse(generated_text=output["choices"][0]["text"])
59
  except Exception as e:
60
- raise HTTPException(status_code=500, detail=str(e))
 
61
 
62
  @app.post("/process-image", response_model=GenerationResponse)
63
  async def process_image(
64
  file: UploadFile = File(...),
65
- request: ImageRequest = None
 
 
66
  ):
67
  try:
68
- # Read and validate the image
69
- image_data = await file.read()
70
- image = Image.open(io.BytesIO(image_data))
71
-
72
- # Convert image to base64
73
- buffered = io.BytesIO()
74
- image.save(buffered, format=image.format or "JPEG")
75
- img_str = base64.b64encode(buffered.getvalue()).decode()
76
 
77
- # Create prompt with image
78
- prompt = f"""
79
- <image>data:image/jpeg;base64,{img_str}</image>
80
- {request.prompt if request else "Describe this image in detail"}
81
- """
82
 
83
- # Generate description
84
- output = llm(
85
- prompt,
86
- max_tokens=request.max_tokens if request else 200,
87
- temperature=request.temperature if request else 0.7
88
- )
89
 
90
- return GenerationResponse(generated_text=output["choices"][0]["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
- raise HTTPException(status_code=500, detail=str(e))
 
93
 
94
  @app.get("/health")
95
  async def health_check():
96
- return {"status": "healthy"}
 
 
 
97
 
98
  if __name__ == "__main__":
99
  port = int(os.environ.get("PORT", 7860))
100
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
  # app.py
2
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from typing import Optional
 
10
  from PIL import Image
11
  import io
12
  import base64
13
+ import logging
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
 
19
  app = FastAPI(
20
  title="OmniVLM API",
 
22
  version="1.0.0"
23
  )
24
 
25
+ # Add CORS middleware
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
  )
33
 
34
+ # Download the model from Hugging Face Hub
35
+ try:
36
+ model_path = huggingface_hub.hf_hub_download(
37
+ repo_id="NexaAIDev/OmniVLM-968M",
38
+ filename="omnivision-text-optimized-llm-Q8_0.gguf"
39
+ )
40
+ logger.info(f"Model downloaded successfully to {model_path}")
41
+ except Exception as e:
42
+ logger.error(f"Error downloading model: {e}")
43
+ raise
44
+
45
  # Initialize the model with the downloaded file
46
+ try:
47
+ llm = Llama(
48
+ model_path=model_path,
49
+ n_ctx=2048,
50
+ n_threads=4,
51
+ n_batch=512,
52
+ verbose=True
53
+ )
54
+ logger.info("Model initialized successfully")
55
+ except Exception as e:
56
+ logger.error(f"Error initializing model: {e}")
57
+ raise
58
 
59
  class GenerationRequest(BaseModel):
60
  prompt: str
 
62
  temperature: Optional[float] = 0.7
63
  top_p: Optional[float] = 0.9
64
 
 
 
 
 
 
65
  class GenerationResponse(BaseModel):
66
  generated_text: str
67
+ error: Optional[str] = None
68
+
69
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
70
+ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
71
+
72
+ def allowed_file(filename):
73
+ return '.' in filename and \
74
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
75
 
76
  @app.post("/generate", response_model=GenerationResponse)
77
  async def generate_text(request: GenerationRequest):
 
85
 
86
  return GenerationResponse(generated_text=output["choices"][0]["text"])
87
  except Exception as e:
88
+ logger.error(f"Error in text generation: {e}")
89
+ return GenerationResponse(generated_text="", error=str(e))
90
 
91
  @app.post("/process-image", response_model=GenerationResponse)
92
  async def process_image(
93
  file: UploadFile = File(...),
94
+ prompt: str = Form("Describe this image in detail"),
95
+ max_tokens: int = Form(200),
96
+ temperature: float = Form(0.7)
97
  ):
98
  try:
99
+ # Validate file size
100
+ file_size = 0
101
+ file_content = await file.read()
102
+ file_size = len(file_content)
 
 
 
 
103
 
104
+ if file_size > MAX_IMAGE_SIZE:
105
+ raise HTTPException(status_code=400, detail="File too large")
 
 
 
106
 
107
+ # Validate file type
108
+ if not allowed_file(file.filename):
109
+ raise HTTPException(status_code=400, detail="File type not allowed")
 
 
 
110
 
111
+ # Process image
112
+ try:
113
+ image = Image.open(io.BytesIO(file_content))
114
+
115
+ # Convert image to RGB if necessary
116
+ if image.mode != 'RGB':
117
+ image = image.convert('RGB')
118
+
119
+ # Resize image if too large
120
+ max_size = (1024, 1024)
121
+ if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
122
+ image.thumbnail(max_size, Image.Resampling.LANCZOS)
123
+
124
+ # Convert to base64
125
+ buffered = io.BytesIO()
126
+ image.save(buffered, format="JPEG", quality=85)
127
+ img_str = base64.b64encode(buffered.getvalue()).decode()
128
+
129
+ # Create prompt with image
130
+ full_prompt = f"""
131
+ <image>data:image/jpeg;base64,{img_str}</image>
132
+ {prompt}
133
+ """
134
+
135
+ logger.info("Processing image with prompt")
136
+ # Generate description
137
+ output = llm(
138
+ full_prompt,
139
+ max_tokens=max_tokens,
140
+ temperature=temperature
141
+ )
142
+
143
+ return GenerationResponse(generated_text=output["choices"][0]["text"])
144
+
145
+ except Exception as e:
146
+ logger.error(f"Error processing image: {e}")
147
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
148
+
149
+ except HTTPException as he:
150
+ raise he
151
  except Exception as e:
152
+ logger.error(f"Unexpected error: {e}")
153
+ return GenerationResponse(generated_text="", error=str(e))
154
 
155
  @app.get("/health")
156
  async def health_check():
157
+ return {
158
+ "status": "healthy",
159
+ "model_loaded": llm is not None
160
+ }
161
 
162
  if __name__ == "__main__":
163
  port = int(os.environ.get("PORT", 7860))
164
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
image.png ADDED
requirements.txt CHANGED
@@ -2,4 +2,8 @@ fastapi==0.104.1
2
  uvicorn==0.24.0
3
  pydantic==2.4.2
4
  llama-cpp-python>=0.2.20
5
- huggingface-hub>=0.19.0
 
 
 
 
 
2
  uvicorn==0.24.0
3
  pydantic==2.4.2
4
  llama-cpp-python>=0.2.20
5
+ huggingface-hub>=0.19.0
6
+ python-multipart>=0.0.6 # FastAPI file upload support
7
+ pillow>=10.0.0 # Image processing
8
+ requests>=2.31.0 # HTTP requests
9
+ python-dotenv>=1.0.0 # Environment variables management