|
import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; |
|
import type { JsonSchemaType } from 'librechat-data-provider'; |
|
import type { Logger } from 'winston'; |
|
import type * as t from './types/mcp'; |
|
import { formatToolContent } from './parsers'; |
|
import { MCPConnection } from './connection'; |
|
import { CONSTANTS } from './enum'; |
|
|
|
export class MCPManager { |
|
private static instance: MCPManager | null = null; |
|
private connections: Map<string, MCPConnection> = new Map(); |
|
private logger: Logger; |
|
|
|
private static getDefaultLogger(): Logger { |
|
return { |
|
error: console.error, |
|
warn: console.warn, |
|
info: console.info, |
|
debug: console.debug, |
|
} as Logger; |
|
} |
|
|
|
private constructor(logger?: Logger) { |
|
this.logger = logger || MCPManager.getDefaultLogger(); |
|
} |
|
|
|
public static getInstance(logger?: Logger): MCPManager { |
|
if (!MCPManager.instance) { |
|
MCPManager.instance = new MCPManager(logger); |
|
} |
|
return MCPManager.instance; |
|
} |
|
|
|
public async initializeMCP(mcpServers: t.MCPServers): Promise<void> { |
|
this.logger.info('[MCP] Initializing servers'); |
|
|
|
const entries = Object.entries(mcpServers); |
|
const initializedServers = new Set(); |
|
const connectionResults = await Promise.allSettled( |
|
entries.map(async ([serverName, config], i) => { |
|
const connection = new MCPConnection(serverName, config, this.logger); |
|
|
|
connection.on('connectionChange', (state) => { |
|
this.logger.info(`[MCP][${serverName}] Connection state: ${state}`); |
|
}); |
|
|
|
try { |
|
const connectionTimeout = new Promise<void>((_, reject) => |
|
setTimeout(() => reject(new Error('Connection timeout')), 1800000), |
|
); |
|
|
|
const connectionAttempt = this.initializeServer(connection, serverName); |
|
await Promise.race([connectionAttempt, connectionTimeout]); |
|
|
|
if (connection.isConnected()) { |
|
initializedServers.add(i); |
|
this.connections.set(serverName, connection); |
|
|
|
const serverCapabilities = connection.client.getServerCapabilities(); |
|
this.logger.info( |
|
`[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`, |
|
); |
|
|
|
if (serverCapabilities?.tools) { |
|
const tools = await connection.client.listTools(); |
|
if (tools.tools.length) { |
|
this.logger.info( |
|
`[MCP][${serverName}] Available tools: ${tools.tools |
|
.map((tool) => tool.name) |
|
.join(', ')}`, |
|
); |
|
} |
|
} |
|
} |
|
} catch (error) { |
|
this.logger.error(`[MCP][${serverName}] Initialization failed`, error); |
|
throw error; |
|
} |
|
}), |
|
); |
|
|
|
const failedConnections = connectionResults.filter( |
|
(result): result is PromiseRejectedResult => result.status === 'rejected', |
|
); |
|
|
|
this.logger.info(`[MCP] Initialized ${initializedServers.size}/${entries.length} server(s)`); |
|
|
|
if (failedConnections.length > 0) { |
|
this.logger.warn( |
|
`[MCP] ${failedConnections.length}/${entries.length} server(s) failed to initialize`, |
|
); |
|
} |
|
|
|
entries.forEach(([serverName], index) => { |
|
if (initializedServers.has(index)) { |
|
this.logger.info(`[MCP][${serverName}] ✓ Initialized`); |
|
} else { |
|
this.logger.info(`[MCP][${serverName}] ✗ Failed`); |
|
} |
|
}); |
|
|
|
if (initializedServers.size === entries.length) { |
|
this.logger.info('[MCP] All servers initialized successfully'); |
|
} else if (initializedServers.size === 0) { |
|
this.logger.error('[MCP] No servers initialized'); |
|
} |
|
} |
|
|
|
private async initializeServer(connection: MCPConnection, serverName: string): Promise<void> { |
|
const maxAttempts = 3; |
|
let attempts = 0; |
|
|
|
while (attempts < maxAttempts) { |
|
try { |
|
await connection.connect(); |
|
|
|
if (connection.isConnected()) { |
|
return; |
|
} |
|
} catch (error) { |
|
attempts++; |
|
|
|
if (attempts === maxAttempts) { |
|
this.logger.error(`[MCP][${serverName}] Failed after ${maxAttempts} attempts`); |
|
throw error; |
|
} |
|
|
|
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts)); |
|
} |
|
} |
|
} |
|
|
|
public getConnection(serverName: string): MCPConnection | undefined { |
|
return this.connections.get(serverName); |
|
} |
|
|
|
public getAllConnections(): Map<string, MCPConnection> { |
|
return this.connections; |
|
} |
|
|
|
public async mapAvailableTools(availableTools: t.LCAvailableTools): Promise<void> { |
|
for (const [serverName, connection] of this.connections.entries()) { |
|
try { |
|
if (connection.isConnected() !== true) { |
|
this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`); |
|
continue; |
|
} |
|
|
|
const tools = await connection.fetchTools(); |
|
for (const tool of tools) { |
|
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; |
|
availableTools[name] = { |
|
type: 'function', |
|
['function']: { |
|
name, |
|
description: tool.description, |
|
parameters: tool.inputSchema as JsonSchemaType, |
|
}, |
|
}; |
|
} |
|
} catch (error) { |
|
this.logger.warn(`[MCP][${serverName}] Not connected, skipping tool fetch`); |
|
} |
|
} |
|
} |
|
|
|
public async loadManifestTools(manifestTools: t.LCToolManifest): Promise<void> { |
|
for (const [serverName, connection] of this.connections.entries()) { |
|
try { |
|
if (connection.isConnected() !== true) { |
|
this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`); |
|
continue; |
|
} |
|
|
|
const tools = await connection.fetchTools(); |
|
for (const tool of tools) { |
|
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; |
|
manifestTools.push({ |
|
name: tool.name, |
|
pluginKey, |
|
description: tool.description ?? '', |
|
icon: connection.iconPath, |
|
}); |
|
} |
|
} catch (error) { |
|
this.logger.error(`[MCP][${serverName}] Error fetching tools`, error); |
|
} |
|
} |
|
} |
|
|
|
async callTool( |
|
serverName: string, |
|
toolName: string, |
|
provider: t.Provider, |
|
toolArguments?: Record<string, unknown>, |
|
): Promise<t.FormattedToolResponse> { |
|
const connection = this.connections.get(serverName); |
|
if (!connection) { |
|
throw new Error( |
|
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, |
|
); |
|
} |
|
const result = await connection.client.request( |
|
{ |
|
method: 'tools/call', |
|
params: { |
|
name: toolName, |
|
arguments: toolArguments, |
|
}, |
|
}, |
|
CallToolResultSchema, |
|
); |
|
return formatToolContent(result, provider); |
|
} |
|
|
|
public async disconnectServer(serverName: string): Promise<void> { |
|
const connection = this.connections.get(serverName); |
|
if (connection) { |
|
await connection.disconnect(); |
|
this.connections.delete(serverName); |
|
} |
|
} |
|
|
|
public async disconnectAll(): Promise<void> { |
|
const disconnectPromises = Array.from(this.connections.values()).map((connection) => |
|
connection.disconnect(), |
|
); |
|
await Promise.all(disconnectPromises); |
|
this.connections.clear(); |
|
} |
|
|
|
public static async destroyInstance(): Promise<void> { |
|
if (MCPManager.instance) { |
|
await MCPManager.instance.disconnectAll(); |
|
MCPManager.instance = null; |
|
} |
|
} |
|
} |
|
|