gradio.chat.app-HFIPs / mcp_client.py
ysharma's picture
ysharma HF Staff
Create mcp_client.py
f0735e5 verified
raw
history blame
14.4 kB
"""
MCP Client implementation for Universal MCP Client
"""
import asyncio
import json
import re
import base64
from typing import Dict, Optional, Tuple
import anthropic
import logging
import traceback
# Import the proper MCP client components
from mcp import ClientSession
from mcp.client.sse import sse_client
from config import MCPServerConfig, AppConfig, HTTPX_AVAILABLE
logger = logging.getLogger(__name__)
class UniversalMCPClient:
"""Universal MCP Client for connecting to various MCP servers"""
def __init__(self):
self.servers: Dict[str, MCPServerConfig] = {}
self.anthropic_client = None
# Initialize Anthropic client if API key is available
if AppConfig.ANTHROPIC_API_KEY:
self.anthropic_client = anthropic.Anthropic(
api_key=AppConfig.ANTHROPIC_API_KEY
)
logger.info("βœ… Anthropic client initialized")
else:
logger.warning("⚠️ ANTHROPIC_API_KEY not found")
async def add_server_async(self, config: MCPServerConfig) -> Tuple[bool, str]:
"""Add an MCP server using pure MCP protocol"""
try:
logger.info(f"πŸ”§ Adding MCP server: {config.name} at {config.url}")
# Clean and validate URL - handle various input formats
original_url = config.url.strip()
# Remove common MCP endpoint variations
base_url = original_url
for endpoint in ["/gradio_api/mcp/sse", "/gradio_api/mcp/", "/gradio_api/mcp"]:
if base_url.endswith(endpoint):
base_url = base_url[:-len(endpoint)]
break
# Remove trailing slashes
base_url = base_url.rstrip("/")
# Construct proper MCP URL
mcp_url = f"{base_url}/gradio_api/mcp/sse"
logger.info(f"πŸ”§ Original URL: {original_url}")
logger.info(f"πŸ”§ Base URL: {base_url}")
logger.info(f"πŸ”§ MCP URL: {mcp_url}")
# Extract space ID if it's a HuggingFace space
if "hf.space" in base_url:
space_parts = base_url.split("/")
if len(space_parts) >= 1:
space_id = space_parts[-1].replace('.hf.space', '').replace('https://', '').replace('http://', '')
if '-' in space_id:
# Format: username-spacename.hf.space
config.space_id = space_id.replace('-', '/', 1)
else:
config.space_id = space_id
logger.info(f"πŸ“ Detected HF Space ID: {config.space_id}")
# Update config with proper MCP URL
config.url = mcp_url
# Test MCP connection
success, message = await self._test_mcp_connection(config)
if success:
self.servers[config.name] = config
logger.info(f"βœ… MCP Server {config.name} added successfully")
return True, f"βœ… Successfully added MCP server: {config.name}\n{message}"
else:
logger.error(f"❌ Failed to connect to MCP server {config.name}: {message}")
return False, f"❌ Failed to add server: {config.name}\n{message}"
except Exception as e:
error_msg = f"Failed to add server {config.name}: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return False, f"❌ {error_msg}"
async def _test_mcp_connection(self, config: MCPServerConfig) -> Tuple[bool, str]:
"""Test MCP server connection with detailed debugging"""
try:
logger.info(f"πŸ” Testing MCP connection to {config.url}")
async with sse_client(config.url, timeout=AppConfig.MCP_TIMEOUT_SECONDS) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
# Initialize MCP session
logger.info("πŸ”§ Initializing MCP session...")
await session.initialize()
# List available tools
logger.info("πŸ“‹ Listing available tools...")
tools = await session.list_tools()
tool_info = []
for tool in tools.tools:
tool_info.append(f" - {tool.name}: {tool.description}")
logger.info(f" πŸ“ Tool: {tool.name}")
logger.info(f" Description: {tool.description}")
if hasattr(tool, 'inputSchema') and tool.inputSchema:
logger.info(f" Input Schema: {tool.inputSchema}")
if len(tools.tools) == 0:
return False, "No tools found on MCP server"
message = f"Connected successfully!\nFound {len(tools.tools)} tools:\n" + "\n".join(tool_info)
return True, message
except asyncio.TimeoutError:
return False, "Connection timeout - server may be sleeping or unreachable"
except Exception as e:
logger.error(f"MCP connection failed: {e}")
logger.error(traceback.format_exc())
return False, f"Connection failed: {str(e)}"
def _extract_media_from_mcp_response(self, result_text: str, config: MCPServerConfig) -> Optional[str]:
"""Enhanced media extraction from MCP responses"""
if not isinstance(result_text, str):
logger.info(f"πŸ” Non-string result: {type(result_text)}")
return None
base_url = config.url.replace("/gradio_api/mcp/sse", "")
logger.info(f"πŸ” Processing MCP result for media: {result_text[:300]}...")
logger.info(f"πŸ” Base URL: {base_url}")
# 1. Try to parse as JSON (most Gradio MCP servers return structured data)
try:
if result_text.strip().startswith('[') or result_text.strip().startswith('{'):
logger.info("πŸ” Attempting JSON parse...")
data = json.loads(result_text.strip())
logger.info(f"πŸ” Parsed JSON structure: {data}")
# Handle array format: [{'image': {'url': '...'}}] or [{'url': '...'}]
if isinstance(data, list) and len(data) > 0:
item = data[0]
logger.info(f"πŸ” First array item: {item}")
if isinstance(item, dict):
# Check for nested media structure
for media_type in ['image', 'audio', 'video']:
if media_type in item and isinstance(item[media_type], dict):
media_data = item[media_type]
if 'url' in media_data:
url = media_data['url']
logger.info(f"🎯 Found {media_type} URL: {url}")
return self._resolve_media_url(url, base_url)
# Check for direct URL
if 'url' in item:
url = item['url']
logger.info(f"🎯 Found direct URL: {url}")
return self._resolve_media_url(url, base_url)
# Handle object format: {'image': {'url': '...'}} or {'url': '...'}
elif isinstance(data, dict):
logger.info(f"πŸ” Processing dict: {data}")
# Check for nested media structure
for media_type in ['image', 'audio', 'video']:
if media_type in data and isinstance(data[media_type], dict):
media_data = data[media_type]
if 'url' in media_data:
url = media_data['url']
logger.info(f"🎯 Found {media_type} URL: {url}")
return self._resolve_media_url(url, base_url)
# Check for direct URL
if 'url' in data:
url = data['url']
logger.info(f"🎯 Found direct URL: {url}")
return self._resolve_media_url(url, base_url)
except json.JSONDecodeError:
logger.info("πŸ” Not valid JSON, trying other formats...")
except Exception as e:
logger.warning(f"πŸ” JSON parsing error: {e}")
# 2. Check for data URLs (base64 encoded media)
if result_text.startswith('data:'):
logger.info("🎯 Found data URL")
return result_text
# 3. Check for base64 image patterns
if any(result_text.startswith(pattern) for pattern in ['iVBORw0KGgoAAAANSUhEU', '/9j/', 'UklGR']):
logger.info("🎯 Found base64 image data")
return f"data:image/png;base64,{result_text}"
# 4. Check for file paths and convert to URLs
if AppConfig.is_media_file(result_text):
# Extract just the filename if it's a path
if '/' in result_text:
filename = result_text.split('/')[-1]
else:
filename = result_text.strip()
# Create Gradio file URL
if filename.startswith('http'):
media_url = filename
else:
media_url = f"{base_url}/file={filename}"
logger.info(f"🎯 Found media file: {media_url}")
return media_url
# 5. Check for HTTP URLs that look like media
if result_text.startswith('http') and AppConfig.is_media_file(result_text):
logger.info(f"🎯 Found HTTP media URL: {result_text}")
return result_text
logger.info("❌ No media detected in result")
return None
def _resolve_media_url(self, url: str, base_url: str) -> str:
"""Resolve relative URLs to absolute URLs"""
if url.startswith('http') or url.startswith('data:'):
return url
elif url.startswith('/'):
return f"{base_url}/file={url}"
else:
return f"{base_url}/file={url}"
def _convert_file_to_accessible_url(self, file_path: str, base_url: str) -> str:
"""Convert local file path to accessible URL for MCP servers"""
try:
# Extract filename
filename = file_path.split('/')[-1] if '/' in file_path else file_path
# For Gradio MCP servers, we can use the /file= endpoint
# This assumes the MCP server can access the same file system or we upload it
accessible_url = f"{base_url}/file={filename}"
logger.info(f"πŸ”— Converted file path to accessible URL: {accessible_url}")
return accessible_url
except Exception as e:
logger.error(f"Failed to convert file to accessible URL: {e}")
return file_path # Fallback to original path
async def upload_file_to_gradio_server(self, file_path: str, target_server_url: str) -> Optional[str]:
"""Upload a local file to a Gradio server and return the accessible URL"""
if not HTTPX_AVAILABLE:
logger.error("httpx not available for file upload")
return None
try:
import httpx
# Remove MCP endpoint to get base URL
base_url = target_server_url.replace("/gradio_api/mcp/sse", "")
upload_url = f"{base_url}/upload"
# Read the file
with open(file_path, "rb") as f:
file_content = f.read()
# Get filename
filename = file_path.split('/')[-1] if '/' in file_path else file_path
# Upload file to Gradio server
files = {"file": (filename, file_content)}
async with httpx.AsyncClient() as client:
response = await client.post(upload_url, files=files, timeout=30.0)
if response.status_code == 200:
# Gradio usually returns the file path/URL in the response
result = response.json()
if isinstance(result, list) and len(result) > 0:
uploaded_path = result[0]
# Convert to accessible URL
accessible_url = f"{base_url}/file={uploaded_path}"
logger.info(f"πŸ“€ Successfully uploaded file: {accessible_url}")
return accessible_url
logger.warning(f"File upload failed with status {response.status_code}")
return None
except Exception as e:
logger.error(f"Failed to upload file to Gradio server: {e}")
return None
def _check_file_upload_compatibility(self, config: MCPServerConfig) -> str:
"""Check if a server likely supports file uploads"""
if "hf.space" in config.url:
return "🟑 Hugging Face Space (usually compatible)"
elif "gradio" in config.url.lower():
return "🟒 Gradio server (likely compatible)"
elif "localhost" in config.url or "127.0.0.1" in config.url:
return "🟒 Local server (file access available)"
else:
return "πŸ”΄ Remote server (may need public URLs)"
def get_server_status(self) -> Dict[str, str]:
"""Get status of all configured servers"""
status = {}
for name in self.servers:
compatibility = self._check_file_upload_compatibility(self.servers[name])
status[name] = f"βœ… Connected (MCP Protocol) - {compatibility}"
return status