sunheycho commited on
Commit
ddb7f6a
·
1 Parent(s): d7ccb89

feat(lora-compare): add SSE endpoints for LLaMA LoRA comparison; wire frontend component; build and copy static assets

Browse files
Files changed (2) hide show
  1. api.py +146 -0
  2. frontend/package-lock.json +5 -4
api.py CHANGED
@@ -1253,6 +1253,152 @@ def stream_product_comparison(session_id):
1253
  }
1254
  )
1255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1256
  @app.route('/api/search-similar-objects', methods=['POST'])
1257
  @require_auth()
1258
  def search_similar_objects():
 
1253
  }
1254
  )
1255
 
1256
+ # ============================
1257
+ # LLM LoRA Compare Endpoints
1258
+ # ============================
1259
+
1260
+ # Simple in-memory session store for LoRA compare
1261
+ lora_sessions = {}
1262
+
1263
+ def lora_add_message(session_id, message, msg_type="info"):
1264
+ sess = lora_sessions.get(session_id)
1265
+ if not sess:
1266
+ return
1267
+ ts = time.strftime('%Y-%m-%d %H:%M:%S')
1268
+ sess['messages'].append({
1269
+ 'message': message,
1270
+ 'type': msg_type,
1271
+ 'timestamp': ts
1272
+ })
1273
+
1274
+ @app.route('/api/llama/compare/start', methods=['POST'])
1275
+ @require_auth()
1276
+ def start_llama_lora_compare():
1277
+ """Start a LoRA-vs-Base comparison session (text or image+text prompt)."""
1278
+ session_id = request.form.get('session_id') or str(uuid.uuid4())
1279
+ prompt = request.form.get('prompt', '')
1280
+ base_model_id = request.form.get('baseModel', 'meta-llama/Llama-3.1-8B-Instruct')
1281
+ lora_path = request.form.get('loraPath', '')
1282
+ image_b64 = None
1283
+ if 'image' in request.files:
1284
+ try:
1285
+ img = Image.open(request.files['image'].stream).convert('RGB')
1286
+ buffer = BytesIO()
1287
+ img.save(buffer, format='PNG')
1288
+ image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
1289
+ except Exception as _e:
1290
+ pass
1291
+
1292
+ # Initialize session
1293
+ lora_sessions[session_id] = {
1294
+ 'status': 'processing',
1295
+ 'messages': [],
1296
+ 'result': None,
1297
+ }
1298
+ lora_add_message(session_id, 'LoRA comparison started', 'system')
1299
+
1300
+ def worker():
1301
+ try:
1302
+ lora_add_message(session_id, f"Base model: {base_model_id}")
1303
+ if lora_path:
1304
+ lora_add_message(session_id, f"LoRA adapter: {lora_path}")
1305
+ else:
1306
+ lora_add_message(session_id, "No LoRA adapter provided; using mock output.")
1307
+
1308
+ # Prepare prompt
1309
+ full_prompt = prompt or 'Describe the content.'
1310
+ if image_b64:
1311
+ lora_add_message(session_id, 'Image provided; running vision+language prompt.')
1312
+
1313
+ # Run base inference (best-effort)
1314
+ start_base = time.time()
1315
+ base_output = None
1316
+ try:
1317
+ if llm_model is not None and llm_tokenizer is not None:
1318
+ inputs = llm_tokenizer(full_prompt, return_tensors='pt').to(device)
1319
+ with torch.no_grad():
1320
+ out = llm_model.generate(**inputs, max_new_tokens=128, temperature=0.7, top_p=0.9)
1321
+ text = llm_tokenizer.decode(out[0], skip_special_tokens=True)
1322
+ # strip prompt prefix
1323
+ if text.startswith(full_prompt):
1324
+ text = text[len(full_prompt):].strip()
1325
+ base_output = text
1326
+ else:
1327
+ base_output = f"[mock] Base response for: {full_prompt[:80]}..."
1328
+ except Exception as e:
1329
+ base_output = f"[error] Base inference failed: {e}"
1330
+ base_latency = int((time.time() - start_base) * 1000)
1331
+ lora_add_message(session_id, f"Base inference done in {base_latency} ms")
1332
+
1333
+ # Run LoRA inference (mock unless PEFT is integrated)
1334
+ start_lora = time.time()
1335
+ try:
1336
+ if lora_path and llm_model is not None and llm_tokenizer is not None:
1337
+ # Placeholder: in real integration, load LoRA via PEFT and run generate
1338
+ lora_output = f"[mock-lora:{lora_path}] {base_output}"
1339
+ else:
1340
+ lora_output = f"[mock] LoRA response (no adapter) for: {full_prompt[:80]}..."
1341
+ except Exception as e:
1342
+ lora_output = f"[error] LoRA inference failed: {e}"
1343
+ lora_latency = int((time.time() - start_lora) * 1000)
1344
+ lora_add_message(session_id, f"LoRA inference done in {lora_latency} ms")
1345
+
1346
+ lora_sessions[session_id]['result'] = {
1347
+ 'prompt': full_prompt,
1348
+ 'image': image_b64,
1349
+ 'base': { 'output': base_output, 'latency_ms': base_latency },
1350
+ 'lora': { 'output': lora_output, 'latency_ms': lora_latency },
1351
+ }
1352
+ lora_sessions[session_id]['status'] = 'completed'
1353
+ lora_add_message(session_id, 'Comparison completed', 'system')
1354
+ except Exception as e:
1355
+ lora_sessions[session_id]['status'] = 'error'
1356
+ lora_sessions[session_id]['result'] = {
1357
+ 'error': str(e)
1358
+ }
1359
+ lora_add_message(session_id, f"Error: {e}", 'error')
1360
+
1361
+ Thread(target=worker, daemon=True).start()
1362
+ return jsonify({ 'session_id': session_id, 'status': 'processing' })
1363
+
1364
+
1365
+ @app.route('/api/llama/compare/stream/<session_id>', methods=['GET'])
1366
+ @require_auth()
1367
+ def stream_llama_lora_compare(session_id):
1368
+ """SSE stream for LoRA comparison progress and final result."""
1369
+ def generate():
1370
+ last_idx = 0
1371
+ retries = 0
1372
+ max_retries = 300
1373
+ while retries < max_retries:
1374
+ sess = lora_sessions.get(session_id)
1375
+ if not sess:
1376
+ yield f"data: {json.dumps({'error': 'Session not found'})}\n\n"
1377
+ break
1378
+ msgs = sess['messages']
1379
+ if len(msgs) > last_idx:
1380
+ for m in msgs[last_idx:]:
1381
+ yield f"data: {json.dumps(m)}\n\n"
1382
+ last_idx = len(msgs)
1383
+ yield f"data: {json.dumps({'status': sess['status']})}\n\n"
1384
+ if sess['status'] in ('completed', 'error'):
1385
+ yield f"data: {json.dumps({'final_result': sess['result']})}\n\n"
1386
+ break
1387
+ time.sleep(1)
1388
+ retries += 1
1389
+ if retries >= max_retries:
1390
+ yield f"data: {json.dumps({'error': 'Timeout waiting for results'})}\n\n"
1391
+
1392
+ return Response(
1393
+ stream_with_context(generate()),
1394
+ mimetype='text/event-stream',
1395
+ headers={
1396
+ 'Cache-Control': 'no-cache',
1397
+ 'X-Accel-Buffering': 'no',
1398
+ 'Content-Type': 'text/event-stream',
1399
+ }
1400
+ )
1401
+
1402
  @app.route('/api/search-similar-objects', methods=['POST'])
