| import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; | |
| import type { JsonSchemaType } from 'librechat-data-provider'; | |
| import type { Logger } from 'winston'; | |
| import type * as t from './types/mcp'; | |
| import { formatToolContent } from './parsers'; | |
| import { MCPConnection } from './connection'; | |
| import { CONSTANTS } from './enum'; | |
| export class MCPManager { | |
| private static instance: MCPManager | null = null; | |
| private connections: Map<string, MCPConnection> = new Map(); | |
| private logger: Logger; | |
| private static getDefaultLogger(): Logger { | |
| return { | |
| error: console.error, | |
| warn: console.warn, | |
| info: console.info, | |
| debug: console.debug, | |
| } as Logger; | |
| } | |
| private constructor(logger?: Logger) { | |
| this.logger = logger || MCPManager.getDefaultLogger(); | |
| } | |
| public static getInstance(logger?: Logger): MCPManager { | |
| if (!MCPManager.instance) { | |
| MCPManager.instance = new MCPManager(logger); | |
| } | |
| return MCPManager.instance; | |
| } | |
| public async initializeMCP(mcpServers: t.MCPServers): Promise<void> { | |
| this.logger.info('[MCP] Initializing servers'); | |
| const entries = Object.entries(mcpServers); | |
| const initializedServers = new Set(); | |
| const connectionResults = await Promise.allSettled( | |
| entries.map(async ([serverName, config], i) => { | |
| const connection = new MCPConnection(serverName, config, this.logger); | |
| connection.on('connectionChange', (state) => { | |
| this.logger.info(`[MCP][${serverName}] Connection state: ${state}`); | |
| }); | |
| try { | |
| const connectionTimeout = new Promise<void>((_, reject) => | |
| setTimeout(() => reject(new Error('Connection timeout')), 1800000), | |
| ); | |
| const connectionAttempt = this.initializeServer(connection, serverName); | |
| await Promise.race([connectionAttempt, connectionTimeout]); | |
| if (connection.isConnected()) { | |
| initializedServers.add(i); | |
| this.connections.set(serverName, connection); | |
| const serverCapabilities = connection.client.getServerCapabilities(); | |
| this.logger.info( | |
| `[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`, | |
| ); | |
| if (serverCapabilities?.tools) { | |
| const tools = await connection.client.listTools(); | |
| if (tools.tools.length) { | |
| this.logger.info( | |
| `[MCP][${serverName}] Available tools: ${tools.tools | |
| .map((tool) => tool.name) | |
| .join(', ')}`, | |
| ); | |
| } | |
| } | |
| } | |
| } catch (error) { | |
| this.logger.error(`[MCP][${serverName}] Initialization failed`, error); | |
| throw error; | |
| } | |
| }), | |
| ); | |
| const failedConnections = connectionResults.filter( | |
| (result): result is PromiseRejectedResult => result.status === 'rejected', | |
| ); | |
| this.logger.info(`[MCP] Initialized ${initializedServers.size}/${entries.length} server(s)`); | |
| if (failedConnections.length > 0) { | |
| this.logger.warn( | |
| `[MCP] ${failedConnections.length}/${entries.length} server(s) failed to initialize`, | |
| ); | |
| } | |
| entries.forEach(([serverName], index) => { | |
| if (initializedServers.has(index)) { | |
| this.logger.info(`[MCP][${serverName}] ✓ Initialized`); | |
| } else { | |
| this.logger.info(`[MCP][${serverName}] ✗ Failed`); | |
| } | |
| }); | |
| if (initializedServers.size === entries.length) { | |
| this.logger.info('[MCP] All servers initialized successfully'); | |
| } else if (initializedServers.size === 0) { | |
| this.logger.error('[MCP] No servers initialized'); | |
| } | |
| } | |
| private async initializeServer(connection: MCPConnection, serverName: string): Promise<void> { | |
| const maxAttempts = 3; | |
| let attempts = 0; | |
| while (attempts < maxAttempts) { | |
| try { | |
| await connection.connect(); | |
| if (connection.isConnected()) { | |
| return; | |
| } | |
| } catch (error) { | |
| attempts++; | |
| if (attempts === maxAttempts) { | |
| this.logger.error(`[MCP][${serverName}] Failed after ${maxAttempts} attempts`); | |
| throw error; | |
| } | |
| await new Promise((resolve) => setTimeout(resolve, 2000 * attempts)); | |
| } | |
| } | |
| } | |
| public getConnection(serverName: string): MCPConnection | undefined { | |
| return this.connections.get(serverName); | |
| } | |
| public getAllConnections(): Map<string, MCPConnection> { | |
| return this.connections; | |
| } | |
| public async mapAvailableTools(availableTools: t.LCAvailableTools): Promise<void> { | |
| for (const [serverName, connection] of this.connections.entries()) { | |
| try { | |
| if (connection.isConnected() !== true) { | |
| this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`); | |
| continue; | |
| } | |
| const tools = await connection.fetchTools(); | |
| for (const tool of tools) { | |
| const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; | |
| availableTools[name] = { | |
| type: 'function', | |
| ['function']: { | |
| name, | |
| description: tool.description, | |
| parameters: tool.inputSchema as JsonSchemaType, | |
| }, | |
| }; | |
| } | |
| } catch (error) { | |
| this.logger.warn(`[MCP][${serverName}] Not connected, skipping tool fetch`); | |
| } | |
| } | |
| } | |
| public async loadManifestTools(manifestTools: t.LCToolManifest): Promise<void> { | |
| for (const [serverName, connection] of this.connections.entries()) { | |
| try { | |
| if (connection.isConnected() !== true) { | |
| this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`); | |
| continue; | |
| } | |
| const tools = await connection.fetchTools(); | |
| for (const tool of tools) { | |
| const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; | |
| manifestTools.push({ | |
| name: tool.name, | |
| pluginKey, | |
| description: tool.description ?? '', | |
| icon: connection.iconPath, | |
| }); | |
| } | |
| } catch (error) { | |
| this.logger.error(`[MCP][${serverName}] Error fetching tools`, error); | |
| } | |
| } | |
| } | |
| async callTool( | |
| serverName: string, | |
| toolName: string, | |
| provider: t.Provider, | |
| toolArguments?: Record<string, unknown>, | |
| ): Promise<t.FormattedToolResponse> { | |
| const connection = this.connections.get(serverName); | |
| if (!connection) { | |
| throw new Error( | |
| `No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, | |
| ); | |
| } | |
| const result = await connection.client.request( | |
| { | |
| method: 'tools/call', | |
| params: { | |
| name: toolName, | |
| arguments: toolArguments, | |
| }, | |
| }, | |
| CallToolResultSchema, | |
| ); | |
| return formatToolContent(result, provider); | |
| } | |
| public async disconnectServer(serverName: string): Promise<void> { | |
| const connection = this.connections.get(serverName); | |
| if (connection) { | |
| await connection.disconnect(); | |
| this.connections.delete(serverName); | |
| } | |
| } | |
| public async disconnectAll(): Promise<void> { | |
| const disconnectPromises = Array.from(this.connections.values()).map((connection) => | |
| connection.disconnect(), | |
| ); | |
| await Promise.all(disconnectPromises); | |
| this.connections.clear(); | |
| } | |
| public static async destroyInstance(): Promise<void> { | |
| if (MCPManager.instance) { | |
| await MCPManager.instance.disconnectAll(); | |
| MCPManager.instance = null; | |
| } | |
| } | |
| } | |