Spaces:
Running
Running
sunheycho
commited on
Commit
·
ddb7f6a
1
Parent(s):
d7ccb89
feat(lora-compare): add SSE endpoints for LLaMA LoRA comparison; wire frontend component; build and copy static assets
Browse files- api.py +146 -0
- frontend/package-lock.json +5 -4
api.py
CHANGED
@@ -1253,6 +1253,152 @@ def stream_product_comparison(session_id):
|
|
1253 |
}
|
1254 |
)
|
1255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1256 |
@app.route('/api/search-similar-objects', methods=['POST'])
|
1257 |
@require_auth()
|
1258 |
def search_similar_objects():
|
|
|
1253 |
}
|
1254 |
)
|
1255 |
|
1256 |
+
# ============================
|
1257 |
+
# LLM LoRA Compare Endpoints
|
1258 |
+
# ============================
|
1259 |
+
|
1260 |
+
# Simple in-memory session store for LoRA compare
|
1261 |
+
lora_sessions = {}
|
1262 |
+
|
1263 |
+
def lora_add_message(session_id, message, msg_type="info"):
|
1264 |
+
sess = lora_sessions.get(session_id)
|
1265 |
+
if not sess:
|
1266 |
+
return
|
1267 |
+
ts = time.strftime('%Y-%m-%d %H:%M:%S')
|
1268 |
+
sess['messages'].append({
|
1269 |
+
'message': message,
|
1270 |
+
'type': msg_type,
|
1271 |
+
'timestamp': ts
|
1272 |
+
})
|
1273 |
+
|
1274 |
+
@app.route('/api/llama/compare/start', methods=['POST'])
|
1275 |
+
@require_auth()
|
1276 |
+
def start_llama_lora_compare():
|
1277 |
+
"""Start a LoRA-vs-Base comparison session (text or image+text prompt)."""
|
1278 |
+
session_id = request.form.get('session_id') or str(uuid.uuid4())
|
1279 |
+
prompt = request.form.get('prompt', '')
|
1280 |
+
base_model_id = request.form.get('baseModel', 'meta-llama/Llama-3.1-8B-Instruct')
|
1281 |
+
lora_path = request.form.get('loraPath', '')
|
1282 |
+
image_b64 = None
|
1283 |
+
if 'image' in request.files:
|
1284 |
+
try:
|
1285 |
+
img = Image.open(request.files['image'].stream).convert('RGB')
|
1286 |
+
buffer = BytesIO()
|
1287 |
+
img.save(buffer, format='PNG')
|
1288 |
+
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
1289 |
+
except Exception as _e:
|
1290 |
+
pass
|
1291 |
+
|
1292 |
+
# Initialize session
|
1293 |
+
lora_sessions[session_id] = {
|
1294 |
+
'status': 'processing',
|
1295 |
+
'messages': [],
|
1296 |
+
'result': None,
|
1297 |
+
}
|
1298 |
+
lora_add_message(session_id, 'LoRA comparison started', 'system')
|
1299 |
+
|
1300 |
+
def worker():
|
1301 |
+
try:
|
1302 |
+
lora_add_message(session_id, f"Base model: {base_model_id}")
|
1303 |
+
if lora_path:
|
1304 |
+
lora_add_message(session_id, f"LoRA adapter: {lora_path}")
|
1305 |
+
else:
|
1306 |
+
lora_add_message(session_id, "No LoRA adapter provided; using mock output.")
|
1307 |
+
|
1308 |
+
# Prepare prompt
|
1309 |
+
full_prompt = prompt or 'Describe the content.'
|
1310 |
+
if image_b64:
|
1311 |
+
lora_add_message(session_id, 'Image provided; running vision+language prompt.')
|
1312 |
+
|
1313 |
+
# Run base inference (best-effort)
|
1314 |
+
start_base = time.time()
|
1315 |
+
base_output = None
|
1316 |
+
try:
|
1317 |
+
if llm_model is not None and llm_tokenizer is not None:
|
1318 |
+
inputs = llm_tokenizer(full_prompt, return_tensors='pt').to(device)
|
1319 |
+
with torch.no_grad():
|
1320 |
+
out = llm_model.generate(**inputs, max_new_tokens=128, temperature=0.7, top_p=0.9)
|
1321 |
+
text = llm_tokenizer.decode(out[0], skip_special_tokens=True)
|
1322 |
+
# strip prompt prefix
|
1323 |
+
if text.startswith(full_prompt):
|
1324 |
+
text = text[len(full_prompt):].strip()
|
1325 |
+
base_output = text
|
1326 |
+
else:
|
1327 |
+
base_output = f"[mock] Base response for: {full_prompt[:80]}..."
|
1328 |
+
except Exception as e:
|
1329 |
+
base_output = f"[error] Base inference failed: {e}"
|
1330 |
+
base_latency = int((time.time() - start_base) * 1000)
|
1331 |
+
lora_add_message(session_id, f"Base inference done in {base_latency} ms")
|
1332 |
+
|
1333 |
+
# Run LoRA inference (mock unless PEFT is integrated)
|
1334 |
+
start_lora = time.time()
|
1335 |
+
try:
|
1336 |
+
if lora_path and llm_model is not None and llm_tokenizer is not None:
|
1337 |
+
# Placeholder: in real integration, load LoRA via PEFT and run generate
|
1338 |
+
lora_output = f"[mock-lora:{lora_path}] {base_output}"
|
1339 |
+
else:
|
1340 |
+
lora_output = f"[mock] LoRA response (no adapter) for: {full_prompt[:80]}..."
|
1341 |
+
except Exception as e:
|
1342 |
+
lora_output = f"[error] LoRA inference failed: {e}"
|
1343 |
+
lora_latency = int((time.time() - start_lora) * 1000)
|
1344 |
+
lora_add_message(session_id, f"LoRA inference done in {lora_latency} ms")
|
1345 |
+
|
1346 |
+
lora_sessions[session_id]['result'] = {
|
1347 |
+
'prompt': full_prompt,
|
1348 |
+
'image': image_b64,
|
1349 |
+
'base': { 'output': base_output, 'latency_ms': base_latency },
|
1350 |
+
'lora': { 'output': lora_output, 'latency_ms': lora_latency },
|
1351 |
+
}
|
1352 |
+
lora_sessions[session_id]['status'] = 'completed'
|
1353 |
+
lora_add_message(session_id, 'Comparison completed', 'system')
|
1354 |
+
except Exception as e:
|
1355 |
+
lora_sessions[session_id]['status'] = 'error'
|
1356 |
+
lora_sessions[session_id]['result'] = {
|
1357 |
+
'error': str(e)
|
1358 |
+
}
|
1359 |
+
lora_add_message(session_id, f"Error: {e}", 'error')
|
1360 |
+
|
1361 |
+
Thread(target=worker, daemon=True).start()
|
1362 |
+
return jsonify({ 'session_id': session_id, 'status': 'processing' })
|
1363 |
+
|
1364 |
+
|
1365 |
+
@app.route('/api/llama/compare/stream/<session_id>', methods=['GET'])
|
1366 |
+
@require_auth()
|
1367 |
+
def stream_llama_lora_compare(session_id):
|
1368 |
+
"""SSE stream for LoRA comparison progress and final result."""
|
1369 |
+
def generate():
|
1370 |
+
last_idx = 0
|
1371 |
+
retries = 0
|
1372 |
+
max_retries = 300
|
1373 |
+
while retries < max_retries:
|
1374 |
+
sess = lora_sessions.get(session_id)
|
1375 |
+
if not sess:
|
1376 |
+
yield f"data: {json.dumps({'error': 'Session not found'})}\n\n"
|
1377 |
+
break
|
1378 |
+
msgs = sess['messages']
|
1379 |
+
if len(msgs) > last_idx:
|
1380 |
+
for m in msgs[last_idx:]:
|
1381 |
+
yield f"data: {json.dumps(m)}\n\n"
|
1382 |
+
last_idx = len(msgs)
|
1383 |
+
yield f"data: {json.dumps({'status': sess['status']})}\n\n"
|
1384 |
+
if sess['status'] in ('completed', 'error'):
|
1385 |
+
yield f"data: {json.dumps({'final_result': sess['result']})}\n\n"
|
1386 |
+
break
|
1387 |
+
time.sleep(1)
|
1388 |
+
retries += 1
|
1389 |
+
if retries >= max_retries:
|
1390 |
+
yield f"data: {json.dumps({'error': 'Timeout waiting for results'})}\n\n"
|
1391 |
+
|
1392 |
+
return Response(
|
1393 |
+
stream_with_context(generate()),
|
1394 |
+
mimetype='text/event-stream',
|
1395 |
+
headers={
|
1396 |
+
'Cache-Control': 'no-cache',
|
1397 |
+
'X-Accel-Buffering': 'no',
|
1398 |
+
'Content-Type': 'text/event-stream',
|
1399 |
+
}
|
1400 |
+
)
|
1401 |
+
|
1402 |
@app.route('/api/search-similar-objects', methods=['POST'])
|
1403 |
@require_auth()
|
1404 |
def search_similar_objects():
|
frontend/package-lock.json
CHANGED
@@ -17497,16 +17497,17 @@
|
|
17497 |
"integrity": "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA=="
|
17498 |
},
|
17499 |
"node_modules/typescript": {
|
17500 |
-
"version": "
|
17501 |
-
"resolved": "https://registry.npmjs.org/typescript/-/typescript-
|
17502 |
-
"integrity": "sha512-
|
|
|
17503 |
"peer": true,
|
17504 |
"bin": {
|
17505 |
"tsc": "bin/tsc",
|
17506 |
"tsserver": "bin/tsserver"
|
17507 |
},
|
17508 |
"engines": {
|
17509 |
-
"node": ">=
|
17510 |
}
|
17511 |
},
|
17512 |
"node_modules/unbox-primitive": {
|
|
|
17497 |
"integrity": "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA=="
|
17498 |
},
|
17499 |
"node_modules/typescript": {
|
17500 |
+
"version": "3.9.10",
|
17501 |
+
"resolved": "https://registry.npmjs.org/typescript/-/typescript-3.9.10.tgz",
|
17502 |
+
"integrity": "sha512-w6fIxVE/H1PkLKcCPsFqKE7Kv7QUwhU8qQY2MueZXWx5cPZdwFupLgKK3vntcK98BtNHZtAF4LA/yl2a7k8R6Q==",
|
17503 |
+
"license": "Apache-2.0",
|
17504 |
"peer": true,
|
17505 |
"bin": {
|
17506 |
"tsc": "bin/tsc",
|
17507 |
"tsserver": "bin/tsserver"
|
17508 |
},
|
17509 |
"engines": {
|
17510 |
+
"node": ">=4.2.0"
|
17511 |
}
|
17512 |
},
|
17513 |
"node_modules/unbox-primitive": {
|