Spaces:
Running
Running
Andy Lee
commited on
Commit
·
4d37e51
1
Parent(s):
23ee129
feat: more models, including qwen
Browse files- app.py +104 -25
- config.py +26 -0
- geo_bot.py +32 -12
- hf_chat.py +142 -0
app.py
CHANGED
|
@@ -136,63 +136,138 @@ if start_button:
|
|
| 136 |
# --- Inner agent exploration loop ---
|
| 137 |
history = []
|
| 138 |
final_guess = None
|
|
|
|
| 139 |
|
| 140 |
for step in range(steps_per_sample):
|
| 141 |
step_num = step + 1
|
| 142 |
reasoning_placeholder.info(
|
| 143 |
-
f"Thinking... (Step {step_num}/{steps_per_sample})"
|
| 144 |
)
|
| 145 |
action_placeholder.empty()
|
| 146 |
|
| 147 |
# Observe and label arrows
|
| 148 |
bot.controller.label_arrows_on_screen()
|
| 149 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
|
|
|
|
|
|
| 150 |
image_placeholder.image(
|
| 151 |
-
screenshot_bytes,
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
# Update history
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
)
|
| 163 |
|
| 164 |
# Think
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
| 166 |
remaining_steps=steps_per_sample - step,
|
| 167 |
-
history_text=
|
| 168 |
-
|
| 169 |
-
),
|
| 170 |
-
available_actions=json.dumps(bot.controller.get_available_actions()),
|
| 171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
message = bot._create_message_with_history(
|
| 173 |
prompt, [h["image_b64"] for h in history]
|
| 174 |
)
|
|
|
|
|
|
|
| 175 |
response = bot.model.invoke(message)
|
| 176 |
decision = bot._parse_agent_response(response)
|
| 177 |
|
| 178 |
if not decision: # Fallback
|
| 179 |
decision = {
|
| 180 |
"action_details": {"action": "PAN_RIGHT"},
|
| 181 |
-
"reasoning": "
|
| 182 |
}
|
| 183 |
|
| 184 |
action = decision.get("action_details", {}).get("action")
|
| 185 |
history[-1]["action"] = action
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
# Force a GUESS on the last step
|
| 193 |
if step_num == steps_per_sample and action != "GUESS":
|
| 194 |
-
st.warning("Max steps reached. Forcing a GUESS action.")
|
| 195 |
action = "GUESS"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
# Act
|
| 198 |
if action == "GUESS":
|
|
@@ -204,18 +279,22 @@ if start_button:
|
|
| 204 |
final_guess = (lat, lon)
|
| 205 |
else:
|
| 206 |
st.error(
|
| 207 |
-
"GUESS action was missing coordinates. Guess failed for this sample."
|
| 208 |
)
|
| 209 |
break # End exploration for the current sample
|
| 210 |
|
| 211 |
elif action == "MOVE_FORWARD":
|
| 212 |
-
|
|
|
|
| 213 |
elif action == "MOVE_BACKWARD":
|
| 214 |
-
|
|
|
|
| 215 |
elif action == "PAN_LEFT":
|
| 216 |
-
|
|
|
|
| 217 |
elif action == "PAN_RIGHT":
|
| 218 |
-
|
|
|
|
| 219 |
|
| 220 |
time.sleep(1) # A brief pause between steps for better visualization
|
| 221 |
|
|
|
|
| 136 |
# --- Inner agent exploration loop ---
|
| 137 |
history = []
|
| 138 |
final_guess = None
|
| 139 |
+
step_history_container = st.container()
|
| 140 |
|
| 141 |
for step in range(steps_per_sample):
|
| 142 |
step_num = step + 1
|
| 143 |
reasoning_placeholder.info(
|
| 144 |
+
f"🤔 Thinking... (Step {step_num}/{steps_per_sample})"
|
| 145 |
)
|
| 146 |
action_placeholder.empty()
|
| 147 |
|
| 148 |
# Observe and label arrows
|
| 149 |
bot.controller.label_arrows_on_screen()
|
| 150 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
| 151 |
+
|
| 152 |
+
# Current view
|
| 153 |
image_placeholder.image(
|
| 154 |
+
screenshot_bytes,
|
| 155 |
+
caption=f"🔍 Step {step_num} - What AI Sees Now",
|
| 156 |
+
use_column_width=True,
|
| 157 |
)
|
| 158 |
|
| 159 |
# Update history
|
| 160 |
+
current_step_data = {
|
| 161 |
+
"image_b64": bot.pil_to_base64(Image.open(BytesIO(screenshot_bytes))),
|
| 162 |
+
"action": "N/A",
|
| 163 |
+
"screenshot_bytes": screenshot_bytes,
|
| 164 |
+
"step_num": step_num,
|
| 165 |
+
}
|
| 166 |
+
history.append(current_step_data)
|
|
|
|
| 167 |
|
| 168 |
# Think
|
| 169 |
+
available_actions = bot.controller.get_available_actions()
|
| 170 |
+
history_text = "\n".join(
|
| 171 |
+
[f"Step {j + 1}: {h['action']}" for j, h in enumerate(history[:-1])]
|
| 172 |
+
)
|
| 173 |
+
if not history_text:
|
| 174 |
+
history_text = "No history yet. This is the first step."
|
| 175 |
+
|
| 176 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
| 177 |
remaining_steps=steps_per_sample - step,
|
| 178 |
+
history_text=history_text,
|
| 179 |
+
available_actions=json.dumps(available_actions),
|
|
|
|
|
|
|
| 180 |
)
|
| 181 |
+
|
| 182 |
+
# Show what AI is considering
|
| 183 |
+
with reasoning_placeholder:
|
| 184 |
+
st.info("🧠 **AI is analyzing the situation...**")
|
| 185 |
+
with st.expander("🔍 Available Actions", expanded=False):
|
| 186 |
+
st.json(available_actions)
|
| 187 |
+
with st.expander("📝 Context Being Considered", expanded=False):
|
| 188 |
+
st.text_area(
|
| 189 |
+
"History Context:", history_text, height=100, disabled=True
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
message = bot._create_message_with_history(
|
| 193 |
prompt, [h["image_b64"] for h in history]
|
| 194 |
)
|
| 195 |
+
|
| 196 |
+
# Get AI response
|
| 197 |
response = bot.model.invoke(message)
|
| 198 |
decision = bot._parse_agent_response(response)
|
| 199 |
|
| 200 |
if not decision: # Fallback
|
| 201 |
decision = {
|
| 202 |
"action_details": {"action": "PAN_RIGHT"},
|
| 203 |
+
"reasoning": "⚠️ Response parsing failed. Using default recovery action.",
|
| 204 |
}
|
| 205 |
|
| 206 |
action = decision.get("action_details", {}).get("action")
|
| 207 |
history[-1]["action"] = action
|
| 208 |
+
history[-1]["reasoning"] = decision.get("reasoning", "N/A")
|
| 209 |
+
history[-1]["raw_response"] = (
|
| 210 |
+
response.content[:500] + "..."
|
| 211 |
+
if len(response.content) > 500
|
| 212 |
+
else response.content
|
| 213 |
)
|
| 214 |
+
|
| 215 |
+
# Display AI's decision process
|
| 216 |
+
reasoning_placeholder.success("✅ **AI Decision Made!**")
|
| 217 |
+
|
| 218 |
+
with action_placeholder:
|
| 219 |
+
st.success(f"🎯 **AI Action:** `{action}`")
|
| 220 |
+
|
| 221 |
+
# Detailed reasoning display
|
| 222 |
+
with st.expander("🧠 AI's Detailed Thinking Process", expanded=True):
|
| 223 |
+
col_reason, col_raw = st.columns([2, 1])
|
| 224 |
+
|
| 225 |
+
with col_reason:
|
| 226 |
+
st.markdown("**🤔 AI's Reasoning:**")
|
| 227 |
+
st.info(decision.get("reasoning", "N/A"))
|
| 228 |
+
|
| 229 |
+
if action == "GUESS":
|
| 230 |
+
lat = decision.get("action_details", {}).get("lat")
|
| 231 |
+
lon = decision.get("action_details", {}).get("lon")
|
| 232 |
+
if lat and lon:
|
| 233 |
+
st.success(f"📍 **Final Guess:** {lat:.4f}, {lon:.4f}")
|
| 234 |
+
|
| 235 |
+
with col_raw:
|
| 236 |
+
st.markdown("**🔤 Raw AI Response:**")
|
| 237 |
+
st.text_area(
|
| 238 |
+
"Full Response:",
|
| 239 |
+
history[-1]["raw_response"],
|
| 240 |
+
height=200,
|
| 241 |
+
disabled=True,
|
| 242 |
+
key=f"raw_response_{step_num}",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Store step in history display
|
| 246 |
+
with step_history_container:
|
| 247 |
+
with st.expander(f"📚 Step {step_num} History", expanded=False):
|
| 248 |
+
hist_col1, hist_col2 = st.columns([1, 2])
|
| 249 |
+
with hist_col1:
|
| 250 |
+
st.image(
|
| 251 |
+
screenshot_bytes, caption=f"Step {step_num} View", width=200
|
| 252 |
+
)
|
| 253 |
+
with hist_col2:
|
| 254 |
+
st.write(f"**Action:** {action}")
|
| 255 |
+
st.write(
|
| 256 |
+
f"**Reasoning:** {decision.get('reasoning', 'N/A')[:150]}..."
|
| 257 |
+
)
|
| 258 |
|
| 259 |
# Force a GUESS on the last step
|
| 260 |
if step_num == steps_per_sample and action != "GUESS":
|
| 261 |
+
st.warning("⏰ Max steps reached. Forcing a GUESS action.")
|
| 262 |
action = "GUESS"
|
| 263 |
+
# Force coordinates if missing
|
| 264 |
+
if not decision.get("action_details", {}).get("lat"):
|
| 265 |
+
st.error("❌ AI didn't provide coordinates. Using fallback guess.")
|
| 266 |
+
decision["action_details"] = {
|
| 267 |
+
"action": "GUESS",
|
| 268 |
+
"lat": 0.0,
|
| 269 |
+
"lon": 0.0,
|
| 270 |
+
}
|
| 271 |
|
| 272 |
# Act
|
| 273 |
if action == "GUESS":
|
|
|
|
| 279 |
final_guess = (lat, lon)
|
| 280 |
else:
|
| 281 |
st.error(
|
| 282 |
+
"❌ GUESS action was missing coordinates. Guess failed for this sample."
|
| 283 |
)
|
| 284 |
break # End exploration for the current sample
|
| 285 |
|
| 286 |
elif action == "MOVE_FORWARD":
|
| 287 |
+
with st.spinner("🚶 Moving forward..."):
|
| 288 |
+
bot.controller.move("forward")
|
| 289 |
elif action == "MOVE_BACKWARD":
|
| 290 |
+
with st.spinner("🔄 Moving backward..."):
|
| 291 |
+
bot.controller.move("backward")
|
| 292 |
elif action == "PAN_LEFT":
|
| 293 |
+
with st.spinner("⬅️ Panning left..."):
|
| 294 |
+
bot.controller.pan_view("left")
|
| 295 |
elif action == "PAN_RIGHT":
|
| 296 |
+
with st.spinner("➡️ Panning right..."):
|
| 297 |
+
bot.controller.pan_view("right")
|
| 298 |
|
| 299 |
time.sleep(1) # A brief pause between steps for better visualization
|
| 300 |
|
config.py
CHANGED
|
@@ -31,18 +31,44 @@ MODELS_CONFIG = {
|
|
| 31 |
"gpt-4o": {
|
| 32 |
"class": "ChatOpenAI",
|
| 33 |
"model_name": "gpt-4o",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
},
|
| 35 |
"claude-3.5-sonnet": {
|
| 36 |
"class": "ChatAnthropic",
|
| 37 |
"model_name": "claude-3-5-sonnet-20240620",
|
|
|
|
|
|
|
| 38 |
},
|
| 39 |
"gemini-1.5-pro": {
|
| 40 |
"class": "ChatGoogleGenerativeAI",
|
| 41 |
"model_name": "gemini-1.5-pro-latest",
|
|
|
|
|
|
|
| 42 |
},
|
| 43 |
"gemini-2.5-pro": {
|
| 44 |
"class": "ChatGoogleGenerativeAI",
|
| 45 |
"model_name": "gemini-2.5-pro-preview-06-05",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
},
|
| 47 |
}
|
| 48 |
|
|
|
|
| 31 |
"gpt-4o": {
|
| 32 |
"class": "ChatOpenAI",
|
| 33 |
"model_name": "gpt-4o",
|
| 34 |
+
"api_key_env": "OPENAI_API_KEY",
|
| 35 |
+
"description": "OpenAI GPT-4o",
|
| 36 |
+
},
|
| 37 |
+
"gpt-4o-mini": {
|
| 38 |
+
"class": "ChatOpenAI",
|
| 39 |
+
"model_name": "gpt-4o-mini",
|
| 40 |
+
"api_key_env": "OPENAI_API_KEY",
|
| 41 |
+
"description": "OpenAI GPT-4o Mini (cheaper)",
|
| 42 |
},
|
| 43 |
"claude-3.5-sonnet": {
|
| 44 |
"class": "ChatAnthropic",
|
| 45 |
"model_name": "claude-3-5-sonnet-20240620",
|
| 46 |
+
"api_key_env": "ANTHROPIC_API_KEY",
|
| 47 |
+
"description": "Anthropic Claude 3.5 Sonnet",
|
| 48 |
},
|
| 49 |
"gemini-1.5-pro": {
|
| 50 |
"class": "ChatGoogleGenerativeAI",
|
| 51 |
"model_name": "gemini-1.5-pro-latest",
|
| 52 |
+
"api_key_env": "GOOGLE_API_KEY",
|
| 53 |
+
"description": "Google Gemini 1.5 Pro",
|
| 54 |
},
|
| 55 |
"gemini-2.5-pro": {
|
| 56 |
"class": "ChatGoogleGenerativeAI",
|
| 57 |
"model_name": "gemini-2.5-pro-preview-06-05",
|
| 58 |
+
"api_key_env": "GOOGLE_API_KEY",
|
| 59 |
+
"description": "Google Gemini 2.5 Pro",
|
| 60 |
+
},
|
| 61 |
+
"qwen2-vl-72b": {
|
| 62 |
+
"class": "HuggingFaceChat",
|
| 63 |
+
"model_name": "Qwen/Qwen2-VL-72B-Instruct",
|
| 64 |
+
"api_key_env": "HUGGINGFACE_API_KEY",
|
| 65 |
+
"description": "Qwen2-VL 72B (via HF Inference API)",
|
| 66 |
+
},
|
| 67 |
+
"qwen2-vl-7b": {
|
| 68 |
+
"class": "HuggingFaceChat",
|
| 69 |
+
"model_name": "Qwen/Qwen2-VL-7B-Instruct",
|
| 70 |
+
"api_key_env": "HUGGINGFACE_API_KEY",
|
| 71 |
+
"description": "Qwen2-VL 7B (via HF Inference API)",
|
| 72 |
},
|
| 73 |
}
|
| 74 |
|
geo_bot.py
CHANGED
|
@@ -11,10 +11,11 @@ from langchain_openai import ChatOpenAI
|
|
| 11 |
from langchain_anthropic import ChatAnthropic
|
| 12 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from mapcrunch_controller import MapCrunchController
|
| 15 |
|
| 16 |
# The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
|
| 17 |
-
|
| 18 |
AGENT_PROMPT_TEMPLATE = """
|
| 19 |
**Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
|
| 20 |
|
|
@@ -68,11 +69,20 @@ class GeoBot:
|
|
| 68 |
):
|
| 69 |
# Initialize model with temperature parameter
|
| 70 |
model_kwargs = {
|
| 71 |
-
"model": model_name,
|
| 72 |
"temperature": temperature,
|
| 73 |
}
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
self.model_name = model_name
|
| 77 |
self.temperature = temperature
|
| 78 |
self.use_selenium = use_selenium
|
|
@@ -90,6 +100,7 @@ class GeoBot:
|
|
| 90 |
) -> List[HumanMessage]:
|
| 91 |
"""Creates a message for the LLM that includes text and a sequence of images."""
|
| 92 |
content = [{"type": "text", "text": prompt}]
|
|
|
|
| 93 |
# Add the JSON format instructions right after the main prompt text
|
| 94 |
content.append(
|
| 95 |
{
|
|
@@ -145,7 +156,6 @@ class GeoBot:
|
|
| 145 |
print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
|
| 146 |
|
| 147 |
self.controller.setup_clean_environment()
|
| 148 |
-
|
| 149 |
self.controller.label_arrows_on_screen()
|
| 150 |
|
| 151 |
screenshot_bytes = self.controller.take_street_view_screenshot()
|
|
@@ -178,17 +188,22 @@ class GeoBot:
|
|
| 178 |
available_actions=json.dumps(available_actions),
|
| 179 |
)
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
if not decision:
|
| 187 |
print(
|
| 188 |
-
"Response parsing failed. Using default recovery action: PAN_RIGHT."
|
| 189 |
)
|
| 190 |
decision = {
|
| 191 |
-
"reasoning": "Recovery due to parsing failure.",
|
| 192 |
"action_details": {"action": "PAN_RIGHT"},
|
| 193 |
}
|
| 194 |
|
|
@@ -219,8 +234,13 @@ class GeoBot:
|
|
| 219 |
def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
|
| 220 |
image_b64 = self.pil_to_base64(image)
|
| 221 |
message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
content = response.content.strip()
|
| 226 |
last_line = ""
|
|
|
|
| 11 |
from langchain_anthropic import ChatAnthropic
|
| 12 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
|
| 14 |
+
from hf_chat import HuggingFaceChat
|
| 15 |
+
|
| 16 |
from mapcrunch_controller import MapCrunchController
|
| 17 |
|
| 18 |
# The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
|
|
|
|
| 19 |
AGENT_PROMPT_TEMPLATE = """
|
| 20 |
**Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
|
| 21 |
|
|
|
|
| 69 |
):
|
| 70 |
# Initialize model with temperature parameter
|
| 71 |
model_kwargs = {
|
|
|
|
| 72 |
"temperature": temperature,
|
| 73 |
}
|
| 74 |
|
| 75 |
+
# Handle different model types
|
| 76 |
+
if model == HuggingFaceChat and HuggingFaceChat is not None:
|
| 77 |
+
model_kwargs["model"] = model_name
|
| 78 |
+
else:
|
| 79 |
+
model_kwargs["model"] = model_name
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
self.model = model(**model_kwargs)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
raise ValueError(f"Failed to initialize model {model_name}: {e}")
|
| 85 |
+
|
| 86 |
self.model_name = model_name
|
| 87 |
self.temperature = temperature
|
| 88 |
self.use_selenium = use_selenium
|
|
|
|
| 100 |
) -> List[HumanMessage]:
|
| 101 |
"""Creates a message for the LLM that includes text and a sequence of images."""
|
| 102 |
content = [{"type": "text", "text": prompt}]
|
| 103 |
+
|
| 104 |
# Add the JSON format instructions right after the main prompt text
|
| 105 |
content.append(
|
| 106 |
{
|
|
|
|
| 156 |
print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
|
| 157 |
|
| 158 |
self.controller.setup_clean_environment()
|
|
|
|
| 159 |
self.controller.label_arrows_on_screen()
|
| 160 |
|
| 161 |
screenshot_bytes = self.controller.take_street_view_screenshot()
|
|
|
|
| 188 |
available_actions=json.dumps(available_actions),
|
| 189 |
)
|
| 190 |
|
| 191 |
+
try:
|
| 192 |
+
message = self._create_message_with_history(
|
| 193 |
+
prompt, image_b64_for_prompt
|
| 194 |
+
)
|
| 195 |
+
response = self.model.invoke(message)
|
| 196 |
+
decision = self._parse_agent_response(response)
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error during model invocation: {e}")
|
| 199 |
+
decision = None
|
| 200 |
|
| 201 |
if not decision:
|
| 202 |
print(
|
| 203 |
+
"Response parsing failed or model error. Using default recovery action: PAN_RIGHT."
|
| 204 |
)
|
| 205 |
decision = {
|
| 206 |
+
"reasoning": "Recovery due to parsing failure or model error.",
|
| 207 |
"action_details": {"action": "PAN_RIGHT"},
|
| 208 |
}
|
| 209 |
|
|
|
|
| 234 |
def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
|
| 235 |
image_b64 = self.pil_to_base64(image)
|
| 236 |
message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
response = self.model.invoke(message)
|
| 240 |
+
print(f"\nLLM Response:\n{response.content}")
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f"Error during image analysis: {e}")
|
| 243 |
+
return None
|
| 244 |
|
| 245 |
content = response.content.strip()
|
| 246 |
last_line = ""
|
hf_chat.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace Chat Model Wrapper for vision models like Qwen2-VL
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import base64
|
| 7 |
+
import requests
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from langchain_core.messages import BaseMessage, HumanMessage
|
| 10 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 11 |
+
from langchain_core.outputs import ChatResult, ChatGeneration
|
| 12 |
+
from pydantic import Field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HuggingFaceChat(BaseChatModel):
|
| 16 |
+
"""Chat model wrapper for HuggingFace Inference API"""
|
| 17 |
+
|
| 18 |
+
model: str = Field(description="HuggingFace model name")
|
| 19 |
+
temperature: float = Field(default=0.0, description="Temperature for sampling")
|
| 20 |
+
max_tokens: int = Field(default=1000, description="Max tokens to generate")
|
| 21 |
+
api_token: Optional[str] = Field(default=None, description="HF API token")
|
| 22 |
+
|
| 23 |
+
def __init__(self, model: str, temperature: float = 0.0, **kwargs):
|
| 24 |
+
api_token = kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY")
|
| 25 |
+
if not api_token:
|
| 26 |
+
raise ValueError("HUGGINGFACE_API_KEY environment variable is required")
|
| 27 |
+
|
| 28 |
+
super().__init__(
|
| 29 |
+
model=model, temperature=temperature, api_token=api_token, **kwargs
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def _llm_type(self) -> str:
|
| 34 |
+
return "huggingface_chat"
|
| 35 |
+
|
| 36 |
+
def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]:
|
| 37 |
+
"""Convert LangChain message to HuggingFace format"""
|
| 38 |
+
if isinstance(message.content, str):
|
| 39 |
+
return {"role": "user", "content": message.content}
|
| 40 |
+
|
| 41 |
+
# Handle multi-modal content (text + images)
|
| 42 |
+
formatted_content = []
|
| 43 |
+
for item in message.content:
|
| 44 |
+
if item["type"] == "text":
|
| 45 |
+
formatted_content.append({"type": "text", "text": item["text"]})
|
| 46 |
+
elif item["type"] == "image_url":
|
| 47 |
+
# Extract base64 data from data URL
|
| 48 |
+
image_url = item["image_url"]["url"]
|
| 49 |
+
if image_url.startswith("data:image"):
|
| 50 |
+
# Extract base64 data
|
| 51 |
+
base64_data = image_url.split(",")[1]
|
| 52 |
+
formatted_content.append({"type": "image", "image": base64_data})
|
| 53 |
+
|
| 54 |
+
return {"role": "user", "content": formatted_content}
|
| 55 |
+
|
| 56 |
+
def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
|
| 57 |
+
"""Generate response using HuggingFace Inference API"""
|
| 58 |
+
|
| 59 |
+
# Format messages for HF API
|
| 60 |
+
formatted_messages = []
|
| 61 |
+
for msg in messages:
|
| 62 |
+
if isinstance(msg, HumanMessage):
|
| 63 |
+
formatted_messages.append(self._format_message_for_hf(msg))
|
| 64 |
+
|
| 65 |
+
# Prepare API request
|
| 66 |
+
api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions"
|
| 67 |
+
headers = {
|
| 68 |
+
"Authorization": f"Bearer {self.api_token}",
|
| 69 |
+
"Content-Type": "application/json",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
payload = {
|
| 73 |
+
"model": self.model,
|
| 74 |
+
"messages": formatted_messages,
|
| 75 |
+
"temperature": self.temperature,
|
| 76 |
+
"max_tokens": self.max_tokens,
|
| 77 |
+
"stream": False,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
|
| 82 |
+
response.raise_for_status()
|
| 83 |
+
|
| 84 |
+
result = response.json()
|
| 85 |
+
content = result["choices"][0]["message"]["content"]
|
| 86 |
+
|
| 87 |
+
return ChatResult(
|
| 88 |
+
generations=[ChatGeneration(message=HumanMessage(content=content))]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
except requests.exceptions.RequestException as e:
|
| 92 |
+
# Fallback to simple text-only API if chat completions fail
|
| 93 |
+
return self._fallback_generate(messages, **kwargs)
|
| 94 |
+
|
| 95 |
+
def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
|
| 96 |
+
"""Fallback to simple HF Inference API"""
|
| 97 |
+
try:
|
| 98 |
+
# Use simple inference API as fallback
|
| 99 |
+
api_url = f"https://api-inference.huggingface.co/models/{self.model}"
|
| 100 |
+
headers = {
|
| 101 |
+
"Authorization": f"Bearer {self.api_token}",
|
| 102 |
+
"Content-Type": "application/json",
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Extract text content only for fallback
|
| 106 |
+
text_content = ""
|
| 107 |
+
for msg in messages:
|
| 108 |
+
if isinstance(msg, HumanMessage):
|
| 109 |
+
if isinstance(msg.content, str):
|
| 110 |
+
text_content += msg.content
|
| 111 |
+
else:
|
| 112 |
+
for item in msg.content:
|
| 113 |
+
if item["type"] == "text":
|
| 114 |
+
text_content += item["text"] + "\n"
|
| 115 |
+
|
| 116 |
+
payload = {
|
| 117 |
+
"inputs": text_content,
|
| 118 |
+
"parameters": {
|
| 119 |
+
"temperature": self.temperature,
|
| 120 |
+
"max_new_tokens": self.max_tokens,
|
| 121 |
+
},
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
|
| 125 |
+
response.raise_for_status()
|
| 126 |
+
|
| 127 |
+
result = response.json()
|
| 128 |
+
if isinstance(result, list) and len(result) > 0:
|
| 129 |
+
content = result[0].get("generated_text", "No response generated")
|
| 130 |
+
else:
|
| 131 |
+
content = "Error: Invalid response format"
|
| 132 |
+
|
| 133 |
+
return ChatResult(
|
| 134 |
+
generations=[ChatGeneration(message=HumanMessage(content=content))]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
# Last resort fallback
|
| 139 |
+
error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability."
|
| 140 |
+
return ChatResult(
|
| 141 |
+
generations=[ChatGeneration(message=HumanMessage(content=error_msg))]
|
| 142 |
+
)
|