1403
  @require_auth()
1404
  def search_similar_objects():
frontend/package-lock.json CHANGED
@@ -17497,16 +17497,17 @@
17497
  "integrity": "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA=="
17498
  },
17499
  "node_modules/typescript": {
17500
- "version": "5.9.2",
17501
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.2.tgz",
17502
- "integrity": "sha512-CWBzXQrc/qOkhidw1OzBTQuYRbfyxDXJMVJ1XNwUHGROVmuaeiEm3OslpZ1RV96d7SKKjZKrSJu3+t/xlw3R9A==",
 
17503
  "peer": true,
17504
  "bin": {
17505
  "tsc": "bin/tsc",
17506
  "tsserver": "bin/tsserver"
17507
  },
17508
  "engines": {
17509
- "node": ">=14.17"
17510
  }
17511
  },
17512
  "node_modules/unbox-primitive": {
 
17497
  "integrity": "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA=="
17498
  },
17499
  "node_modules/typescript": {
17500
+ "version": "3.9.10",
17501
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-3.9.10.tgz",
17502
+ "integrity": "sha512-w6fIxVE/H1PkLKcCPsFqKE7Kv7QUwhU8qQY2MueZXWx5cPZdwFupLgKK3vntcK98BtNHZtAF4LA/yl2a7k8R6Q==",
17503
+ "license": "Apache-2.0",
17504
  "peer": true,
17505
  "bin": {
17506
  "tsc": "bin/tsc",
17507
  "tsserver": "bin/tsserver"
17508
  },
17509
  "engines": {
17510
+ "node": ">=4.2.0"
17511
  }
17512
  },
17513
  "node_modules/unbox-primitive": {