Spaces:
Running
Running
David Ko
commited on
Commit
·
637dbbd
1
Parent(s):
96aa590
feat(vision-rag): add LangChain deps and verify .venv; vision_rag_query uses ChatOpenAI; set default OPENAI_MODEL to gpt-4o
Browse files- api.py +131 -0
- requirements.txt +4 -0
api.py
CHANGED
@@ -23,6 +23,15 @@ try:
|
|
23 |
from openai import OpenAI
|
24 |
except Exception as _e:
|
25 |
OpenAI = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
from flask_login import (
|
27 |
LoginManager,
|
28 |
UserMixin,
|
@@ -1582,6 +1591,128 @@ def openai_chat_api():
|
|
1582 |
'latency_sec': latency
|
1583 |
})
|
1584 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1585 |
@app.route('/api/status', methods=['GET'])
|
1586 |
@fresh_login_required
|
1587 |
def status():
|
|
|
23 |
from openai import OpenAI
|
24 |
except Exception as _e:
|
25 |
OpenAI = None
|
26 |
+
try:
|
27 |
+
# LangChain for RAG answering
|
28 |
+
from langchain_openai import ChatOpenAI
|
29 |
+
from langchain_core.prompts import ChatPromptTemplate
|
30 |
+
from langchain_core.output_parsers import StrOutputParser
|
31 |
+
except Exception as _e:
|
32 |
+
ChatOpenAI = None
|
33 |
+
ChatPromptTemplate = None
|
34 |
+
StrOutputParser = None
|
35 |
from flask_login import (
|
36 |
LoginManager,
|
37 |
UserMixin,
|
|
|
1591 |
'latency_sec': latency
|
1592 |
})
|
1593 |
|
1594 |
+
@app.route('/api/vision-rag/query', methods=['POST'])
|
1595 |
+
@login_required
|
1596 |
+
def vision_rag_query():
|
1597 |
+
"""Vision RAG endpoint.
|
1598 |
+
Expects JSON with one of the following query modes and a user question:
|
1599 |
+
- { userQuery, searchType: 'image', image, n_results? }
|
1600 |
+
- { userQuery, searchType: 'object', objectId, n_results? }
|
1601 |
+
- { userQuery, searchType: 'class', class_name, n_results? }
|
1602 |
+
Returns: { answer, retrieved: [...], model, latency_sec }
|
1603 |
+
"""
|
1604 |
+
if ChatOpenAI is None:
|
1605 |
+
return jsonify({"error": "LangChain not installed on server"}), 500
|
1606 |
+
|
1607 |
+
data = request.get_json(silent=True) or {}
|
1608 |
+
user_query = (data.get('userQuery') or '').strip()
|
1609 |
+
if not user_query:
|
1610 |
+
return jsonify({"error": "Missing 'userQuery'"}), 400
|
1611 |
+
|
1612 |
+
api_key = data.get('api_key') or os.environ.get('OPENAI_API_KEY')
|
1613 |
+
if not api_key:
|
1614 |
+
return jsonify({"error": "Missing OpenAI API key. Provide in request or set OPENAI_API_KEY env."}), 400
|
1615 |
+
|
1616 |
+
search_type = data.get('searchType', 'image')
|
1617 |
+
n_results = int(data.get('n_results', 5))
|
1618 |
+
|
1619 |
+
# Build query embedding or filtered fetch similar to /api/search-similar-objects
|
1620 |
+
results = None
|
1621 |
+
try:
|
1622 |
+
if search_type == 'image' and 'image' in data:
|
1623 |
+
image_data = data['image']
|
1624 |
+
if isinstance(image_data, str) and image_data.startswith('data:image'):
|
1625 |
+
image_data = image_data.split(',')[1]
|
1626 |
+
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB')
|
1627 |
+
query_embedding = generate_image_embedding(image)
|
1628 |
+
if query_embedding is None:
|
1629 |
+
return jsonify({"error": "Failed to generate image embedding"}), 500
|
1630 |
+
results = object_collection.query(
|
1631 |
+
query_embeddings=[query_embedding],
|
1632 |
+
n_results=n_results,
|
1633 |
+
include=["metadatas", "distances"]
|
1634 |
+
) if object_collection is not None else None
|
1635 |
+
elif search_type == 'object' and 'objectId' in data:
|
1636 |
+
obj_id = data['objectId']
|
1637 |
+
base = object_collection.get(ids=[obj_id], include=["embeddings"]) if object_collection is not None else None
|
1638 |
+
emb = base["embeddings"][0] if base and "embeddings" in base and base["embeddings"] else None
|
1639 |
+
if emb is None:
|
1640 |
+
return jsonify({"error": "objectId not found or has no embedding"}), 400
|
1641 |
+
results = object_collection.query(
|
1642 |
+
query_embeddings=[emb],
|
1643 |
+
n_results=n_results,
|
1644 |
+
include=["metadatas", "distances"]
|
1645 |
+
)
|
1646 |
+
elif search_type == 'class' and 'class_name' in data:
|
1647 |
+
filter_query = {"class": {"$eq": data['class_name']}}
|
1648 |
+
results = object_collection.get(
|
1649 |
+
where=filter_query,
|
1650 |
+
limit=n_results,
|
1651 |
+
include=["metadatas", "embeddings", "documents"]
|
1652 |
+
) if object_collection is not None else None
|
1653 |
+
else:
|
1654 |
+
return jsonify({"error": "Invalid search parameters"}), 400
|
1655 |
+
except Exception as e:
|
1656 |
+
return jsonify({"error": f"Retrieval failed: {str(e)}"}), 500
|
1657 |
+
|
1658 |
+
# Format results using existing helper
|
1659 |
+
formatted = format_object_results(results) if results else []
|
1660 |
+
|
1661 |
+
# Build concise context for LLM
|
1662 |
+
def _shorten(md):
|
1663 |
+
try:
|
1664 |
+
bbox = md.get('bbox') if isinstance(md, dict) else None
|
1665 |
+
if isinstance(bbox, dict):
|
1666 |
+
bbox = {k: round(float(v), 3) for k, v in bbox.items() if isinstance(v, (int, float))}
|
1667 |
+
return {
|
1668 |
+
'image_id': md.get('image_id'),
|
1669 |
+
'class': md.get('class'),
|
1670 |
+
'confidence': md.get('confidence'),
|
1671 |
+
'bbox': bbox,
|
1672 |
+
}
|
1673 |
+
except Exception:
|
1674 |
+
return {k: md.get(k) for k in ('image_id', 'class', 'confidence') if k in md}
|
1675 |
+
|
1676 |
+
context_items = []
|
1677 |
+
for r in formatted[:n_results]:
|
1678 |
+
md = r.get('metadata', {})
|
1679 |
+
item = {
|
1680 |
+
'id': r.get('id'),
|
1681 |
+
'distance': r.get('distance'),
|
1682 |
+
'meta': _shorten(md)
|
1683 |
+
}
|
1684 |
+
context_items.append(item)
|
1685 |
+
|
1686 |
+
# Compose prompt
|
1687 |
+
system_text = (
|
1688 |
+
"You are a vision assistant. Use ONLY the provided detected object context to answer. "
|
1689 |
+
"Be concise and state uncertainty if context is insufficient."
|
1690 |
+
)
|
1691 |
+
# Provide the minimal JSON-like context to the model
|
1692 |
+
context_text = json.dumps(context_items, ensure_ascii=False, indent=2)
|
1693 |
+
user_text = f"User question: {user_query}\n\nDetected context (top {len(context_items)}):\n{context_text}"
|
1694 |
+
|
1695 |
+
try:
|
1696 |
+
start = time.time()
|
1697 |
+
llm = ChatOpenAI(api_key=api_key, model=os.environ.get('OPENAI_MODEL', 'gpt-4o'))
|
1698 |
+
# Keep it simple: template -> LLM -> string
|
1699 |
+
prompt = ChatPromptTemplate.from_messages([
|
1700 |
+
("system", system_text),
|
1701 |
+
("human", "{input}")
|
1702 |
+
])
|
1703 |
+
chain = prompt | llm | StrOutputParser()
|
1704 |
+
answer = chain.invoke({"input": user_text})
|
1705 |
+
latency = round(time.time() - start, 3)
|
1706 |
+
except Exception as e:
|
1707 |
+
return jsonify({"error": f"LLM call failed: {str(e)}"}), 502
|
1708 |
+
|
1709 |
+
return jsonify({
|
1710 |
+
"answer": answer,
|
1711 |
+
"retrieved": context_items,
|
1712 |
+
"model": getattr(llm, 'model', None),
|
1713 |
+
"latency_sec": latency
|
1714 |
+
})
|
1715 |
+
|
1716 |
@app.route('/api/status', methods=['GET'])
|
1717 |
@fresh_login_required
|
1718 |
def status():
|
requirements.txt
CHANGED
@@ -40,3 +40,7 @@ pysqlite3-binary>=0.5.0
|
|
40 |
|
41 |
# OpenAI Python SDK
|
42 |
openai>=1.30.0
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# OpenAI Python SDK
|
42 |
openai>=1.30.0
|
43 |
+
|
44 |
+
# LangChain (RAG pipeline)
|
45 |
+
langchain>=0.2.6
|
46 |
+
langchain-openai>=0.1.16
|