gnilets commited on
Commit
b71e69d
·
verified ·
1 Parent(s): 0935932

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncio import gather
2
+ from base64 import b64decode
3
+ from binascii import Error as BinasciiError
4
+ from contextlib import asynccontextmanager
5
+ from io import BytesIO
6
+ from logging import Formatter, INFO, StreamHandler, getLogger
7
+ from pathlib import Path
8
+ from random import randint
9
+ from typing import AsyncGenerator
10
+ from uuid import UUID
11
+
12
+ from PIL.Image import open as image_open
13
+ from fastapi import FastAPI, Request
14
+ from fastapi.responses import HTMLResponse, JSONResponse
15
+ from httpx import AsyncClient
16
+ from starlette.responses import Response
17
+
18
+ logger = getLogger('NVIDIA_VLM_API')
19
+ logger.setLevel(INFO)
20
+ handler = StreamHandler()
21
+ handler.setLevel(INFO)
22
+ formatter = Formatter('%(asctime)s | %(levelname)s : %(message)s', datefmt='%d.%m.%Y %H:%M:%S')
23
+ handler.setFormatter(formatter)
24
+ logger.addHandler(handler)
25
+
26
+ logger.info('инициализация приложения...')
27
+
28
+ INVOKE_URLS = [
29
+ 'https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct',
30
+ 'https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b',
31
+ 'https://ai.api.nvidia.com/v1/vlm/nvidia/vila'
32
+ ]
33
+
34
+ ASSETS_URL = 'https://api.nvcf.nvidia.com/v2/nvcf/assets'
35
+
36
+
37
+ def get_extension(filename: str) -> str:
38
+ return Path(filename).suffix[1:].lower()
39
+
40
+
41
+ async def upload_asset(client: AsyncClient, media_file_bytes: bytes, description: str, api_key: str) -> UUID:
42
+ headers = {
43
+ 'Authorization': f'Bearer {api_key}',
44
+ 'Content-Type': 'application/json',
45
+ 'accept': 'application/json',
46
+ }
47
+
48
+ authorize_response = await client.post(
49
+ ASSETS_URL,
50
+ headers=headers,
51
+ json={'contentType': 'image/jpeg', 'description': description},
52
+ timeout=30,
53
+ )
54
+ authorize_response.raise_for_status()
55
+ authorize_res = authorize_response.json()
56
+ response = await client.put(
57
+ authorize_res.get('uploadUrl'),
58
+ content=media_file_bytes,
59
+ headers={'x-amz-meta-nvcf-asset-description': description, 'content-type': 'image/jpeg'},
60
+ timeout=300,
61
+ )
62
+
63
+ response.raise_for_status()
64
+ return UUID(authorize_res.get('assetId'))
65
+
66
+
67
+ async def delete_asset(client: AsyncClient, asset_id: UUID, api_key: str) -> None:
68
+ headers = {'Authorization': f'Bearer {api_key}'}
69
+ response = await client.delete(f'{ASSETS_URL}/{asset_id}', headers=headers, timeout=30)
70
+ response.raise_for_status()
71
+
72
+
73
+ async def chat_with_media_nvcf(infer_url: str, media_file_bytes: bytes, query: str, api_key: str) -> str | None:
74
+ try:
75
+ async with AsyncClient(follow_redirects=True, timeout=45) as client:
76
+ asset_list = []
77
+ asset_id = await upload_asset(client, media_file_bytes, 'Reference media file', api_key)
78
+ asset_list.append(str(asset_id))
79
+ media_content = f'<img src="data:image/jpeg;asset_id,{asset_id}" />'
80
+ asset_seq = ','.join(asset_list)
81
+ headers = {
82
+ 'Authorization': f'Bearer {api_key}',
83
+ 'Content-Type': 'application/json',
84
+ 'NVCF-INPUT-ASSET-REFERENCES': asset_seq,
85
+ 'NVCF-FUNCTION-ASSET-IDS': asset_seq,
86
+ 'Accept': 'application/json',
87
+ }
88
+
89
+ payload = {
90
+ 'max_tokens': 1024,
91
+ 'temperature': 0.65,
92
+ 'top_p': 0.95,
93
+ 'seed': randint(0, 999999999),
94
+ 'messages': [{'role': 'user', 'content': f'{query} {media_content}'}],
95
+ 'stream': False,
96
+ "model": infer_url.split('/v1/vlm/')[-1]
97
+ }
98
+
99
+ response = await client.post(infer_url, headers=headers, json=payload)
100
+ response_json = response.json()
101
+ answer = response_json.get('choices', [{}])[0].get('message', {}).get('content', None)
102
+ for asset_id in asset_list:
103
+ await delete_asset(client, UUID(asset_id), api_key)
104
+ return answer
105
+ except Exception as exc:
106
+ print(exc)
107
+ return None
108
+
109
+
110
+ def base64_to_jpeg_bytes(base64_str: str) -> bytes:
111
+ try:
112
+ if ',' not in base64_str:
113
+ raise ValueError('недопустимый формат строки base64')
114
+ base64_data = base64_str.split(',', 1)[1]
115
+ binary_data = b64decode(base64_data)
116
+ with image_open(BytesIO(binary_data)) as img:
117
+ with BytesIO() as jpeg_bytes:
118
+ img.convert('RGB').save(jpeg_bytes, format='JPEG', quality=90, optimize=True)
119
+ return jpeg_bytes.getvalue()
120
+ except (BinasciiError, OSError) as e:
121
+ raise ValueError('данные не являются корректным изображением') from e
122
+
123
+
124
+ async def get_captions(image_base64_str: str, query: str, api_key: str) -> dict[str, str]:
125
+ media_file_bytes = base64_to_jpeg_bytes(image_base64_str)
126
+ tasks = [chat_with_media_nvcf(url, media_file_bytes, query, api_key) for url in INVOKE_URLS]
127
+ results = await gather(*tasks)
128
+ return dict(zip((url.split('/v1/vlm/')[-1] for url in INVOKE_URLS), results))
129
+
130
+
131
+ @asynccontextmanager
132
+ async def app_lifespan(_) -> AsyncGenerator:
133
+ logger.info('запуск приложения')
134
+ try:
135
+ logger.info('старт API')
136
+ yield
137
+ finally:
138
+ logger.info('приложение завершено')
139
+
140
+
141
+ app = FastAPI(lifespan=app_lifespan, title='NVIDIA_VLM_API')
142
+
143
+ banned_endpoints = [
144
+ '/openapi.json',
145
+ '/docs',
146
+ '/docs/oauth2-redirect',
147
+ 'swagger_ui_redirect',
148
+ '/redoc',
149
+ ]
150
+
151
+
152
+ @app.middleware('http')
153
+ async def block_banned_endpoints(request: Request, call_next):
154
+ logger.debug(f'получен запрос: {request.url.path}')
155
+ if request.url.path in banned_endpoints:
156
+ logger.warning(f'запрещенный endpoint: {request.url.path}')
157
+ return Response(status_code=403)
158
+ response = await call_next(request)
159
+ return response
160
+
161
+
162
+ @app.post('/v1/describe')
163
+ async def describe_v1(request: Request):
164
+ logger.info('запрос `describe_v1`')
165
+ body = await request.json()
166
+ headers = request.headers
167
+ authorization: str = headers.get('Authorization') or headers.get('authorization')
168
+ nvapi_key = authorization.removeprefix('Bearer ').strip()
169
+ if not authorization or not nvapi_key:
170
+ return JSONResponse({'caption': 'в запросе нужно передать заголовок `Authorization: Bearer <NVAPI_KEY>`'}, status_code=401)
171
+
172
+ content_text = ''
173
+ image_data = ''
174
+
175
+ messages = body.get('messages', [])
176
+ for message in messages:
177
+ role = message.get('role')
178
+ content = message.get('content')
179
+
180
+ if role in ['system', 'user']:
181
+ if isinstance(content, str):
182
+ content_text += content + ' '
183
+ elif isinstance(content, list):
184
+ for item in content:
185
+ if item.get('type') == 'text':
186
+ content_text += item.get('text', '') + ' '
187
+ elif item.get('type') == 'image_url':
188
+ image_url = item.get('image_url', {})
189
+ url = image_url.get('url')
190
+ if url and url.startswith('data:image/'):
191
+ image_data = url
192
+ image_data, content_text = image_data.strip(), content_text.strip()
193
+
194
+ if not content_text or not image_data:
195
+ return JSONResponse({'caption': 'изображение должно быть передано как строка base64 `data:image/jpeg;base64,{base64_img}` а также текст'}, status_code=400)
196
+ try:
197
+ return JSONResponse(await get_captions(image_data, content_text, nvapi_key), status_code=200)
198
+ except Exception as e:
199
+ return JSONResponse({'caption': str(e)}, status_code=500)
200
+
201
+
202
+ @app.get('/')
203
+ async def root():
204
+ return HTMLResponse('ну пролапс, ну и что', status_code=200)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ from uvicorn import run as uvicorn_run
209
+
210
+ logger.info('запуск сервера uvicorn')
211
+ uvicorn_run(app, host='0.0.0.0', port=7860)