mirror of
https://github.com/langchain-ai/langchainjs-mcp-adapters.git
synced 2026-07-01 12:27:48 -04:00
fix: pass inputSchema to LangChain tools & fail loudly (#20)
Also migrates test suite to vitest for better support for mocked modules with ESMs.
This commit is contained in:
+119
-127
@@ -1,12 +1,13 @@
|
||||
// Mock the problematic dependencies
|
||||
jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
||||
import { vi, describe, test, expect, beforeEach, afterEach } from 'vitest';
|
||||
// Mock the problematic dependencies using vi.mock
|
||||
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
||||
// Create mock functions for all methods
|
||||
const connectMock = jest.fn().mockResolvedValue(undefined);
|
||||
const sendMock = jest.fn().mockResolvedValue(undefined);
|
||||
const closeMock = jest.fn().mockResolvedValue(undefined);
|
||||
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
const sendMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
|
||||
return {
|
||||
SSEClientTransport: jest.fn().mockImplementation(() => ({
|
||||
SSEClientTransport: vi.fn().mockImplementation(() => ({
|
||||
connect: connectMock,
|
||||
send: sendMock,
|
||||
close: closeMock,
|
||||
@@ -15,14 +16,14 @@ jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
||||
// Create mock functions for all methods
|
||||
const connectMock = jest.fn().mockResolvedValue(undefined);
|
||||
const sendMock = jest.fn().mockResolvedValue(undefined);
|
||||
const closeMock = jest.fn().mockResolvedValue(undefined);
|
||||
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
const sendMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
|
||||
return {
|
||||
StdioClientTransport: jest.fn().mockImplementation(() => ({
|
||||
StdioClientTransport: vi.fn().mockImplementation(() => ({
|
||||
connect: connectMock,
|
||||
send: sendMock,
|
||||
close: closeMock,
|
||||
@@ -31,31 +32,35 @@ jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||
// Create mock functions for all methods
|
||||
const connectMock = jest.fn().mockResolvedValue(undefined);
|
||||
const listToolsMock = jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
const listToolsMock = vi.fn().mockReturnValue(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
required: ['input'],
|
||||
},
|
||||
required: ['input'],
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
const callToolMock = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'result' }],
|
||||
});
|
||||
const closeMock = jest.fn().mockResolvedValue(undefined);
|
||||
],
|
||||
})
|
||||
);
|
||||
const callToolMock = vi.fn().mockReturnValue(
|
||||
Promise.resolve({
|
||||
content: [{ type: 'text', text: 'result' }],
|
||||
})
|
||||
);
|
||||
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
|
||||
|
||||
return {
|
||||
Client: jest.fn().mockImplementation(() => ({
|
||||
Client: vi.fn().mockImplementation(() => ({
|
||||
connect: connectMock,
|
||||
listTools: listToolsMock,
|
||||
callTool: callToolMock,
|
||||
@@ -64,84 +69,78 @@ jest.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('fs');
|
||||
jest.mock('path');
|
||||
vi.mock('fs', () => ({
|
||||
readFileSync: vi.fn(),
|
||||
}));
|
||||
vi.mock('path', () => ({
|
||||
resolve: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the logger
|
||||
jest.mock('../src/logger.js', () => {
|
||||
vi.mock('../src/logger.js', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// Create placeholder mocks that will be replaced in beforeEach
|
||||
jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({
|
||||
SSEClientTransport: jest.fn(),
|
||||
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({
|
||||
SSEClientTransport: vi.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({
|
||||
StdioClientTransport: jest.fn(),
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({
|
||||
StdioClientTransport: vi.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||
Client: jest.fn(),
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||
Client: vi.fn(),
|
||||
}));
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-unused-vars */
|
||||
import {
|
||||
MultiServerMCPClient,
|
||||
MCPClientError,
|
||||
StdioConnection,
|
||||
SSEConnection,
|
||||
} from '../src/client.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import { StructuredToolInterface } from '@langchain/core/tools';
|
||||
import { z } from 'zod';
|
||||
import { loadMcpTools } from '../src/tools.js';
|
||||
/* eslint-enable @typescript-eslint/no-unused-vars */
|
||||
const { MultiServerMCPClient, MCPClientError } = await import('../src/client.js');
|
||||
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
||||
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const fs = await import('fs');
|
||||
const path = await import('path');
|
||||
|
||||
describe('MultiServerMCPClient', () => {
|
||||
// Create mock implementations that will be used throughout the tests
|
||||
let mockClientConnect: jest.Mock;
|
||||
let mockClientListTools: jest.Mock;
|
||||
let mockClientCallTool: jest.Mock;
|
||||
let mockClientClose: jest.Mock;
|
||||
let mockClientConnect: vi.Mock;
|
||||
let mockClientListTools: vi.Mock;
|
||||
let mockClientCallTool: vi.Mock;
|
||||
let mockClientClose: vi.Mock;
|
||||
|
||||
let mockStdioTransportClose: jest.Mock;
|
||||
let mockStdioTransportConnect: jest.Mock;
|
||||
let mockStdioTransportSend: jest.Mock;
|
||||
let mockStdioTransportClose: vi.Mock;
|
||||
let mockStdioTransportConnect: vi.Mock;
|
||||
let mockStdioTransportSend: vi.Mock;
|
||||
// Define specific function type for onclose handlers
|
||||
let mockStdioOnClose: (() => void) | null;
|
||||
|
||||
let mockSSETransportClose: jest.Mock;
|
||||
let mockSSETransportConnect: jest.Mock;
|
||||
let mockSSETransportSend: jest.Mock;
|
||||
let mockSSETransportClose: vi.Mock;
|
||||
let mockSSETransportConnect: vi.Mock;
|
||||
let mockSSETransportSend: vi.Mock;
|
||||
// Define specific function type for onclose handlers
|
||||
let mockSSEOnClose: (() => void) | null;
|
||||
|
||||
// Setup and teardown
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Set up mock implementations for Client
|
||||
mockClientConnect = jest.fn().mockResolvedValue(undefined);
|
||||
mockClientListTools = jest.fn().mockResolvedValue({ tools: [] });
|
||||
mockClientCallTool = jest
|
||||
mockClientConnect = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockClientListTools = vi.fn().mockReturnValue(Promise.resolve({ tools: [] }));
|
||||
mockClientCallTool = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ content: [{ type: 'text', text: 'result' }] });
|
||||
mockClientClose = jest.fn().mockResolvedValue(undefined);
|
||||
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] }));
|
||||
mockClientClose = vi.fn().mockReturnValue(Promise.resolve());
|
||||
|
||||
(Client as jest.Mock).mockImplementation(() => ({
|
||||
(Client as vi.Mock).mockImplementation(() => ({
|
||||
connect: mockClientConnect,
|
||||
listTools: mockClientListTools,
|
||||
callTool: mockClientCallTool,
|
||||
@@ -149,12 +148,12 @@ describe('MultiServerMCPClient', () => {
|
||||
}));
|
||||
|
||||
// Set up mock implementations for StdioClientTransport
|
||||
mockStdioTransportClose = jest.fn().mockResolvedValue(undefined);
|
||||
mockStdioTransportConnect = jest.fn().mockResolvedValue(undefined);
|
||||
mockStdioTransportSend = jest.fn().mockResolvedValue(undefined);
|
||||
mockStdioTransportClose = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockStdioTransportConnect = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockStdioTransportSend = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockStdioOnClose = null;
|
||||
|
||||
(StdioClientTransport as jest.Mock).mockImplementation(() => {
|
||||
(StdioClientTransport as vi.Mock).mockImplementation(() => {
|
||||
const transport = {
|
||||
close: mockStdioTransportClose,
|
||||
connect: mockStdioTransportConnect,
|
||||
@@ -172,12 +171,12 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
// Set up mock implementations for SSEClientTransport
|
||||
mockSSETransportClose = jest.fn().mockResolvedValue(undefined);
|
||||
mockSSETransportConnect = jest.fn().mockResolvedValue(undefined);
|
||||
mockSSETransportSend = jest.fn().mockResolvedValue(undefined);
|
||||
mockSSETransportClose = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockSSETransportConnect = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockSSETransportSend = vi.fn().mockReturnValue(Promise.resolve());
|
||||
mockSSEOnClose = null;
|
||||
|
||||
(SSEClientTransport as jest.Mock).mockImplementation(() => {
|
||||
(SSEClientTransport as vi.Mock).mockImplementation(() => {
|
||||
const transport = {
|
||||
close: mockSSETransportClose,
|
||||
connect: mockSSETransportConnect,
|
||||
@@ -194,7 +193,7 @@ describe('MultiServerMCPClient', () => {
|
||||
return transport;
|
||||
});
|
||||
|
||||
(fs.readFileSync as jest.Mock).mockImplementation(() =>
|
||||
(fs.readFileSync as vi.Mock).mockImplementation(() =>
|
||||
JSON.stringify({
|
||||
servers: {
|
||||
'test-server': {
|
||||
@@ -206,11 +205,11 @@ describe('MultiServerMCPClient', () => {
|
||||
})
|
||||
);
|
||||
|
||||
(path.resolve as jest.Mock).mockImplementation(p => p);
|
||||
(path.resolve as vi.Mock).mockImplementation(p => p);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
// 1. Constructor functionality tests
|
||||
@@ -221,44 +220,39 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should process valid stdio connection config', () => {
|
||||
const config = {
|
||||
const client = new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
transport: 'stdio',
|
||||
command: 'python',
|
||||
args: ['./script.py'],
|
||||
},
|
||||
};
|
||||
|
||||
const client = new MultiServerMCPClient(config);
|
||||
});
|
||||
expect(client).toBeDefined();
|
||||
// Additional assertions to verify the connection was processed correctly
|
||||
});
|
||||
|
||||
test('should process valid SSE connection config', () => {
|
||||
const config = {
|
||||
const client = new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
transport: 'sse',
|
||||
url: 'http://localhost:8000/sse',
|
||||
headers: { Authorization: 'Bearer token' },
|
||||
useNodeEventSource: true,
|
||||
},
|
||||
};
|
||||
|
||||
const client = new MultiServerMCPClient(config);
|
||||
});
|
||||
expect(client).toBeDefined();
|
||||
// Additional assertions to verify the connection was processed correctly
|
||||
});
|
||||
|
||||
test('should handle invalid connection config gracefully', () => {
|
||||
const config = {
|
||||
'test-server': {
|
||||
transport: 'invalid',
|
||||
},
|
||||
};
|
||||
|
||||
const client = new MultiServerMCPClient(config);
|
||||
expect(client).toBeDefined();
|
||||
// Verify that the invalid config was not processed
|
||||
test('should have a compile time error and a runtime error when the config is invalid', () => {
|
||||
expect(() => {
|
||||
new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
// @ts-expect-error shouldn't match type constraints here
|
||||
transport: 'invalid',
|
||||
},
|
||||
});
|
||||
}).toThrow(MCPClientError);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -271,7 +265,7 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should throw error for invalid config file', () => {
|
||||
(fs.readFileSync as jest.Mock).mockImplementation(() => {
|
||||
(fs.readFileSync as vi.Mock).mockImplementation(() => {
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
@@ -281,7 +275,7 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should throw error for invalid JSON in config file', () => {
|
||||
(fs.readFileSync as jest.Mock).mockImplementation(() => 'invalid json');
|
||||
(fs.readFileSync as vi.Mock).mockImplementation(() => 'invalid json');
|
||||
|
||||
expect(() => {
|
||||
MultiServerMCPClient.fromConfigFile('./invalid.json');
|
||||
@@ -329,10 +323,10 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(mockClientListTools).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle connection failures gracefully', async () => {
|
||||
(Client as jest.Mock).mockImplementation(() => ({
|
||||
connect: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
||||
listTools: jest.fn().mockResolvedValue({ tools: [] }),
|
||||
test('should throw on connection failure', async () => {
|
||||
(Client as vi.Mock).mockImplementation(() => ({
|
||||
connect: vi.fn().mockReturnValue(Promise.reject(new Error('Connection failed'))),
|
||||
listTools: vi.fn().mockReturnValue(Promise.resolve({ tools: [] })),
|
||||
}));
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
@@ -343,14 +337,13 @@ describe('MultiServerMCPClient', () => {
|
||||
},
|
||||
});
|
||||
|
||||
await client.initializeConnections();
|
||||
// Verify that the error was handled gracefully
|
||||
await expect(() => client.initializeConnections()).rejects.toThrow(MCPClientError);
|
||||
});
|
||||
|
||||
test('should handle tool loading failures gracefully', async () => {
|
||||
(Client as jest.Mock).mockImplementation(() => ({
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
listTools: jest.fn().mockRejectedValue(new Error('Failed to list tools')),
|
||||
test('should throw on tool loading failures', async () => {
|
||||
(Client as vi.Mock).mockImplementation(() => ({
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
listTools: vi.fn().mockReturnValue(Promise.reject(new Error('Failed to list tools'))),
|
||||
}));
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
@@ -361,8 +354,7 @@ describe('MultiServerMCPClient', () => {
|
||||
},
|
||||
});
|
||||
|
||||
await client.initializeConnections();
|
||||
// Verify that the error was handled gracefully
|
||||
await expect(() => client.initializeConnections()).rejects.toThrow(MCPClientError);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -385,7 +377,7 @@ describe('MultiServerMCPClient', () => {
|
||||
await client.initializeConnections();
|
||||
|
||||
// Reset the call counts to focus on reconnection
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
|
||||
// Trigger the onclose handler if it exists
|
||||
if (mockStdioOnClose) {
|
||||
@@ -415,7 +407,7 @@ describe('MultiServerMCPClient', () => {
|
||||
await client.initializeConnections();
|
||||
|
||||
// Reset the call counts to focus on reconnection
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
|
||||
// Trigger the onclose handler if it exists
|
||||
if (mockSSEOnClose) {
|
||||
@@ -449,9 +441,9 @@ describe('MultiServerMCPClient', () => {
|
||||
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
|
||||
];
|
||||
|
||||
(Client as jest.Mock).mockImplementation(() => ({
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
listTools: jest.fn().mockResolvedValue({ tools: mockTools }),
|
||||
(Client as vi.Mock).mockImplementation(() => ({
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
listTools: vi.fn().mockReturnValue(Promise.resolve({ tools: mockTools })),
|
||||
}));
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
@@ -508,8 +500,8 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should handle errors during cleanup gracefully', async () => {
|
||||
// Mock close to throw an error
|
||||
(StdioClientTransport as jest.Mock).mockImplementation(() => ({
|
||||
close: jest.fn().mockRejectedValue(new Error('Close failed')),
|
||||
(StdioClientTransport as vi.Mock).mockImplementation(() => ({
|
||||
close: vi.fn().mockReturnValue(Promise.reject(new Error('Close failed'))),
|
||||
onclose: null,
|
||||
}));
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { vi, describe, test, expect, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
// Mock fs module before imports
|
||||
jest.mock('fs', () => {
|
||||
vi.mock('fs', () => {
|
||||
// Create a map to store mock file contents
|
||||
const mockFiles = {
|
||||
'./mcp.json': JSON.stringify({
|
||||
@@ -16,7 +18,7 @@ jest.mock('fs', () => {
|
||||
};
|
||||
|
||||
return {
|
||||
readFileSync: jest.fn((path: string, _encoding?: string) => {
|
||||
readFileSync: vi.fn((path: string, _encoding?: string) => {
|
||||
if (path === './nonexistent.json') {
|
||||
throw new Error('File not found');
|
||||
}
|
||||
@@ -25,27 +27,29 @@ jest.mock('fs', () => {
|
||||
}
|
||||
throw new Error(`Mock file not found: ${path}`);
|
||||
}),
|
||||
existsSync: jest.fn((path: string) => {
|
||||
existsSync: vi.fn((path: string) => {
|
||||
return Object.prototype.hasOwnProperty.call(mockFiles, path);
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock path module
|
||||
jest.mock('path', () => ({
|
||||
join: jest.fn((...args) => args.join('/')),
|
||||
resolve: jest.fn((...args) => args.join('/')),
|
||||
dirname: jest.fn(path => path.split('/').slice(0, -1).join('/')),
|
||||
basename: jest.fn(path => path.split('/').pop()),
|
||||
}));
|
||||
vi.mock('path', () => {
|
||||
return {
|
||||
join: vi.fn((...args) => args.join('/')),
|
||||
resolve: vi.fn((...args) => args.join('/')),
|
||||
dirname: vi.fn((path: string) => path.split('/').slice(0, -1).join('/')),
|
||||
basename: vi.fn((path: string) => path.split('/').pop()),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock the logger module
|
||||
jest.mock('../src/logger.js', () => {
|
||||
vi.mock('../src/logger.js', () => {
|
||||
const mockLogger = {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
return {
|
||||
__esModule: true,
|
||||
@@ -55,30 +59,32 @@ jest.mock('../src/logger.js', () => {
|
||||
});
|
||||
|
||||
// Set up mocks for external modules
|
||||
jest.mock(
|
||||
vi.mock(
|
||||
'@modelcontextprotocol/sdk/client/index.js',
|
||||
() => {
|
||||
return {
|
||||
Client: jest.fn().mockImplementation(() => ({
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
description: 'Test tool 1',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
{
|
||||
name: 'tool2',
|
||||
description: 'Test tool 2',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
callTool: jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'result' }],
|
||||
}),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
Client: vi.fn().mockImplementation(() => ({
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
listTools: vi.fn().mockReturnValue(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
description: 'Test tool 1',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
{
|
||||
name: 'tool2',
|
||||
description: 'Test tool 2',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
})
|
||||
),
|
||||
callTool: vi
|
||||
.fn()
|
||||
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] })),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
tools: [], // Add the tools property
|
||||
})),
|
||||
};
|
||||
@@ -86,16 +92,16 @@ jest.mock(
|
||||
{ virtual: true }
|
||||
);
|
||||
|
||||
jest.mock(
|
||||
vi.mock(
|
||||
'@modelcontextprotocol/sdk/client/stdio.js',
|
||||
() => {
|
||||
// Using the OnCloseHandler type defined at the top level
|
||||
return {
|
||||
StdioClientTransport: jest.fn().mockImplementation(config => {
|
||||
StdioClientTransport: vi.fn().mockImplementation(config => {
|
||||
const transport = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
send: jest.fn().mockResolvedValue(undefined),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
send: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
onclose: null as OnCloseHandler | null,
|
||||
config,
|
||||
};
|
||||
@@ -106,16 +112,16 @@ jest.mock(
|
||||
{ virtual: true }
|
||||
);
|
||||
|
||||
jest.mock(
|
||||
vi.mock(
|
||||
'@modelcontextprotocol/sdk/client/sse.js',
|
||||
() => {
|
||||
// Using the OnCloseHandler type defined at the top level
|
||||
return {
|
||||
SSEClientTransport: jest.fn().mockImplementation(config => {
|
||||
SSEClientTransport: vi.fn().mockImplementation(config => {
|
||||
const transport = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
send: jest.fn().mockResolvedValue(undefined),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
send: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
onclose: null as OnCloseHandler | null,
|
||||
config,
|
||||
};
|
||||
@@ -130,52 +136,54 @@ jest.mock(
|
||||
type OnCloseHandler = () => void;
|
||||
|
||||
// Import modules after mocking
|
||||
import * as fs from 'fs';
|
||||
import { MultiServerMCPClient, MCPClientError } from '../src/client.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
const fs = await import('fs');
|
||||
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
|
||||
const { MultiServerMCPClient, MCPClientError } = await import('../src/client.js');
|
||||
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
||||
|
||||
// Create mock objects that will be accessible throughout the tests
|
||||
const mockClientMethods = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
description: 'Test tool 1',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
{
|
||||
name: 'tool2',
|
||||
description: 'Test tool 2',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
callTool: jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'result' }],
|
||||
}),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
listTools: vi.fn().mockReturnValue(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
description: 'Test tool 1',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
{
|
||||
name: 'tool2',
|
||||
description: 'Test tool 2',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
})
|
||||
),
|
||||
callTool: vi
|
||||
.fn()
|
||||
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] })),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
};
|
||||
|
||||
const mockStdioMethods = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
send: jest.fn().mockResolvedValue(undefined),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
triggerOnclose: jest.fn(),
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
send: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
triggerOnclose: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSSEMethods = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
send: jest.fn().mockResolvedValue(undefined),
|
||||
close: jest.fn().mockResolvedValue(undefined),
|
||||
triggerOnclose: jest.fn(),
|
||||
connect: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
send: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
close: vi.fn().mockReturnValue(Promise.resolve()),
|
||||
triggerOnclose: vi.fn(),
|
||||
};
|
||||
|
||||
// Reset mocks before each test
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Reset mock methods
|
||||
Object.values(mockClientMethods).forEach(mock => mock.mockClear());
|
||||
@@ -183,7 +191,7 @@ beforeEach(() => {
|
||||
Object.values(mockSSEMethods).forEach(mock => mock.mockClear());
|
||||
|
||||
// Reset mock implementations
|
||||
mockClientMethods.listTools.mockResolvedValue({
|
||||
mockClientMethods.listTools.mockReturnValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
@@ -199,7 +207,7 @@ beforeEach(() => {
|
||||
});
|
||||
|
||||
// Reset the mock implementations for the imported modules
|
||||
(Client as jest.Mock).mockImplementation(() => ({
|
||||
(Client as vi.Mock).mockImplementation(() => ({
|
||||
connect: mockClientMethods.connect,
|
||||
listTools: mockClientMethods.listTools,
|
||||
callTool: mockClientMethods.callTool,
|
||||
@@ -207,7 +215,7 @@ beforeEach(() => {
|
||||
tools: [],
|
||||
}));
|
||||
|
||||
(StdioClientTransport as jest.Mock).mockImplementation(config => {
|
||||
(StdioClientTransport as vi.Mock).mockImplementation(config => {
|
||||
const transport = {
|
||||
connect: mockStdioMethods.connect,
|
||||
send: mockStdioMethods.send,
|
||||
@@ -215,13 +223,13 @@ beforeEach(() => {
|
||||
onclose: null as OnCloseHandler | null,
|
||||
config,
|
||||
};
|
||||
mockStdioMethods.triggerOnclose = jest.fn(() => {
|
||||
mockStdioMethods.triggerOnclose = vi.fn(() => {
|
||||
if (transport.onclose) transport.onclose();
|
||||
});
|
||||
return transport;
|
||||
});
|
||||
|
||||
(SSEClientTransport as jest.Mock).mockImplementation(config => {
|
||||
(SSEClientTransport as vi.Mock).mockImplementation(config => {
|
||||
const transport = {
|
||||
connect: mockSSEMethods.connect,
|
||||
send: mockSSEMethods.send,
|
||||
@@ -229,7 +237,7 @@ beforeEach(() => {
|
||||
onclose: null as OnCloseHandler | null,
|
||||
config,
|
||||
};
|
||||
mockSSEMethods.triggerOnclose = jest.fn(() => {
|
||||
mockSSEMethods.triggerOnclose = vi.fn(() => {
|
||||
if (transport.onclose) transport.onclose();
|
||||
});
|
||||
return transport;
|
||||
@@ -238,7 +246,7 @@ beforeEach(() => {
|
||||
|
||||
describe('MultiServerMCPClient', () => {
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('Constructor', () => {
|
||||
@@ -251,7 +259,7 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(tools).toEqual([]);
|
||||
});
|
||||
|
||||
test('should process valid stdio connection config', () => {
|
||||
test('should process valid stdio connection config', async () => {
|
||||
const config = {
|
||||
'test-server': {
|
||||
transport: 'stdio' as const,
|
||||
@@ -264,13 +272,12 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(client).toBeDefined();
|
||||
|
||||
// Initialize connections and verify
|
||||
return client.initializeConnections().then(() => {
|
||||
expect(StdioClientTransport).toHaveBeenCalled();
|
||||
expect(Client).toHaveBeenCalled();
|
||||
});
|
||||
await client.initializeConnections();
|
||||
expect(StdioClientTransport).toHaveBeenCalled();
|
||||
expect(Client).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should process valid SSE connection config', () => {
|
||||
test('should process valid SSE connection config', async () => {
|
||||
const config = {
|
||||
'test-server': {
|
||||
transport: 'sse' as const,
|
||||
@@ -284,13 +291,12 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(client).toBeDefined();
|
||||
|
||||
// Initialize connections and verify
|
||||
return client.initializeConnections().then(() => {
|
||||
expect(SSEClientTransport).toHaveBeenCalled();
|
||||
expect(Client).toHaveBeenCalled();
|
||||
});
|
||||
await client.initializeConnections();
|
||||
expect(SSEClientTransport).toHaveBeenCalled();
|
||||
expect(Client).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle invalid connection type gracefully', () => {
|
||||
test('should handle invalid connection type gracefully', async () => {
|
||||
const config = {
|
||||
'test-server': {
|
||||
transport: 'invalid' as any,
|
||||
@@ -298,34 +304,28 @@ describe('MultiServerMCPClient', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const client = new MultiServerMCPClient(config);
|
||||
expect(client).toBeDefined();
|
||||
|
||||
// Initialize connections and verify no error is thrown
|
||||
return client.initializeConnections().then(() => {
|
||||
// No connections should be initialized
|
||||
expect(SSEClientTransport).not.toHaveBeenCalled();
|
||||
expect(StdioClientTransport).not.toHaveBeenCalled();
|
||||
});
|
||||
// Should throw error during initialization
|
||||
expect(() => {
|
||||
new MultiServerMCPClient(config);
|
||||
}).toThrow(MCPClientError);
|
||||
});
|
||||
|
||||
test('should gracefully handle empty config', () => {
|
||||
test('should gracefully handle empty config', async () => {
|
||||
const client = new MultiServerMCPClient({});
|
||||
expect(client).toBeDefined();
|
||||
|
||||
// Initialize connections and verify no error is thrown
|
||||
return client.initializeConnections().then(() => {
|
||||
// No connections should be initialized
|
||||
expect(SSEClientTransport).not.toHaveBeenCalled();
|
||||
expect(StdioClientTransport).not.toHaveBeenCalled();
|
||||
});
|
||||
await client.initializeConnections();
|
||||
// No connections should be initialized
|
||||
expect(SSEClientTransport).not.toHaveBeenCalled();
|
||||
expect(StdioClientTransport).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Configuration Loading', () => {
|
||||
test('should load config from a valid file', () => {
|
||||
test('should load config from a valid file', async () => {
|
||||
// Mock fs.readFileSync to return valid JSON
|
||||
(fs.readFileSync as jest.Mock).mockReturnValueOnce(
|
||||
(fs.readFileSync as vi.Mock).mockReturnValueOnce(
|
||||
JSON.stringify({
|
||||
servers: {
|
||||
'test-server': {
|
||||
@@ -343,7 +343,7 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should throw error for nonexistent config file', () => {
|
||||
(fs.existsSync as jest.Mock).mockReturnValueOnce(false);
|
||||
(fs.existsSync as vi.Mock).mockReturnValueOnce(false);
|
||||
|
||||
expect(() => {
|
||||
MultiServerMCPClient.fromConfigFile('./nonexistent.json');
|
||||
@@ -351,7 +351,7 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should throw error for invalid JSON in config file', () => {
|
||||
(fs.readFileSync as jest.Mock).mockReturnValueOnce('invalid json');
|
||||
(fs.readFileSync as vi.Mock).mockReturnValueOnce('invalid json');
|
||||
|
||||
expect(() => {
|
||||
MultiServerMCPClient.fromConfigFile('./invalid.json');
|
||||
@@ -360,7 +360,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should throw error for invalid config structure', () => {
|
||||
// Mock readFileSync to return a config without the required 'servers' property
|
||||
(fs.readFileSync as jest.Mock).mockReturnValueOnce(
|
||||
(fs.readFileSync as vi.Mock).mockReturnValueOnce(
|
||||
JSON.stringify({
|
||||
notServers: {}, // This missing 'servers' property
|
||||
})
|
||||
@@ -378,7 +378,7 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should throw error for file system errors', () => {
|
||||
(fs.readFileSync as jest.Mock).mockImplementationOnce(() => {
|
||||
(fs.readFileSync as vi.Mock).mockImplementationOnce(() => {
|
||||
throw new Error('File system error');
|
||||
});
|
||||
|
||||
@@ -390,23 +390,20 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
describe('Connection Management', () => {
|
||||
test('should initialize stdio connections correctly', async () => {
|
||||
// Create a properly structured config
|
||||
const config = {
|
||||
// Create a client instance with the config
|
||||
const client = new MultiServerMCPClient({
|
||||
'stdio-server': {
|
||||
transport: 'stdio',
|
||||
command: 'python',
|
||||
args: ['./script.py'],
|
||||
},
|
||||
};
|
||||
|
||||
// Create a client instance with the config
|
||||
const client = new MultiServerMCPClient(config);
|
||||
});
|
||||
|
||||
// Reset mocks to ensure clean state
|
||||
jest.clearAllMocks();
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Set up specific implementation for the StdioClientTransport mock
|
||||
(StdioClientTransport as jest.Mock).mockImplementationOnce(options => {
|
||||
(StdioClientTransport as vi.Mock).mockImplementationOnce(options => {
|
||||
return {
|
||||
connect: mockStdioMethods.connect,
|
||||
send: mockStdioMethods.send,
|
||||
@@ -432,22 +429,19 @@ describe('MultiServerMCPClient', () => {
|
||||
});
|
||||
|
||||
test('should initialize SSE connections correctly', async () => {
|
||||
// Create a properly structured config for SSE
|
||||
const config = {
|
||||
// Create a client instance with the config
|
||||
const client = new MultiServerMCPClient({
|
||||
'sse-server': {
|
||||
transport: 'sse',
|
||||
url: 'http://example.com/sse',
|
||||
},
|
||||
};
|
||||
|
||||
// Create a client instance with the config
|
||||
const client = new MultiServerMCPClient(config);
|
||||
});
|
||||
|
||||
// Reset mocks to ensure clean state
|
||||
jest.clearAllMocks();
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Set up specific implementation for the SSEClientTransport mock
|
||||
(SSEClientTransport as jest.Mock).mockImplementationOnce((url, options) => {
|
||||
(SSEClientTransport as vi.Mock).mockImplementationOnce((url, options) => {
|
||||
return {
|
||||
connect: mockSSEMethods.connect,
|
||||
send: mockSSEMethods.send,
|
||||
@@ -468,9 +462,9 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(mockClientMethods.connect).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle connection failures gracefully', async () => {
|
||||
test('should throw on connection failures', async () => {
|
||||
// Mock connection failure
|
||||
mockClientMethods.connect.mockRejectedValueOnce(new Error('Connection failed'));
|
||||
mockClientMethods.connect.mockReturnValueOnce(Promise.reject(new Error('Connection failed')));
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
@@ -480,18 +474,15 @@ describe('MultiServerMCPClient', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Should not throw error
|
||||
await client.initializeConnections();
|
||||
|
||||
// Still called connect but handled error
|
||||
expect(mockClientMethods.connect).toHaveBeenCalled();
|
||||
// Should not try to list tools after failed connection
|
||||
expect(mockClientMethods.listTools).not.toHaveBeenCalled();
|
||||
// Should throw error
|
||||
await expect(client.initializeConnections()).rejects.toThrow();
|
||||
});
|
||||
|
||||
test('should handle tool loading failures gracefully', async () => {
|
||||
test('should throw on tool loading failures', async () => {
|
||||
// Mock tool loading failure
|
||||
mockClientMethods.listTools.mockRejectedValueOnce(new Error('Failed to list tools'));
|
||||
mockClientMethods.listTools.mockReturnValueOnce(
|
||||
Promise.reject(new Error('Failed to list tools'))
|
||||
);
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
@@ -501,16 +492,8 @@ describe('MultiServerMCPClient', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Should not throw error
|
||||
await client.initializeConnections();
|
||||
|
||||
// Connection succeeded but tool loading failed
|
||||
expect(mockClientMethods.connect).toHaveBeenCalled();
|
||||
expect(mockClientMethods.listTools).toHaveBeenCalled();
|
||||
|
||||
// Should have empty tools
|
||||
const tools = client.getTools();
|
||||
expect(tools).toEqual([]);
|
||||
// Should throw error
|
||||
await expect(client.initializeConnections()).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -532,7 +515,7 @@ describe('MultiServerMCPClient', () => {
|
||||
await client.initializeConnections();
|
||||
|
||||
// Clear previous calls
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
mockClientMethods.connect.mockClear();
|
||||
|
||||
// Trigger onclose handler
|
||||
@@ -563,7 +546,7 @@ describe('MultiServerMCPClient', () => {
|
||||
await client.initializeConnections();
|
||||
|
||||
// Clear previous calls
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
mockClientMethods.connect.mockClear();
|
||||
|
||||
// Trigger onclose handler
|
||||
@@ -584,7 +567,7 @@ describe('MultiServerMCPClient', () => {
|
||||
const client = new MultiServerMCPClient();
|
||||
|
||||
// Clear previous mock invocations
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
|
||||
// Connect with reconnection enabled
|
||||
await client.connectToServerViaStdio(
|
||||
@@ -623,7 +606,7 @@ describe('MultiServerMCPClient', () => {
|
||||
await client.initializeConnections();
|
||||
|
||||
// Clear previous calls
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
|
||||
// Trigger onclose handler
|
||||
mockStdioMethods.triggerOnclose();
|
||||
@@ -639,12 +622,14 @@ describe('MultiServerMCPClient', () => {
|
||||
describe('Tool Management', () => {
|
||||
test('should get all tools as a flattened array', async () => {
|
||||
// Mock tool response
|
||||
mockClientMethods.listTools.mockResolvedValue({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1', inputSchema: {} },
|
||||
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
|
||||
],
|
||||
});
|
||||
mockClientMethods.listTools.mockReturnValue(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1', inputSchema: {} },
|
||||
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
server1: {
|
||||
@@ -717,7 +702,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should handle errors during cleanup gracefully', async () => {
|
||||
// Mock close to throw error
|
||||
mockStdioMethods.close.mockRejectedValueOnce(new Error('Close failed'));
|
||||
mockStdioMethods.close.mockReturnValueOnce(Promise.reject(new Error('Close failed')));
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
'test-server': {
|
||||
@@ -738,8 +723,8 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should clean up all resources even if some fail', async () => {
|
||||
// First close fails, second succeeds
|
||||
mockStdioMethods.close.mockRejectedValueOnce(new Error('Close failed'));
|
||||
mockSSEMethods.close.mockResolvedValue(undefined);
|
||||
mockStdioMethods.close.mockReturnValueOnce(Promise.reject(new Error('Close failed')));
|
||||
mockSSEMethods.close.mockReturnValueOnce(Promise.resolve());
|
||||
|
||||
const client = new MultiServerMCPClient({
|
||||
'stdio-server': {
|
||||
@@ -821,7 +806,7 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(StdioClientTransport).toHaveBeenCalled();
|
||||
|
||||
// Simulate connection close to test restart
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
mockStdioMethods.triggerOnclose();
|
||||
|
||||
// Wait for reconnection
|
||||
@@ -833,7 +818,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should connect to an SSE server correctly', async () => {
|
||||
// Clear previous mock invocations
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
|
||||
const client = new MultiServerMCPClient();
|
||||
await client.connectToServerViaSSE('test-server', 'http://localhost:8000/sse');
|
||||
@@ -846,7 +831,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should connect with headers', async () => {
|
||||
// Clear previous mock invocations
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
|
||||
const client = new MultiServerMCPClient();
|
||||
const headers = { Authorization: 'Bearer token' };
|
||||
@@ -860,7 +845,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
test('should connect with useNodeEventSource option', async () => {
|
||||
// Clear previous mock invocations
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
|
||||
const client = new MultiServerMCPClient();
|
||||
await client.connectToServerViaSSE(
|
||||
@@ -891,7 +876,7 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(SSEClientTransport).toHaveBeenCalled();
|
||||
|
||||
// Simulate connection close to test reconnect
|
||||
(SSEClientTransport as jest.Mock).mockClear();
|
||||
(SSEClientTransport as vi.Mock).mockClear();
|
||||
mockSSEMethods.triggerOnclose();
|
||||
|
||||
// Wait for reconnection
|
||||
@@ -909,7 +894,7 @@ describe('MultiServerMCPClient', () => {
|
||||
|
||||
// Clear mock history
|
||||
mockStdioMethods.close.mockClear();
|
||||
(StdioClientTransport as jest.Mock).mockClear();
|
||||
(StdioClientTransport as vi.Mock).mockClear();
|
||||
|
||||
// Connect again with same name (should close previous)
|
||||
await client.connectToServerViaStdio('test-server', 'node', ['script.js']);
|
||||
@@ -936,16 +921,18 @@ describe('MultiServerMCPClient', () => {
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
test('should handle transport creation errors', async () => {
|
||||
test('should throw on transport creation errors', async () => {
|
||||
// Force an error when creating transport
|
||||
(StdioClientTransport as jest.Mock).mockImplementationOnce(() => {
|
||||
(StdioClientTransport as vi.Mock).mockImplementationOnce(() => {
|
||||
throw new Error('Transport creation failed');
|
||||
});
|
||||
|
||||
const client = new MultiServerMCPClient();
|
||||
|
||||
// Should not throw
|
||||
await client.connectToServerViaStdio('test-server', 'python', ['./script.py']);
|
||||
// Should throw error when connecting
|
||||
await expect(
|
||||
client.connectToServerViaStdio('test-server', 'python', ['./script.py'])
|
||||
).rejects.toThrow();
|
||||
|
||||
// Should have attempted to create transport
|
||||
expect(StdioClientTransport).toHaveBeenCalled();
|
||||
|
||||
+40
-34
@@ -1,31 +1,39 @@
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import winston from 'winston';
|
||||
|
||||
import { describe, test, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
// Mock fs and path modules
|
||||
jest.mock('fs');
|
||||
jest.mock('path');
|
||||
vi.mock('fs', () => ({
|
||||
existsSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
writeFileSync: vi.fn(),
|
||||
unlinkSync: vi.fn(),
|
||||
accessSync: vi.fn(),
|
||||
}));
|
||||
vi.mock('path', () => ({
|
||||
join: vi.fn(),
|
||||
}));
|
||||
const fs = await import('fs');
|
||||
const path = await import('path');
|
||||
const winston = await import('winston');
|
||||
|
||||
describe('Logger', () => {
|
||||
// Store original console.warn implementation
|
||||
const originalConsoleWarn = console.warn;
|
||||
let consoleWarnMock: jest.SpyInstance;
|
||||
let consoleWarnMock: any;
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear module cache to ensure logger is reinitialized
|
||||
jest.resetModules();
|
||||
vi.resetModules();
|
||||
|
||||
// Mock console.warn to capture warnings
|
||||
consoleWarnMock = jest.spyOn(console, 'warn').mockImplementation();
|
||||
consoleWarnMock = vi.spyOn(console, 'warn').mockImplementation((..._args) => {});
|
||||
|
||||
// Configure path.join to return predictable paths
|
||||
(path.join as jest.Mock).mockImplementation((...args) => args.join('/'));
|
||||
(path.join as any).mockImplementation((...args: string[]) => args.join('/'));
|
||||
|
||||
// Reset fs mock implementation
|
||||
(fs.existsSync as jest.Mock).mockReset();
|
||||
(fs.mkdirSync as jest.Mock).mockReset();
|
||||
(fs.writeFileSync as jest.Mock).mockReset();
|
||||
(fs.unlinkSync as jest.Mock).mockReset();
|
||||
(fs.existsSync as any).mockReset();
|
||||
(fs.mkdirSync as any).mockReset();
|
||||
(fs.writeFileSync as any).mockReset();
|
||||
(fs.unlinkSync as any).mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -36,10 +44,10 @@ describe('Logger', () => {
|
||||
|
||||
test('should fallback to console-only logging when directory creation fails', async () => {
|
||||
// Mock fs.existsSync to return false (directory doesn't exist)
|
||||
(fs.existsSync as jest.Mock).mockReturnValue(false);
|
||||
(fs.existsSync as any).mockReturnValue(false);
|
||||
|
||||
// Mock fs.mkdirSync to throw an error
|
||||
(fs.mkdirSync as jest.Mock).mockImplementation(() => {
|
||||
(fs.mkdirSync as any).mockImplementation(() => {
|
||||
throw new Error('Permission denied');
|
||||
});
|
||||
|
||||
@@ -59,10 +67,10 @@ describe('Logger', () => {
|
||||
|
||||
test('should fallback to console-only logging when write permission test fails', async () => {
|
||||
// Mock fs.existsSync to return true (directory exists)
|
||||
(fs.existsSync as jest.Mock).mockReturnValue(true);
|
||||
(fs.existsSync as any).mockReturnValue(true);
|
||||
|
||||
// Mock fs.writeFileSync to throw an error
|
||||
(fs.writeFileSync as jest.Mock).mockImplementation(() => {
|
||||
(fs.writeFileSync as any).mockImplementation(() => {
|
||||
throw new Error('Permission denied');
|
||||
});
|
||||
|
||||
@@ -80,24 +88,22 @@ describe('Logger', () => {
|
||||
expect(logger.transports[0]).toBeInstanceOf(winston.transports.Console);
|
||||
});
|
||||
|
||||
test('should set up file transports when permissions are available', () => {
|
||||
test('should set up file transports when permissions are available', async () => {
|
||||
// Mock all the file operations to succeed
|
||||
(fs.mkdirSync as jest.Mock).mockImplementation(() => true);
|
||||
(fs.accessSync as jest.Mock).mockImplementation(() => true);
|
||||
(fs.existsSync as jest.Mock).mockReturnValue(true);
|
||||
(fs.writeFileSync as jest.Mock).mockImplementation(() => undefined);
|
||||
(fs.mkdirSync as any).mockImplementation(() => true);
|
||||
(fs.accessSync as any).mockImplementation(() => true);
|
||||
(fs.existsSync as any).mockReturnValue(true);
|
||||
(fs.writeFileSync as any).mockImplementation(() => undefined);
|
||||
|
||||
// Import logger (after mocks are set up)
|
||||
jest.isolateModules(async () => {
|
||||
const loggerModule = await import('../src/logger.js');
|
||||
const logger = loggerModule.default;
|
||||
// Import logger directly
|
||||
const loggerModule = await import('../src/logger.js');
|
||||
const logger = loggerModule.default;
|
||||
|
||||
// Just verify logger was created - don't worry about warnings
|
||||
expect(logger).toBeDefined();
|
||||
expect(typeof logger.debug).toBe('function');
|
||||
expect(typeof logger.info).toBe('function');
|
||||
expect(typeof logger.warn).toBe('function');
|
||||
expect(typeof logger.error).toBe('function');
|
||||
});
|
||||
// Just verify logger was created - don't worry about warnings
|
||||
expect(logger).toBeDefined();
|
||||
expect(typeof logger.debug).toBe('function');
|
||||
expect(typeof logger.info).toBe('function');
|
||||
expect(typeof logger.warn).toBe('function');
|
||||
expect(typeof logger.error).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
+70
-186
@@ -1,174 +1,31 @@
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { convertMcpToolToLangchainTool, loadMcpTools } from '../src/tools.js';
|
||||
import { z } from 'zod';
|
||||
import { describe, test, expect, beforeEach, vi } from 'vitest';
|
||||
const { loadMcpTools } = await import('../src/tools.js');
|
||||
|
||||
// Create a mock client
|
||||
const mockClient = {
|
||||
callTool: jest.fn(),
|
||||
listTools: jest.fn(),
|
||||
callTool: vi.fn(),
|
||||
listTools: vi.fn(),
|
||||
};
|
||||
|
||||
describe('Simplified Tool Adapter Tests', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('convertMcpToolToLangchainTool', () => {
|
||||
test('should convert MCP tool to LangChain tool with text content', async () => {
|
||||
// Set up mock tool
|
||||
const mcpTool = {
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Set up mock response
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Test result' }],
|
||||
});
|
||||
|
||||
// Convert tool
|
||||
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
|
||||
|
||||
// Verify tool properties
|
||||
expect(tool.name).toBe('testTool');
|
||||
expect(tool.description).toBe('A test tool');
|
||||
|
||||
// Call the tool
|
||||
const result = await tool.invoke({ input: 'test' });
|
||||
|
||||
// Verify that the client was called with the right arguments
|
||||
expect(mockClient.callTool).toHaveBeenCalledWith({
|
||||
name: 'testTool',
|
||||
arguments: { input: 'test' },
|
||||
});
|
||||
|
||||
// Verify result
|
||||
expect(result).toBe('Test result');
|
||||
});
|
||||
|
||||
test('should handle error results', async () => {
|
||||
// Set up mock tool
|
||||
const mcpTool = {
|
||||
name: 'errorTool',
|
||||
description: 'A tool that errors',
|
||||
};
|
||||
|
||||
// Set up mock response
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
isError: true,
|
||||
content: [{ type: 'text', text: 'Error message' }],
|
||||
});
|
||||
|
||||
// Convert tool
|
||||
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
|
||||
|
||||
// Call the tool and expect an error
|
||||
await expect(tool.invoke({ input: 'test' })).rejects.toThrow('Error message');
|
||||
});
|
||||
|
||||
test('should handle non-text content', async () => {
|
||||
// Set up mock tool
|
||||
const mcpTool = {
|
||||
name: 'imageTool',
|
||||
description: 'A tool that returns images',
|
||||
};
|
||||
|
||||
// Set up mock response with non-text content
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [
|
||||
{ type: 'text', text: 'Image caption' },
|
||||
{ type: 'image', url: 'http://example.com/image.jpg' },
|
||||
],
|
||||
});
|
||||
|
||||
// Convert tool
|
||||
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
|
||||
|
||||
// Call the tool
|
||||
const result = await tool.invoke({ input: 'test' });
|
||||
|
||||
// Verify result (should only include text content)
|
||||
expect(result).toBe('Image caption');
|
||||
});
|
||||
|
||||
test('should return both text and non-text content with content_and_artifact format', async () => {
|
||||
// Set up mock tool
|
||||
const mcpTool = {
|
||||
name: 'multiTool',
|
||||
description: 'A tool that returns multiple content types',
|
||||
};
|
||||
|
||||
// Set up mock response with mixed content
|
||||
const mockImageContent = { type: 'image', url: 'http://example.com/image.jpg' };
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Here is your image' }, mockImageContent],
|
||||
});
|
||||
|
||||
// Convert tool with content_and_artifact response format
|
||||
const tool = convertMcpToolToLangchainTool(
|
||||
mockClient as unknown as Client,
|
||||
mcpTool,
|
||||
undefined,
|
||||
'content_and_artifact'
|
||||
);
|
||||
|
||||
// Verify tool properties
|
||||
console.log('Tool class:', tool.constructor.name);
|
||||
console.log('Tool responseFormat:', (tool as any).responseFormat);
|
||||
|
||||
// Call the tool
|
||||
const result = await tool.invoke({ input: 'test' });
|
||||
|
||||
// Debug the result
|
||||
console.log('Result type:', typeof result);
|
||||
console.log('Result:', JSON.stringify(result));
|
||||
console.log('Is array:', Array.isArray(result));
|
||||
|
||||
// The result is the text content, as LangChain processes the array in the call method
|
||||
expect(result).toBe('Here is your image');
|
||||
|
||||
// Set up a second mock response for the direct call test
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Here is your image' }, mockImageContent],
|
||||
});
|
||||
|
||||
// Access the tool class implementation through the prototype chain
|
||||
const toolPrototype = Object.getPrototypeOf(tool);
|
||||
// Call the _call method directly with bound 'this' context
|
||||
const directResult = await toolPrototype._call.call(tool, { input: 'test' });
|
||||
|
||||
expect(Array.isArray(directResult)).toBe(true);
|
||||
|
||||
const [textContent, nonTextContent] = directResult as [string, any[]];
|
||||
|
||||
// Check the text content
|
||||
expect(textContent).toBe('Here is your image');
|
||||
|
||||
// Check the non-text content
|
||||
expect(Array.isArray(nonTextContent)).toBe(true);
|
||||
expect(nonTextContent.length).toBe(1);
|
||||
expect(nonTextContent[0]).toEqual(mockImageContent);
|
||||
});
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('loadMcpTools', () => {
|
||||
test('should load all tools from client', async () => {
|
||||
// Set up mock response
|
||||
mockClient.listTools.mockResolvedValueOnce({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1' },
|
||||
{ name: 'tool2', description: 'Tool 2' },
|
||||
],
|
||||
});
|
||||
mockClient.listTools.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1' },
|
||||
{ name: 'tool2', description: 'Tool 2' },
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
// Load tools
|
||||
const tools = await loadMcpTools(mockClient as unknown as Client);
|
||||
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
|
||||
|
||||
// Verify results
|
||||
expect(tools.length).toBe(2);
|
||||
@@ -178,12 +35,14 @@ describe('Simplified Tool Adapter Tests', () => {
|
||||
|
||||
test('should handle empty tool list', async () => {
|
||||
// Set up mock response
|
||||
mockClient.listTools.mockResolvedValueOnce({
|
||||
tools: [],
|
||||
});
|
||||
mockClient.listTools.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
tools: [],
|
||||
})
|
||||
);
|
||||
|
||||
// Load tools
|
||||
const tools = await loadMcpTools(mockClient as unknown as Client);
|
||||
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
|
||||
|
||||
// Verify results
|
||||
expect(tools.length).toBe(0);
|
||||
@@ -191,16 +50,18 @@ describe('Simplified Tool Adapter Tests', () => {
|
||||
|
||||
test('should filter out tools without names', async () => {
|
||||
// Set up mock response
|
||||
mockClient.listTools.mockResolvedValueOnce({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1' },
|
||||
{ description: 'No name tool' }, // Should be filtered out
|
||||
{ name: 'tool2', description: 'Tool 2' },
|
||||
],
|
||||
});
|
||||
mockClient.listTools.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{ name: 'tool1', description: 'Tool 1' },
|
||||
{ description: 'No name tool' }, // Should be filtered out
|
||||
{ name: 'tool2', description: 'Tool 2' },
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
// Load tools
|
||||
const tools = await loadMcpTools(mockClient as unknown as Client);
|
||||
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
|
||||
|
||||
// Verify results
|
||||
expect(tools.length).toBe(2);
|
||||
@@ -209,39 +70,62 @@ describe('Simplified Tool Adapter Tests', () => {
|
||||
});
|
||||
|
||||
test('should load tools with specified response format', async () => {
|
||||
// Set up mock response
|
||||
mockClient.listTools.mockResolvedValueOnce({
|
||||
tools: [{ name: 'tool1', description: 'Tool 1' }],
|
||||
});
|
||||
// Set up mock response with input schema
|
||||
mockClient.listTools.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
tools: [
|
||||
{
|
||||
name: 'tool1',
|
||||
description: 'Tool 1',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
required: ['input'],
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
// Load tools with content_and_artifact response format
|
||||
const tools = await loadMcpTools(mockClient as unknown as Client, 'content_and_artifact');
|
||||
const tools = await loadMcpTools(
|
||||
mockClient as unknown as Parameters<typeof loadMcpTools>[0],
|
||||
'content_and_artifact'
|
||||
);
|
||||
|
||||
// Verify tool was loaded
|
||||
expect(tools.length).toBe(1);
|
||||
expect((tools[0] as any).responseFormat).toBe('content_and_artifact');
|
||||
expect((tools[0] as any).responseFormat).toBe('content');
|
||||
|
||||
// Mock the call result to check response format handling
|
||||
const mockImageContent = { type: 'image', url: 'http://example.com/image.jpg' };
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
|
||||
});
|
||||
mockClient.callTool.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
|
||||
})
|
||||
);
|
||||
|
||||
// Invoke the tool
|
||||
const result = await tools[0].invoke({ test: 'input' });
|
||||
// Invoke the tool with proper input matching the schema
|
||||
const result = await tools[0].invoke({ input: 'test input' });
|
||||
|
||||
// The result is the text content, as LangChain processes the array in the call method
|
||||
expect(result).toBe('Image result');
|
||||
// Verify the result
|
||||
expect(Array.isArray(result)).toBe(true);
|
||||
expect(result[0]).toBe('Image result');
|
||||
expect(result[1]).toEqual([mockImageContent]);
|
||||
|
||||
// Set up a second mock response for the direct call test
|
||||
mockClient.callTool.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
|
||||
});
|
||||
mockClient.callTool.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
|
||||
})
|
||||
);
|
||||
|
||||
// Access the tool class implementation through the prototype chain
|
||||
const toolPrototype = Object.getPrototypeOf(tools[0]);
|
||||
// Call the _call method directly with bound 'this' context
|
||||
const directResult = await toolPrototype._call.call(tools[0], { test: 'input' });
|
||||
const directResult = await toolPrototype._call.call(tools[0], { input: 'test input' });
|
||||
|
||||
expect(Array.isArray(directResult)).toBe(true);
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
export default {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
testMatch: ['**/__tests__/**/*.test.ts'],
|
||||
collectCoverage: true,
|
||||
collectCoverageFrom: ['src/**/*.ts', '!src/**/*.test.ts'],
|
||||
coverageDirectory: 'coverage',
|
||||
coverageReporters: ['text', 'lcov'],
|
||||
moduleFileExtensions: ['ts', 'js', 'json'],
|
||||
transform: {
|
||||
'^.+\\.tsx?$': [
|
||||
'ts-jest',
|
||||
{
|
||||
useESM: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
extensionsToTreatAsEsm: ['.ts'],
|
||||
moduleNameMapper: {
|
||||
'^(\\.{1,2}/.*)\\.js$': '$1',
|
||||
'^node:(.*)$': '$1',
|
||||
},
|
||||
transformIgnorePatterns: [
|
||||
'/node_modules/(?!(@dmitryrechkin/json-schema-to-zod|pkce-challenge|@modelcontextprotocol)/)',
|
||||
],
|
||||
setupFiles: ['<rootDir>/jest.setup.js'],
|
||||
};
|
||||
Generated
+1984
-3070
File diff suppressed because it is too large
Load Diff
+7
-7
@@ -15,8 +15,9 @@
|
||||
},
|
||||
"scripts": {
|
||||
"build": "run-s \"build:main -- {@}\" \"build:examples -- {@}\" --",
|
||||
"test": "jest",
|
||||
"test:coverage": "jest --coverage",
|
||||
"test": "vitest run",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test:watch": "vitest",
|
||||
"lint": "eslint --ignore-pattern 'dist/**' .",
|
||||
"lint:fix": "eslint --ignore-pattern 'dist/**' . --fix",
|
||||
"format": "prettier --write \"src/**/*.ts\" \"examples/**/*.ts\"",
|
||||
@@ -40,6 +41,7 @@
|
||||
"author": "Ravi Kiran Vemula",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@dmitryrechkin/json-schema-to-zod": "^1.0.1",
|
||||
"@modelcontextprotocol/sdk": "^1.7.0",
|
||||
"winston": "^3.17.0"
|
||||
},
|
||||
@@ -53,8 +55,7 @@
|
||||
"@eslint/js": "^9.21.0",
|
||||
"@langchain/langgraph": "^0.2.56",
|
||||
"@langchain/openai": "^0.4.4",
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/node": "^20.11.30",
|
||||
"@types/node": "^22.13.10",
|
||||
"@typescript-eslint/eslint-plugin": "^7.3.1",
|
||||
"@typescript-eslint/parser": "^7.3.1",
|
||||
"dotenv": "^16.4.7",
|
||||
@@ -62,14 +63,13 @@
|
||||
"eslint-config-prettier": "^10.0.2",
|
||||
"eventsource": "^3.0.5",
|
||||
"husky": "^9.0.11",
|
||||
"jest": "^29.7.0",
|
||||
"lint-staged": "^15.2.2",
|
||||
"npm-run-all": "^4.1.5",
|
||||
"prettier": "^3.2.5",
|
||||
"ts-jest": "^29.1.2",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.4.2",
|
||||
"typescript-eslint": "^8.26.0"
|
||||
"typescript-eslint": "^8.26.0",
|
||||
"vitest": "^3.0.9"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
|
||||
+208
-54
@@ -2,7 +2,6 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StructuredToolInterface } from '@langchain/core/tools';
|
||||
import { z } from 'zod';
|
||||
import { loadMcpTools } from './tools.js';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
@@ -94,7 +93,7 @@ export class MCPClientError extends Error {
|
||||
*/
|
||||
export class MultiServerMCPClient {
|
||||
private clients: Map<string, Client> = new Map();
|
||||
private serverNameToTools: Map<string, StructuredToolInterface<z.ZodObject<any>>[]> = new Map();
|
||||
private serverNameToTools: Map<string, StructuredToolInterface[]> = new Map();
|
||||
private connections?: Record<string, Connection>;
|
||||
private cleanupFunctions: Array<() => Promise<void>> = [];
|
||||
private transportInstances: Map<string, StdioClientTransport | SSEClientTransport> = new Map();
|
||||
@@ -104,7 +103,7 @@ export class MultiServerMCPClient {
|
||||
*
|
||||
* @param connections - Optional connections to initialize
|
||||
*/
|
||||
constructor(connections?: Record<string, any>) {
|
||||
constructor(connections?: Record<string, Connection>) {
|
||||
if (connections) {
|
||||
this.connections = this.processConnections(connections);
|
||||
} else {
|
||||
@@ -212,7 +211,9 @@ export class MultiServerMCPClient {
|
||||
* @param connections - Raw connection configurations
|
||||
* @returns Processed connection configurations
|
||||
*/
|
||||
private processConnections(connections: Record<string, any>): Record<string, Connection> {
|
||||
private processConnections(
|
||||
connections: Record<string, Partial<Connection>>
|
||||
): Record<string, Connection> {
|
||||
const processedConnections: Record<string, Connection> = {};
|
||||
|
||||
for (const [serverName, config] of Object.entries(connections)) {
|
||||
@@ -221,17 +222,21 @@ export class MultiServerMCPClient {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
// Determine the connection type and process accordingly
|
||||
if (this.isStdioConnection(config)) {
|
||||
processedConnections[serverName] = this.processStdioConfig(serverName, config);
|
||||
} else if (this.isSSEConnection(config)) {
|
||||
processedConnections[serverName] = this.processSSEConfig(serverName, config);
|
||||
} else {
|
||||
logger.warn(`Server "${serverName}" has invalid or unsupported configuration. Skipping.`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error processing configuration for server "${serverName}": ${error}`);
|
||||
// Determine the connection type and process accordingly
|
||||
if (MultiServerMCPClient.isStdioConnection(config)) {
|
||||
processedConnections[serverName] = MultiServerMCPClient.processStdioConfig(
|
||||
serverName,
|
||||
config
|
||||
);
|
||||
} else if (MultiServerMCPClient.isSSEConnection(config)) {
|
||||
processedConnections[serverName] = MultiServerMCPClient.processSSEConfig(
|
||||
serverName,
|
||||
config
|
||||
);
|
||||
} else {
|
||||
throw new MCPClientError(
|
||||
`Server "${serverName}" has invalid or unsupported configuration. Skipping.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,45 +246,108 @@ export class MultiServerMCPClient {
|
||||
/**
|
||||
* Check if a configuration is for a stdio connection
|
||||
*/
|
||||
private isStdioConnection(config: any): boolean {
|
||||
private static isStdioConnection(config: unknown): config is StdioConnection {
|
||||
// When transport is missing, default to stdio if it has command and args
|
||||
// OR when transport is explicitly set to 'stdio'
|
||||
return (
|
||||
(config.transport === 'stdio' || !config.transport) &&
|
||||
config.command &&
|
||||
Array.isArray(config.args)
|
||||
typeof config === 'object' &&
|
||||
config !== null &&
|
||||
(!('transport' in config) || config.transport === 'stdio') &&
|
||||
'command' in config &&
|
||||
(!('args' in config) || Array.isArray(config.args))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a configuration is for an SSE connection
|
||||
*/
|
||||
private isSSEConnection(config: any): boolean {
|
||||
private static isSSEConnection(config: unknown): config is SSEConnection {
|
||||
// Only consider it an SSE connection if transport is explicitly set to 'sse'
|
||||
return config.transport === 'sse' && typeof config.url === 'string';
|
||||
return (
|
||||
typeof config === 'object' &&
|
||||
config !== null &&
|
||||
'transport' in config &&
|
||||
config.transport === 'sse' &&
|
||||
'url' in config &&
|
||||
typeof config.url === 'string'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Process stdio connection configuration
|
||||
*/
|
||||
private processStdioConfig(serverName: string, config: any): StdioConnection {
|
||||
private static processStdioConfig(
|
||||
serverName: string,
|
||||
config: Partial<StdioConnection>
|
||||
): StdioConnection {
|
||||
if (!config.command || typeof config.command !== 'string') {
|
||||
throw new MCPClientError(`Missing or invalid command for server "${serverName}"`);
|
||||
}
|
||||
|
||||
if (config.args !== undefined && !Array.isArray(config.args)) {
|
||||
throw new MCPClientError(
|
||||
`Invalid args for server "${serverName} - must be an array of strings`
|
||||
);
|
||||
}
|
||||
|
||||
if (config.args !== undefined && !config.args.every(arg => typeof arg === 'string')) {
|
||||
throw new MCPClientError(
|
||||
`Invalid args for server "${serverName} - must be an array of strings`
|
||||
);
|
||||
}
|
||||
|
||||
// Always set transport to 'stdio' regardless of whether it was in the original config
|
||||
const stdioConfig: StdioConnection = {
|
||||
transport: 'stdio',
|
||||
command: config.command,
|
||||
args: config.args,
|
||||
args: config.args ?? [],
|
||||
};
|
||||
|
||||
if (config.env && typeof config.env !== 'object') {
|
||||
throw new MCPClientError(
|
||||
`Invalid env for server "${serverName} - must be an object of key-value pairs`
|
||||
);
|
||||
}
|
||||
|
||||
if (config.env && typeof config.env === 'object' && Array.isArray(config.env)) {
|
||||
throw new MCPClientError(
|
||||
`Invalid env for server "${serverName} - must be an object of key-value pairs`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
config.env &&
|
||||
typeof config.env === 'object' &&
|
||||
!Object.values(config.env).every(value => typeof value === 'string')
|
||||
) {
|
||||
throw new MCPClientError(
|
||||
`Invalid env for server "${serverName} - must be an object of key-value pairs with string values`
|
||||
);
|
||||
}
|
||||
|
||||
// Add optional properties if they exist
|
||||
if (config.env && typeof config.env === 'object') {
|
||||
stdioConfig.env = config.env;
|
||||
}
|
||||
|
||||
if (config.encoding !== undefined && typeof config.encoding !== 'string') {
|
||||
throw new MCPClientError(`Invalid encoding for server "${serverName} - must be a string`);
|
||||
}
|
||||
|
||||
if (typeof config.encoding === 'string') {
|
||||
stdioConfig.encoding = config.encoding;
|
||||
}
|
||||
|
||||
if (['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler)) {
|
||||
if (
|
||||
config.encodingErrorHandler !== undefined &&
|
||||
!['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler)
|
||||
) {
|
||||
throw new MCPClientError(
|
||||
`Invalid encodingErrorHandler for server "${serverName} - must be one of: strict, ignore, replace`
|
||||
);
|
||||
}
|
||||
|
||||
if (['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler ?? '')) {
|
||||
stdioConfig.encodingErrorHandler = config.encodingErrorHandler as
|
||||
| 'strict'
|
||||
| 'ignore'
|
||||
@@ -287,15 +355,40 @@ export class MultiServerMCPClient {
|
||||
}
|
||||
|
||||
// Add restart configuration if present
|
||||
if (config.restart && typeof config.restart !== 'object') {
|
||||
throw new MCPClientError(`Invalid restart for server "${serverName} - must be an object`);
|
||||
}
|
||||
|
||||
if (config.restart && typeof config.restart === 'object') {
|
||||
if (config.restart.enabled !== undefined && typeof config.restart.enabled !== 'boolean') {
|
||||
throw new MCPClientError(
|
||||
`Invalid restart.enabled for server "${serverName} - must be a boolean`
|
||||
);
|
||||
}
|
||||
|
||||
stdioConfig.restart = {
|
||||
enabled: Boolean(config.restart.enabled),
|
||||
};
|
||||
|
||||
if (
|
||||
config.restart.maxAttempts !== undefined &&
|
||||
typeof config.restart.maxAttempts !== 'number'
|
||||
) {
|
||||
throw new MCPClientError(
|
||||
`Invalid restart.maxAttempts for server "${serverName} - must be a number`
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof config.restart.maxAttempts === 'number') {
|
||||
stdioConfig.restart.maxAttempts = config.restart.maxAttempts;
|
||||
}
|
||||
|
||||
if (config.restart.delayMs !== undefined && typeof config.restart.delayMs !== 'number') {
|
||||
throw new MCPClientError(
|
||||
`Invalid restart.delayMs for server "${serverName} - must be a number`
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof config.restart.delayMs === 'number') {
|
||||
stdioConfig.restart.delayMs = config.restart.delayMs;
|
||||
}
|
||||
@@ -307,32 +400,102 @@ export class MultiServerMCPClient {
|
||||
/**
|
||||
* Process SSE connection configuration
|
||||
*/
|
||||
private processSSEConfig(serverName: string, config: any): SSEConnection {
|
||||
private static processSSEConfig(serverName: string, config: SSEConnection): SSEConnection {
|
||||
if (!config.url || typeof config.url !== 'string') {
|
||||
throw new MCPClientError(`Missing or invalid url for server "${serverName}"`);
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(config.url);
|
||||
if (!url.protocol.startsWith('http')) {
|
||||
throw new MCPClientError(
|
||||
`Invalid url for server "${serverName} - must be a valid HTTP or HTTPS URL`
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
throw new MCPClientError(`Invalid url for server "${serverName} - must be a valid URL`);
|
||||
}
|
||||
|
||||
if (!config.transport || config.transport !== 'sse') {
|
||||
throw new MCPClientError(`Invalid transport for server "${serverName} - must be 'sse'`);
|
||||
}
|
||||
|
||||
const sseConfig: SSEConnection = {
|
||||
transport: 'sse',
|
||||
url: config.url,
|
||||
};
|
||||
|
||||
if (config.headers && typeof config.headers !== 'object') {
|
||||
throw new MCPClientError(`Invalid headers for server "${serverName} - must be an object`);
|
||||
}
|
||||
|
||||
if (config.headers && typeof config.headers === 'object' && Array.isArray(config.headers)) {
|
||||
throw new MCPClientError(
|
||||
`Invalid headers for server "${serverName} - must be an object of key-value pairs`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
config.headers &&
|
||||
typeof config.headers === 'object' &&
|
||||
!Object.values(config.headers).every(value => typeof value === 'string')
|
||||
) {
|
||||
throw new MCPClientError(
|
||||
`Invalid headers for server "${serverName} - must be an object of key-value pairs with string values`
|
||||
);
|
||||
}
|
||||
|
||||
// Add optional headers if they exist
|
||||
if (config.headers && typeof config.headers === 'object') {
|
||||
sseConfig.headers = config.headers;
|
||||
}
|
||||
|
||||
if (config.useNodeEventSource !== undefined && typeof config.useNodeEventSource !== 'boolean') {
|
||||
throw new MCPClientError(
|
||||
`Invalid useNodeEventSource for server "${serverName} - must be a boolean`
|
||||
);
|
||||
}
|
||||
|
||||
// Add optional useNodeEventSource flag if it exists
|
||||
if (typeof config.useNodeEventSource === 'boolean') {
|
||||
sseConfig.useNodeEventSource = config.useNodeEventSource;
|
||||
}
|
||||
|
||||
if (config.reconnect && typeof config.reconnect !== 'object') {
|
||||
throw new MCPClientError(`Invalid reconnect for server "${serverName} - must be an object`);
|
||||
}
|
||||
|
||||
// Add reconnection configuration if present
|
||||
if (config.reconnect && typeof config.reconnect === 'object') {
|
||||
if (config.reconnect.enabled !== undefined && typeof config.reconnect.enabled !== 'boolean') {
|
||||
throw new MCPClientError(
|
||||
`Invalid reconnect.enabled for server "${serverName} - must be a boolean`
|
||||
);
|
||||
}
|
||||
|
||||
sseConfig.reconnect = {
|
||||
enabled: Boolean(config.reconnect.enabled),
|
||||
};
|
||||
|
||||
if (
|
||||
config.reconnect.maxAttempts !== undefined &&
|
||||
typeof config.reconnect.maxAttempts !== 'number'
|
||||
) {
|
||||
throw new MCPClientError(
|
||||
`Invalid reconnect.maxAttempts for server "${serverName} - must be a number`
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof config.reconnect.maxAttempts === 'number') {
|
||||
sseConfig.reconnect.maxAttempts = config.reconnect.maxAttempts;
|
||||
}
|
||||
|
||||
if (config.reconnect.delayMs !== undefined && typeof config.reconnect.delayMs !== 'number') {
|
||||
throw new MCPClientError(
|
||||
`Invalid reconnect.delayMs for server "${serverName} - must be a number`
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof config.reconnect.delayMs === 'number') {
|
||||
sseConfig.reconnect.delayMs = config.reconnect.delayMs;
|
||||
}
|
||||
@@ -377,33 +540,25 @@ export class MultiServerMCPClient {
|
||||
* @returns A map of server names to arrays of tools
|
||||
* @throws {MCPClientError} If initialization fails
|
||||
*/
|
||||
async initializeConnections(): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
|
||||
async initializeConnections(): Promise<Map<string, StructuredToolInterface[]>> {
|
||||
if (!this.connections || Object.keys(this.connections).length === 0) {
|
||||
logger.warn('No connections to initialize');
|
||||
return new Map();
|
||||
}
|
||||
|
||||
for (const [serverName, connection] of Object.entries(this.connections)) {
|
||||
try {
|
||||
logger.info(`Initializing connection to server "${serverName}"...`);
|
||||
logger.info(`Initializing connection to server "${serverName}"...`);
|
||||
|
||||
if (connection.transport === 'stdio') {
|
||||
await this.initializeStdioConnection(serverName, connection);
|
||||
} else if (connection.transport === 'sse') {
|
||||
await this.initializeSSEConnection(serverName, connection);
|
||||
} else {
|
||||
// This should never happen due to the validation in the constructor
|
||||
throw new MCPClientError(
|
||||
`Unsupported transport type for server "${serverName}"`,
|
||||
serverName
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof MCPClientError) {
|
||||
logger.error(error.message);
|
||||
} else {
|
||||
logger.error(`Failed to connect to server "${serverName}": ${error}`);
|
||||
}
|
||||
if (connection.transport === 'stdio') {
|
||||
await this.initializeStdioConnection(serverName, connection);
|
||||
} else if (connection.transport === 'sse') {
|
||||
await this.initializeSSEConnection(serverName, connection);
|
||||
} else {
|
||||
// This should never happen due to the validation in the constructor
|
||||
throw new MCPClientError(
|
||||
`Unsupported transport type for server "${serverName}"`,
|
||||
serverName
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,8 +829,7 @@ export class MultiServerMCPClient {
|
||||
this.serverNameToTools.set(serverName, tools);
|
||||
logger.info(`Successfully loaded ${tools.length} tools from server "${serverName}"`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load tools from server "${serverName}": ${error}`);
|
||||
// Continue even if tool loading fails - the connection is still established
|
||||
throw new MCPClientError(`Failed to load tools from server "${serverName}": ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,7 +906,7 @@ export class MultiServerMCPClient {
|
||||
* If not provided, returns tools from all servers.
|
||||
* @returns A flattened array of tools from the specified servers (or all servers)
|
||||
*/
|
||||
getTools(servers?: string[]): StructuredToolInterface<z.ZodObject<any>>[] {
|
||||
getTools(servers?: string[]): StructuredToolInterface[] {
|
||||
if (!servers || servers.length === 0) {
|
||||
return this.getAllToolsAsFlatArray();
|
||||
}
|
||||
@@ -764,8 +918,8 @@ export class MultiServerMCPClient {
|
||||
*
|
||||
* @returns A flattened array of all tools
|
||||
*/
|
||||
private getAllToolsAsFlatArray(): StructuredToolInterface<z.ZodObject<any>>[] {
|
||||
const allTools: StructuredToolInterface<z.ZodObject<any>>[] = [];
|
||||
private getAllToolsAsFlatArray(): StructuredToolInterface[] {
|
||||
const allTools: StructuredToolInterface[] = [];
|
||||
for (const tools of this.serverNameToTools.values()) {
|
||||
allTools.push(...tools);
|
||||
}
|
||||
@@ -778,8 +932,8 @@ export class MultiServerMCPClient {
|
||||
* @param serverNames - Names of servers to get tools from
|
||||
* @returns A flattened array of tools from the specified servers
|
||||
*/
|
||||
private getToolsFromServers(serverNames: string[]): StructuredToolInterface<z.ZodObject<any>>[] {
|
||||
const allTools: StructuredToolInterface<z.ZodObject<any>>[] = [];
|
||||
private getToolsFromServers(serverNames: string[]): StructuredToolInterface[] {
|
||||
const allTools: StructuredToolInterface[] = [];
|
||||
for (const serverName of serverNames) {
|
||||
const tools = this.serverNameToTools.get(serverName);
|
||||
if (tools) {
|
||||
@@ -837,7 +991,7 @@ export class MultiServerMCPClient {
|
||||
args: string[],
|
||||
env?: Record<string, string>,
|
||||
restart?: StdioConnection['restart']
|
||||
): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
|
||||
): Promise<Map<string, StructuredToolInterface[]>> {
|
||||
const connections: Record<string, Connection> = {
|
||||
[serverName]: {
|
||||
transport: 'stdio',
|
||||
@@ -868,7 +1022,7 @@ export class MultiServerMCPClient {
|
||||
headers?: Record<string, string>,
|
||||
useNodeEventSource?: boolean,
|
||||
reconnect?: SSEConnection['reconnect']
|
||||
): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
|
||||
): Promise<Map<string, StructuredToolInterface[]>> {
|
||||
const connection: SSEConnection = {
|
||||
transport: 'sse',
|
||||
url,
|
||||
|
||||
+1
-2
@@ -1,3 +1,2 @@
|
||||
export { MultiServerMCPClient } from './client.js';
|
||||
export { convertMcpToolToLangchainTool, loadMcpTools } from './tools.js';
|
||||
export { default as logger, enableLogging, disableLogging } from './logger.js';
|
||||
export { logger, enableLogging, disableLogging } from './logger.js';
|
||||
|
||||
+2
-2
@@ -97,7 +97,7 @@ try {
|
||||
* Create the logger instance with our configuration.
|
||||
* By default, logging is disabled (silent) but can be enabled by setting the level.
|
||||
*/
|
||||
const logger = winston.createLogger({
|
||||
export const logger = winston.createLogger({
|
||||
level: defaultLevel, // Start with silent logging by default
|
||||
levels,
|
||||
format,
|
||||
@@ -111,7 +111,7 @@ const logger = winston.createLogger({
|
||||
*/
|
||||
export function enableLogging(level: keyof typeof levels | 'silent' = 'info'): void {
|
||||
logger.level = level;
|
||||
logger.info(`Logging enabled at level: ${level}`);
|
||||
logger.debug(`Logging enabled at level: ${level}`);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+74
-81
@@ -1,6 +1,11 @@
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StructuredTool, StructuredToolInterface } from '@langchain/core/tools';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
DynamicStructuredTool,
|
||||
ResponseFormat,
|
||||
StructuredToolInterface,
|
||||
} from '@langchain/core/tools';
|
||||
import { JSONSchema, JSONSchemaToZod } from '@dmitryrechkin/json-schema-to-zod';
|
||||
|
||||
import logger from './logger.js';
|
||||
|
||||
interface TextContent {
|
||||
@@ -72,73 +77,53 @@ function _convertCallToolResult(
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid response formats for MCP tools
|
||||
*/
|
||||
export type ResponseFormat = 'text' | 'content_and_artifact';
|
||||
|
||||
/**
|
||||
* Convert an MCP tool to a LangChain tool.
|
||||
* Call an MCP tool.
|
||||
*
|
||||
* Use this with `.bind` to capture the fist three arguments, then pass to the constructor of DynamicStructuredTool.
|
||||
*
|
||||
* @internal
|
||||
*
|
||||
* @param client - The MCP client
|
||||
* @param tool - The MCP tool to convert
|
||||
* @param toolSchema - Tool schema (kept for backward compatibility, not used in current implementation)
|
||||
* @param responseFormat - Response format ('text' or 'content_and_artifact')
|
||||
* @returns A LangChain tool
|
||||
* @param name - The name of the tool (forwarded to the client)
|
||||
* @param responseFormat - The response format
|
||||
* @param args - The arguments to pass to the tool
|
||||
* @returns A tuple of [textContent, nonTextContent]
|
||||
*/
|
||||
export function convertMcpToolToLangchainTool(
|
||||
async function _callTool(
|
||||
client: Client,
|
||||
tool: { name: string; description?: string; inputSchema?: any },
|
||||
toolSchema?: any,
|
||||
responseFormat: ResponseFormat = 'content_and_artifact'
|
||||
): StructuredToolInterface {
|
||||
// Create a minimal MCPTool class extending StructuredTool
|
||||
class MCPTool extends StructuredTool {
|
||||
name = tool.name;
|
||||
description = tool.description || '';
|
||||
schema = z.object({}).passthrough();
|
||||
override responseFormat: ResponseFormat;
|
||||
name: string,
|
||||
responseFormat: ResponseFormat,
|
||||
args: Record<string, unknown>
|
||||
): Promise<string | [string | string[], NonTextContent[] | null]> {
|
||||
try {
|
||||
logger.info(`Calling tool ${name}(${JSON.stringify(args)})`);
|
||||
const result = await client.callTool({
|
||||
name,
|
||||
arguments: args,
|
||||
});
|
||||
|
||||
constructor(responseFormat: ResponseFormat) {
|
||||
super({
|
||||
responseFormat,
|
||||
verboseParsingErrors: false,
|
||||
});
|
||||
this.responseFormat = responseFormat;
|
||||
const [textContent, nonTextContent] = _convertCallToolResult({
|
||||
...result,
|
||||
isError: result.isError === true,
|
||||
content: result.content || [],
|
||||
});
|
||||
|
||||
logger.info(`Tool ${name} returned: ${JSON.stringify({ textContent, nonTextContent })}`);
|
||||
|
||||
// Return based on the response format
|
||||
if (responseFormat === 'content_and_artifact') {
|
||||
return [textContent, nonTextContent];
|
||||
}
|
||||
|
||||
protected async _call(
|
||||
args: Record<string, unknown>
|
||||
): Promise<string | [string | string[], NonTextContent[] | null]> {
|
||||
try {
|
||||
const result = await client.callTool({
|
||||
name: tool.name,
|
||||
arguments: args,
|
||||
});
|
||||
|
||||
const [textContent, nonTextContent] = _convertCallToolResult({
|
||||
...result,
|
||||
isError: result.isError === true,
|
||||
content: result.content || [],
|
||||
});
|
||||
|
||||
// Return based on the response format
|
||||
if (this.responseFormat === 'content_and_artifact') {
|
||||
return [textContent, nonTextContent] as [string | string[], NonTextContent[] | null];
|
||||
}
|
||||
|
||||
// Default to returning just the text content
|
||||
return typeof textContent === 'string' ? textContent : textContent.join('\n');
|
||||
} catch (error) {
|
||||
if (error instanceof ToolException) {
|
||||
throw error;
|
||||
}
|
||||
throw new ToolException(`Error calling tool ${tool.name}: ${String(error)}`);
|
||||
}
|
||||
// Default to returning just the text content
|
||||
return typeof textContent === 'string' ? textContent : textContent.join('\n');
|
||||
} catch (error) {
|
||||
logger.error(`Error calling tool ${name}: ${String(error)}`);
|
||||
if (error instanceof ToolException) {
|
||||
throw error;
|
||||
}
|
||||
throw new ToolException(`Error calling tool ${name}: ${String(error)}`);
|
||||
}
|
||||
|
||||
// Return an instance of our tool with the specified response format
|
||||
return new MCPTool(responseFormat);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -150,28 +135,36 @@ export function convertMcpToolToLangchainTool(
|
||||
*/
|
||||
export async function loadMcpTools(
|
||||
client: Client,
|
||||
responseFormat: ResponseFormat = 'text'
|
||||
responseFormat: ResponseFormat = 'content',
|
||||
throwOnLoadError: boolean = true
|
||||
): Promise<StructuredToolInterface[]> {
|
||||
try {
|
||||
// Get tools in a single operation
|
||||
const toolsResponse = await client.listTools();
|
||||
logger.info(`Found ${toolsResponse.tools?.length || 0} MCP tools`);
|
||||
// Get tools in a single operation
|
||||
const toolsResponse = await client.listTools();
|
||||
logger.info(`Found ${toolsResponse.tools?.length || 0} MCP tools`);
|
||||
|
||||
// Filter out tools without names and convert in a single map operation
|
||||
return (toolsResponse.tools || [])
|
||||
.filter(tool => !!tool.name)
|
||||
.map(tool => {
|
||||
try {
|
||||
logger.debug(`Successfully loaded tool: ${tool.name}`);
|
||||
return convertMcpToolToLangchainTool(client, tool, undefined, responseFormat);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load tool "${tool.name}":`, error);
|
||||
return null;
|
||||
// Filter out tools without names and convert in a single map operation
|
||||
return (toolsResponse.tools || [])
|
||||
.filter(tool => !!tool.name)
|
||||
.map(tool => {
|
||||
try {
|
||||
const dst = new DynamicStructuredTool({
|
||||
name: tool.name,
|
||||
description: tool.description || '',
|
||||
schema: JSONSchemaToZod.convert(
|
||||
(tool.inputSchema ?? { type: 'object', properties: {} }) as JSONSchema
|
||||
),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
func: _callTool.bind(null, client, tool.name, responseFormat) as any,
|
||||
});
|
||||
logger.debug(`Successfully loaded tool: ${dst.name}`);
|
||||
return dst;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load tool "${tool.name}":`, error);
|
||||
if (throwOnLoadError) {
|
||||
throw error;
|
||||
}
|
||||
})
|
||||
.filter(Boolean) as StructuredToolInterface[];
|
||||
} catch (error) {
|
||||
logger.error('Failed to list MCP tools:', error);
|
||||
return [];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean) as StructuredToolInterface[];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
environment: 'node',
|
||||
include: ['**/__tests__/**/*.test.ts'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'lcov'],
|
||||
include: ['src/**/*.ts'],
|
||||
exclude: ['src/**/*.test.ts'],
|
||||
},
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
transformMode: {
|
||||
web: [/\.[jt]sx?$/],
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -1,16 +1,17 @@
|
||||
/* global jest */
|
||||
// Mock node: imports
|
||||
jest.mock('node:crypto', () => {
|
||||
import { vi } from 'vitest';
|
||||
|
||||
// Mock node:crypto module
|
||||
vi.mock('node:crypto', () => {
|
||||
return {
|
||||
webcrypto: {
|
||||
getRandomValues: array => {
|
||||
getRandomValues: (array: Uint8Array) => {
|
||||
for (let i = 0; i < array.length; i++) {
|
||||
array[i] = Math.floor(Math.random() * 256);
|
||||
}
|
||||
return array;
|
||||
},
|
||||
subtle: {
|
||||
digest: jest.fn().mockResolvedValue(new ArrayBuffer(32)),
|
||||
digest: vi.fn().mockResolvedValue(new ArrayBuffer(32)),
|
||||
},
|
||||
},
|
||||
};
|
||||
Reference in New Issue
Block a user