fix: pass inputSchema to LangChain tools & fail loudly (#20)

Also migrates test suite to vitest for better support for mocked modules
with ESMs.
This commit is contained in:
Ben Burns
2025-03-19 10:35:41 +13:00
committed by GitHub
parent 70a939fefc
commit 88fbdf5093
13 changed files with 2693 additions and 3772 deletions
+119 -127
View File
@@ -1,12 +1,13 @@
// Mock the problematic dependencies
jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
import { vi, describe, test, expect, beforeEach, afterEach } from 'vitest';
// Mock the problematic dependencies using vi.mock
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
// Create mock functions for all methods
const connectMock = jest.fn().mockResolvedValue(undefined);
const sendMock = jest.fn().mockResolvedValue(undefined);
const closeMock = jest.fn().mockResolvedValue(undefined);
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
const sendMock = vi.fn().mockReturnValue(Promise.resolve());
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
return {
SSEClientTransport: jest.fn().mockImplementation(() => ({
SSEClientTransport: vi.fn().mockImplementation(() => ({
connect: connectMock,
send: sendMock,
close: closeMock,
@@ -15,14 +16,14 @@ jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
};
});
jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
// Create mock functions for all methods
const connectMock = jest.fn().mockResolvedValue(undefined);
const sendMock = jest.fn().mockResolvedValue(undefined);
const closeMock = jest.fn().mockResolvedValue(undefined);
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
const sendMock = vi.fn().mockReturnValue(Promise.resolve());
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
return {
StdioClientTransport: jest.fn().mockImplementation(() => ({
StdioClientTransport: vi.fn().mockImplementation(() => ({
connect: connectMock,
send: sendMock,
close: closeMock,
@@ -31,31 +32,35 @@ jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
};
});
jest.mock('@modelcontextprotocol/sdk/client/index.js', () => {
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
// Create mock functions for all methods
const connectMock = jest.fn().mockResolvedValue(undefined);
const listToolsMock = jest.fn().mockResolvedValue({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
const connectMock = vi.fn().mockReturnValue(Promise.resolve());
const listToolsMock = vi.fn().mockReturnValue(
Promise.resolve({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
required: ['input'],
},
required: ['input'],
},
},
],
});
const callToolMock = jest.fn().mockResolvedValue({
content: [{ type: 'text', text: 'result' }],
});
const closeMock = jest.fn().mockResolvedValue(undefined);
],
})
);
const callToolMock = vi.fn().mockReturnValue(
Promise.resolve({
content: [{ type: 'text', text: 'result' }],
})
);
const closeMock = vi.fn().mockReturnValue(Promise.resolve());
return {
Client: jest.fn().mockImplementation(() => ({
Client: vi.fn().mockImplementation(() => ({
connect: connectMock,
listTools: listToolsMock,
callTool: callToolMock,
@@ -64,84 +69,78 @@ jest.mock('@modelcontextprotocol/sdk/client/index.js', () => {
};
});
jest.mock('fs');
jest.mock('path');
vi.mock('fs', () => ({
readFileSync: vi.fn(),
}));
vi.mock('path', () => ({
resolve: vi.fn(),
}));
// Mock the logger
jest.mock('../src/logger.js', () => {
vi.mock('../src/logger.js', () => {
return {
__esModule: true,
default: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
},
};
});
// Create placeholder mocks that will be replaced in beforeEach
jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({
SSEClientTransport: jest.fn(),
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({
SSEClientTransport: vi.fn(),
}));
jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({
StdioClientTransport: jest.fn(),
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({
StdioClientTransport: vi.fn(),
}));
jest.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
Client: jest.fn(),
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
Client: vi.fn(),
}));
/* eslint-disable @typescript-eslint/no-unused-vars */
import {
MultiServerMCPClient,
MCPClientError,
StdioConnection,
SSEConnection,
} from '../src/client.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as fs from 'fs';
import * as path from 'path';
import { StructuredToolInterface } from '@langchain/core/tools';
import { z } from 'zod';
import { loadMcpTools } from '../src/tools.js';
/* eslint-enable @typescript-eslint/no-unused-vars */
const { MultiServerMCPClient, MCPClientError } = await import('../src/client.js');
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const fs = await import('fs');
const path = await import('path');
describe('MultiServerMCPClient', () => {
// Create mock implementations that will be used throughout the tests
let mockClientConnect: jest.Mock;
let mockClientListTools: jest.Mock;
let mockClientCallTool: jest.Mock;
let mockClientClose: jest.Mock;
let mockClientConnect: vi.Mock;
let mockClientListTools: vi.Mock;
let mockClientCallTool: vi.Mock;
let mockClientClose: vi.Mock;
let mockStdioTransportClose: jest.Mock;
let mockStdioTransportConnect: jest.Mock;
let mockStdioTransportSend: jest.Mock;
let mockStdioTransportClose: vi.Mock;
let mockStdioTransportConnect: vi.Mock;
let mockStdioTransportSend: vi.Mock;
// Define specific function type for onclose handlers
let mockStdioOnClose: (() => void) | null;
let mockSSETransportClose: jest.Mock;
let mockSSETransportConnect: jest.Mock;
let mockSSETransportSend: jest.Mock;
let mockSSETransportClose: vi.Mock;
let mockSSETransportConnect: vi.Mock;
let mockSSETransportSend: vi.Mock;
// Define specific function type for onclose handlers
let mockSSEOnClose: (() => void) | null;
// Setup and teardown
beforeEach(() => {
jest.clearAllMocks();
vi.clearAllMocks();
// Set up mock implementations for Client
mockClientConnect = jest.fn().mockResolvedValue(undefined);
mockClientListTools = jest.fn().mockResolvedValue({ tools: [] });
mockClientCallTool = jest
mockClientConnect = vi.fn().mockReturnValue(Promise.resolve());
mockClientListTools = vi.fn().mockReturnValue(Promise.resolve({ tools: [] }));
mockClientCallTool = vi
.fn()
.mockResolvedValue({ content: [{ type: 'text', text: 'result' }] });
mockClientClose = jest.fn().mockResolvedValue(undefined);
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] }));
mockClientClose = vi.fn().mockReturnValue(Promise.resolve());
(Client as jest.Mock).mockImplementation(() => ({
(Client as vi.Mock).mockImplementation(() => ({
connect: mockClientConnect,
listTools: mockClientListTools,
callTool: mockClientCallTool,
@@ -149,12 +148,12 @@ describe('MultiServerMCPClient', () => {
}));
// Set up mock implementations for StdioClientTransport
mockStdioTransportClose = jest.fn().mockResolvedValue(undefined);
mockStdioTransportConnect = jest.fn().mockResolvedValue(undefined);
mockStdioTransportSend = jest.fn().mockResolvedValue(undefined);
mockStdioTransportClose = vi.fn().mockReturnValue(Promise.resolve());
mockStdioTransportConnect = vi.fn().mockReturnValue(Promise.resolve());
mockStdioTransportSend = vi.fn().mockReturnValue(Promise.resolve());
mockStdioOnClose = null;
(StdioClientTransport as jest.Mock).mockImplementation(() => {
(StdioClientTransport as vi.Mock).mockImplementation(() => {
const transport = {
close: mockStdioTransportClose,
connect: mockStdioTransportConnect,
@@ -172,12 +171,12 @@ describe('MultiServerMCPClient', () => {
});
// Set up mock implementations for SSEClientTransport
mockSSETransportClose = jest.fn().mockResolvedValue(undefined);
mockSSETransportConnect = jest.fn().mockResolvedValue(undefined);
mockSSETransportSend = jest.fn().mockResolvedValue(undefined);
mockSSETransportClose = vi.fn().mockReturnValue(Promise.resolve());
mockSSETransportConnect = vi.fn().mockReturnValue(Promise.resolve());
mockSSETransportSend = vi.fn().mockReturnValue(Promise.resolve());
mockSSEOnClose = null;
(SSEClientTransport as jest.Mock).mockImplementation(() => {
(SSEClientTransport as vi.Mock).mockImplementation(() => {
const transport = {
close: mockSSETransportClose,
connect: mockSSETransportConnect,
@@ -194,7 +193,7 @@ describe('MultiServerMCPClient', () => {
return transport;
});
(fs.readFileSync as jest.Mock).mockImplementation(() =>
(fs.readFileSync as vi.Mock).mockImplementation(() =>
JSON.stringify({
servers: {
'test-server': {
@@ -206,11 +205,11 @@ describe('MultiServerMCPClient', () => {
})
);
(path.resolve as jest.Mock).mockImplementation(p => p);
(path.resolve as vi.Mock).mockImplementation(p => p);
});
afterEach(() => {
jest.resetAllMocks();
vi.resetAllMocks();
});
// 1. Constructor functionality tests
@@ -221,44 +220,39 @@ describe('MultiServerMCPClient', () => {
});
test('should process valid stdio connection config', () => {
const config = {
const client = new MultiServerMCPClient({
'test-server': {
transport: 'stdio',
command: 'python',
args: ['./script.py'],
},
};
const client = new MultiServerMCPClient(config);
});
expect(client).toBeDefined();
// Additional assertions to verify the connection was processed correctly
});
test('should process valid SSE connection config', () => {
const config = {
const client = new MultiServerMCPClient({
'test-server': {
transport: 'sse',
url: 'http://localhost:8000/sse',
headers: { Authorization: 'Bearer token' },
useNodeEventSource: true,
},
};
const client = new MultiServerMCPClient(config);
});
expect(client).toBeDefined();
// Additional assertions to verify the connection was processed correctly
});
test('should handle invalid connection config gracefully', () => {
const config = {
'test-server': {
transport: 'invalid',
},
};
const client = new MultiServerMCPClient(config);
expect(client).toBeDefined();
// Verify that the invalid config was not processed
test('should have a compile time error and a runtime error when the config is invalid', () => {
expect(() => {
new MultiServerMCPClient({
'test-server': {
// @ts-expect-error shouldn't match type constraints here
transport: 'invalid',
},
});
}).toThrow(MCPClientError);
});
});
@@ -271,7 +265,7 @@ describe('MultiServerMCPClient', () => {
});
test('should throw error for invalid config file', () => {
(fs.readFileSync as jest.Mock).mockImplementation(() => {
(fs.readFileSync as vi.Mock).mockImplementation(() => {
throw new Error('File not found');
});
@@ -281,7 +275,7 @@ describe('MultiServerMCPClient', () => {
});
test('should throw error for invalid JSON in config file', () => {
(fs.readFileSync as jest.Mock).mockImplementation(() => 'invalid json');
(fs.readFileSync as vi.Mock).mockImplementation(() => 'invalid json');
expect(() => {
MultiServerMCPClient.fromConfigFile('./invalid.json');
@@ -329,10 +323,10 @@ describe('MultiServerMCPClient', () => {
expect(mockClientListTools).toHaveBeenCalled();
});
test('should handle connection failures gracefully', async () => {
(Client as jest.Mock).mockImplementation(() => ({
connect: jest.fn().mockRejectedValue(new Error('Connection failed')),
listTools: jest.fn().mockResolvedValue({ tools: [] }),
test('should throw on connection failure', async () => {
(Client as vi.Mock).mockImplementation(() => ({
connect: vi.fn().mockReturnValue(Promise.reject(new Error('Connection failed'))),
listTools: vi.fn().mockReturnValue(Promise.resolve({ tools: [] })),
}));
const client = new MultiServerMCPClient({
@@ -343,14 +337,13 @@ describe('MultiServerMCPClient', () => {
},
});
await client.initializeConnections();
// Verify that the error was handled gracefully
await expect(() => client.initializeConnections()).rejects.toThrow(MCPClientError);
});
test('should handle tool loading failures gracefully', async () => {
(Client as jest.Mock).mockImplementation(() => ({
connect: jest.fn().mockResolvedValue(undefined),
listTools: jest.fn().mockRejectedValue(new Error('Failed to list tools')),
test('should throw on tool loading failures', async () => {
(Client as vi.Mock).mockImplementation(() => ({
connect: vi.fn().mockReturnValue(Promise.resolve()),
listTools: vi.fn().mockReturnValue(Promise.reject(new Error('Failed to list tools'))),
}));
const client = new MultiServerMCPClient({
@@ -361,8 +354,7 @@ describe('MultiServerMCPClient', () => {
},
});
await client.initializeConnections();
// Verify that the error was handled gracefully
await expect(() => client.initializeConnections()).rejects.toThrow(MCPClientError);
});
});
@@ -385,7 +377,7 @@ describe('MultiServerMCPClient', () => {
await client.initializeConnections();
// Reset the call counts to focus on reconnection
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
// Trigger the onclose handler if it exists
if (mockStdioOnClose) {
@@ -415,7 +407,7 @@ describe('MultiServerMCPClient', () => {
await client.initializeConnections();
// Reset the call counts to focus on reconnection
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
// Trigger the onclose handler if it exists
if (mockSSEOnClose) {
@@ -449,9 +441,9 @@ describe('MultiServerMCPClient', () => {
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
];
(Client as jest.Mock).mockImplementation(() => ({
connect: jest.fn().mockResolvedValue(undefined),
listTools: jest.fn().mockResolvedValue({ tools: mockTools }),
(Client as vi.Mock).mockImplementation(() => ({
connect: vi.fn().mockReturnValue(Promise.resolve()),
listTools: vi.fn().mockReturnValue(Promise.resolve({ tools: mockTools })),
}));
const client = new MultiServerMCPClient({
@@ -508,8 +500,8 @@ describe('MultiServerMCPClient', () => {
test('should handle errors during cleanup gracefully', async () => {
// Mock close to throw an error
(StdioClientTransport as jest.Mock).mockImplementation(() => ({
close: jest.fn().mockRejectedValue(new Error('Close failed')),
(StdioClientTransport as vi.Mock).mockImplementation(() => ({
close: vi.fn().mockReturnValue(Promise.reject(new Error('Close failed'))),
onclose: null,
}));
+164 -177
View File
@@ -1,5 +1,7 @@
import { vi, describe, test, expect, beforeEach, afterEach } from 'vitest';
// Mock fs module before imports
jest.mock('fs', () => {
vi.mock('fs', () => {
// Create a map to store mock file contents
const mockFiles = {
'./mcp.json': JSON.stringify({
@@ -16,7 +18,7 @@ jest.mock('fs', () => {
};
return {
readFileSync: jest.fn((path: string, _encoding?: string) => {
readFileSync: vi.fn((path: string, _encoding?: string) => {
if (path === './nonexistent.json') {
throw new Error('File not found');
}
@@ -25,27 +27,29 @@ jest.mock('fs', () => {
}
throw new Error(`Mock file not found: ${path}`);
}),
existsSync: jest.fn((path: string) => {
existsSync: vi.fn((path: string) => {
return Object.prototype.hasOwnProperty.call(mockFiles, path);
}),
};
});
// Mock path module
jest.mock('path', () => ({
join: jest.fn((...args) => args.join('/')),
resolve: jest.fn((...args) => args.join('/')),
dirname: jest.fn(path => path.split('/').slice(0, -1).join('/')),
basename: jest.fn(path => path.split('/').pop()),
}));
vi.mock('path', () => {
return {
join: vi.fn((...args) => args.join('/')),
resolve: vi.fn((...args) => args.join('/')),
dirname: vi.fn((path: string) => path.split('/').slice(0, -1).join('/')),
basename: vi.fn((path: string) => path.split('/').pop()),
};
});
// Mock the logger module
jest.mock('../src/logger.js', () => {
vi.mock('../src/logger.js', () => {
const mockLogger = {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
};
return {
__esModule: true,
@@ -55,30 +59,32 @@ jest.mock('../src/logger.js', () => {
});
// Set up mocks for external modules
jest.mock(
vi.mock(
'@modelcontextprotocol/sdk/client/index.js',
() => {
return {
Client: jest.fn().mockImplementation(() => ({
connect: jest.fn().mockResolvedValue(undefined),
listTools: jest.fn().mockResolvedValue({
tools: [
{
name: 'tool1',
description: 'Test tool 1',
inputSchema: { type: 'object', properties: {} },
},
{
name: 'tool2',
description: 'Test tool 2',
inputSchema: { type: 'object', properties: {} },
},
],
}),
callTool: jest.fn().mockResolvedValue({
content: [{ type: 'text', text: 'result' }],
}),
close: jest.fn().mockResolvedValue(undefined),
Client: vi.fn().mockImplementation(() => ({
connect: vi.fn().mockReturnValue(Promise.resolve()),
listTools: vi.fn().mockReturnValue(
Promise.resolve({
tools: [
{
name: 'tool1',
description: 'Test tool 1',
inputSchema: { type: 'object', properties: {} },
},
{
name: 'tool2',
description: 'Test tool 2',
inputSchema: { type: 'object', properties: {} },
},
],
})
),
callTool: vi
.fn()
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] })),
close: vi.fn().mockReturnValue(Promise.resolve()),
tools: [], // Add the tools property
})),
};
@@ -86,16 +92,16 @@ jest.mock(
{ virtual: true }
);
jest.mock(
vi.mock(
'@modelcontextprotocol/sdk/client/stdio.js',
() => {
// Using the OnCloseHandler type defined at the top level
return {
StdioClientTransport: jest.fn().mockImplementation(config => {
StdioClientTransport: vi.fn().mockImplementation(config => {
const transport = {
connect: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
connect: vi.fn().mockReturnValue(Promise.resolve()),
send: vi.fn().mockReturnValue(Promise.resolve()),
close: vi.fn().mockReturnValue(Promise.resolve()),
onclose: null as OnCloseHandler | null,
config,
};
@@ -106,16 +112,16 @@ jest.mock(
{ virtual: true }
);
jest.mock(
vi.mock(
'@modelcontextprotocol/sdk/client/sse.js',
() => {
// Using the OnCloseHandler type defined at the top level
return {
SSEClientTransport: jest.fn().mockImplementation(config => {
SSEClientTransport: vi.fn().mockImplementation(config => {
const transport = {
connect: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
connect: vi.fn().mockReturnValue(Promise.resolve()),
send: vi.fn().mockReturnValue(Promise.resolve()),
close: vi.fn().mockReturnValue(Promise.resolve()),
onclose: null as OnCloseHandler | null,
config,
};
@@ -130,52 +136,54 @@ jest.mock(
type OnCloseHandler = () => void;
// Import modules after mocking
import * as fs from 'fs';
import { MultiServerMCPClient, MCPClientError } from '../src/client.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
const fs = await import('fs');
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js');
const { MultiServerMCPClient, MCPClientError } = await import('../src/client.js');
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
// Create mock objects that will be accessible throughout the tests
const mockClientMethods = {
connect: jest.fn().mockResolvedValue(undefined),
listTools: jest.fn().mockResolvedValue({
tools: [
{
name: 'tool1',
description: 'Test tool 1',
inputSchema: { type: 'object', properties: {} },
},
{
name: 'tool2',
description: 'Test tool 2',
inputSchema: { type: 'object', properties: {} },
},
],
}),
callTool: jest.fn().mockResolvedValue({
content: [{ type: 'text', text: 'result' }],
}),
close: jest.fn().mockResolvedValue(undefined),
connect: vi.fn().mockReturnValue(Promise.resolve()),
listTools: vi.fn().mockReturnValue(
Promise.resolve({
tools: [
{
name: 'tool1',
description: 'Test tool 1',
inputSchema: { type: 'object', properties: {} },
},
{
name: 'tool2',
description: 'Test tool 2',
inputSchema: { type: 'object', properties: {} },
},
],
})
),
callTool: vi
.fn()
.mockReturnValue(Promise.resolve({ content: [{ type: 'text', text: 'result' }] })),
close: vi.fn().mockReturnValue(Promise.resolve()),
};
const mockStdioMethods = {
connect: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
triggerOnclose: jest.fn(),
connect: vi.fn().mockReturnValue(Promise.resolve()),
send: vi.fn().mockReturnValue(Promise.resolve()),
close: vi.fn().mockReturnValue(Promise.resolve()),
triggerOnclose: vi.fn(),
};
const mockSSEMethods = {
connect: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
triggerOnclose: jest.fn(),
connect: vi.fn().mockReturnValue(Promise.resolve()),
send: vi.fn().mockReturnValue(Promise.resolve()),
close: vi.fn().mockReturnValue(Promise.resolve()),
triggerOnclose: vi.fn(),
};
// Reset mocks before each test
beforeEach(() => {
jest.clearAllMocks();
vi.clearAllMocks();
// Reset mock methods
Object.values(mockClientMethods).forEach(mock => mock.mockClear());
@@ -183,7 +191,7 @@ beforeEach(() => {
Object.values(mockSSEMethods).forEach(mock => mock.mockClear());
// Reset mock implementations
mockClientMethods.listTools.mockResolvedValue({
mockClientMethods.listTools.mockReturnValue({
tools: [
{
name: 'tool1',
@@ -199,7 +207,7 @@ beforeEach(() => {
});
// Reset the mock implementations for the imported modules
(Client as jest.Mock).mockImplementation(() => ({
(Client as vi.Mock).mockImplementation(() => ({
connect: mockClientMethods.connect,
listTools: mockClientMethods.listTools,
callTool: mockClientMethods.callTool,
@@ -207,7 +215,7 @@ beforeEach(() => {
tools: [],
}));
(StdioClientTransport as jest.Mock).mockImplementation(config => {
(StdioClientTransport as vi.Mock).mockImplementation(config => {
const transport = {
connect: mockStdioMethods.connect,
send: mockStdioMethods.send,
@@ -215,13 +223,13 @@ beforeEach(() => {
onclose: null as OnCloseHandler | null,
config,
};
mockStdioMethods.triggerOnclose = jest.fn(() => {
mockStdioMethods.triggerOnclose = vi.fn(() => {
if (transport.onclose) transport.onclose();
});
return transport;
});
(SSEClientTransport as jest.Mock).mockImplementation(config => {
(SSEClientTransport as vi.Mock).mockImplementation(config => {
const transport = {
connect: mockSSEMethods.connect,
send: mockSSEMethods.send,
@@ -229,7 +237,7 @@ beforeEach(() => {
onclose: null as OnCloseHandler | null,
config,
};
mockSSEMethods.triggerOnclose = jest.fn(() => {
mockSSEMethods.triggerOnclose = vi.fn(() => {
if (transport.onclose) transport.onclose();
});
return transport;
@@ -238,7 +246,7 @@ beforeEach(() => {
describe('MultiServerMCPClient', () => {
afterEach(() => {
jest.resetAllMocks();
vi.resetAllMocks();
});
describe('Constructor', () => {
@@ -251,7 +259,7 @@ describe('MultiServerMCPClient', () => {
expect(tools).toEqual([]);
});
test('should process valid stdio connection config', () => {
test('should process valid stdio connection config', async () => {
const config = {
'test-server': {
transport: 'stdio' as const,
@@ -264,13 +272,12 @@ describe('MultiServerMCPClient', () => {
expect(client).toBeDefined();
// Initialize connections and verify
return client.initializeConnections().then(() => {
expect(StdioClientTransport).toHaveBeenCalled();
expect(Client).toHaveBeenCalled();
});
await client.initializeConnections();
expect(StdioClientTransport).toHaveBeenCalled();
expect(Client).toHaveBeenCalled();
});
test('should process valid SSE connection config', () => {
test('should process valid SSE connection config', async () => {
const config = {
'test-server': {
transport: 'sse' as const,
@@ -284,13 +291,12 @@ describe('MultiServerMCPClient', () => {
expect(client).toBeDefined();
// Initialize connections and verify
return client.initializeConnections().then(() => {
expect(SSEClientTransport).toHaveBeenCalled();
expect(Client).toHaveBeenCalled();
});
await client.initializeConnections();
expect(SSEClientTransport).toHaveBeenCalled();
expect(Client).toHaveBeenCalled();
});
test('should handle invalid connection type gracefully', () => {
test('should handle invalid connection type gracefully', async () => {
const config = {
'test-server': {
transport: 'invalid' as any,
@@ -298,34 +304,28 @@ describe('MultiServerMCPClient', () => {
},
};
const client = new MultiServerMCPClient(config);
expect(client).toBeDefined();
// Initialize connections and verify no error is thrown
return client.initializeConnections().then(() => {
// No connections should be initialized
expect(SSEClientTransport).not.toHaveBeenCalled();
expect(StdioClientTransport).not.toHaveBeenCalled();
});
// Should throw error during initialization
expect(() => {
new MultiServerMCPClient(config);
}).toThrow(MCPClientError);
});
test('should gracefully handle empty config', () => {
test('should gracefully handle empty config', async () => {
const client = new MultiServerMCPClient({});
expect(client).toBeDefined();
// Initialize connections and verify no error is thrown
return client.initializeConnections().then(() => {
// No connections should be initialized
expect(SSEClientTransport).not.toHaveBeenCalled();
expect(StdioClientTransport).not.toHaveBeenCalled();
});
await client.initializeConnections();
// No connections should be initialized
expect(SSEClientTransport).not.toHaveBeenCalled();
expect(StdioClientTransport).not.toHaveBeenCalled();
});
});
describe('Configuration Loading', () => {
test('should load config from a valid file', () => {
test('should load config from a valid file', async () => {
// Mock fs.readFileSync to return valid JSON
(fs.readFileSync as jest.Mock).mockReturnValueOnce(
(fs.readFileSync as vi.Mock).mockReturnValueOnce(
JSON.stringify({
servers: {
'test-server': {
@@ -343,7 +343,7 @@ describe('MultiServerMCPClient', () => {
});
test('should throw error for nonexistent config file', () => {
(fs.existsSync as jest.Mock).mockReturnValueOnce(false);
(fs.existsSync as vi.Mock).mockReturnValueOnce(false);
expect(() => {
MultiServerMCPClient.fromConfigFile('./nonexistent.json');
@@ -351,7 +351,7 @@ describe('MultiServerMCPClient', () => {
});
test('should throw error for invalid JSON in config file', () => {
(fs.readFileSync as jest.Mock).mockReturnValueOnce('invalid json');
(fs.readFileSync as vi.Mock).mockReturnValueOnce('invalid json');
expect(() => {
MultiServerMCPClient.fromConfigFile('./invalid.json');
@@ -360,7 +360,7 @@ describe('MultiServerMCPClient', () => {
test('should throw error for invalid config structure', () => {
// Mock readFileSync to return a config without the required 'servers' property
(fs.readFileSync as jest.Mock).mockReturnValueOnce(
(fs.readFileSync as vi.Mock).mockReturnValueOnce(
JSON.stringify({
notServers: {}, // This missing 'servers' property
})
@@ -378,7 +378,7 @@ describe('MultiServerMCPClient', () => {
});
test('should throw error for file system errors', () => {
(fs.readFileSync as jest.Mock).mockImplementationOnce(() => {
(fs.readFileSync as vi.Mock).mockImplementationOnce(() => {
throw new Error('File system error');
});
@@ -390,23 +390,20 @@ describe('MultiServerMCPClient', () => {
describe('Connection Management', () => {
test('should initialize stdio connections correctly', async () => {
// Create a properly structured config
const config = {
// Create a client instance with the config
const client = new MultiServerMCPClient({
'stdio-server': {
transport: 'stdio',
command: 'python',
args: ['./script.py'],
},
};
// Create a client instance with the config
const client = new MultiServerMCPClient(config);
});
// Reset mocks to ensure clean state
jest.clearAllMocks();
vi.clearAllMocks();
// Set up specific implementation for the StdioClientTransport mock
(StdioClientTransport as jest.Mock).mockImplementationOnce(options => {
(StdioClientTransport as vi.Mock).mockImplementationOnce(options => {
return {
connect: mockStdioMethods.connect,
send: mockStdioMethods.send,
@@ -432,22 +429,19 @@ describe('MultiServerMCPClient', () => {
});
test('should initialize SSE connections correctly', async () => {
// Create a properly structured config for SSE
const config = {
// Create a client instance with the config
const client = new MultiServerMCPClient({
'sse-server': {
transport: 'sse',
url: 'http://example.com/sse',
},
};
// Create a client instance with the config
const client = new MultiServerMCPClient(config);
});
// Reset mocks to ensure clean state
jest.clearAllMocks();
vi.clearAllMocks();
// Set up specific implementation for the SSEClientTransport mock
(SSEClientTransport as jest.Mock).mockImplementationOnce((url, options) => {
(SSEClientTransport as vi.Mock).mockImplementationOnce((url, options) => {
return {
connect: mockSSEMethods.connect,
send: mockSSEMethods.send,
@@ -468,9 +462,9 @@ describe('MultiServerMCPClient', () => {
expect(mockClientMethods.connect).toHaveBeenCalled();
});
test('should handle connection failures gracefully', async () => {
test('should throw on connection failures', async () => {
// Mock connection failure
mockClientMethods.connect.mockRejectedValueOnce(new Error('Connection failed'));
mockClientMethods.connect.mockReturnValueOnce(Promise.reject(new Error('Connection failed')));
const client = new MultiServerMCPClient({
'test-server': {
@@ -480,18 +474,15 @@ describe('MultiServerMCPClient', () => {
},
});
// Should not throw error
await client.initializeConnections();
// Still called connect but handled error
expect(mockClientMethods.connect).toHaveBeenCalled();
// Should not try to list tools after failed connection
expect(mockClientMethods.listTools).not.toHaveBeenCalled();
// Should throw error
await expect(client.initializeConnections()).rejects.toThrow();
});
test('should handle tool loading failures gracefully', async () => {
test('should throw on tool loading failures', async () => {
// Mock tool loading failure
mockClientMethods.listTools.mockRejectedValueOnce(new Error('Failed to list tools'));
mockClientMethods.listTools.mockReturnValueOnce(
Promise.reject(new Error('Failed to list tools'))
);
const client = new MultiServerMCPClient({
'test-server': {
@@ -501,16 +492,8 @@ describe('MultiServerMCPClient', () => {
},
});
// Should not throw error
await client.initializeConnections();
// Connection succeeded but tool loading failed
expect(mockClientMethods.connect).toHaveBeenCalled();
expect(mockClientMethods.listTools).toHaveBeenCalled();
// Should have empty tools
const tools = client.getTools();
expect(tools).toEqual([]);
// Should throw error
await expect(client.initializeConnections()).rejects.toThrow();
});
});
@@ -532,7 +515,7 @@ describe('MultiServerMCPClient', () => {
await client.initializeConnections();
// Clear previous calls
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
mockClientMethods.connect.mockClear();
// Trigger onclose handler
@@ -563,7 +546,7 @@ describe('MultiServerMCPClient', () => {
await client.initializeConnections();
// Clear previous calls
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
mockClientMethods.connect.mockClear();
// Trigger onclose handler
@@ -584,7 +567,7 @@ describe('MultiServerMCPClient', () => {
const client = new MultiServerMCPClient();
// Clear previous mock invocations
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
// Connect with reconnection enabled
await client.connectToServerViaStdio(
@@ -623,7 +606,7 @@ describe('MultiServerMCPClient', () => {
await client.initializeConnections();
// Clear previous calls
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
// Trigger onclose handler
mockStdioMethods.triggerOnclose();
@@ -639,12 +622,14 @@ describe('MultiServerMCPClient', () => {
describe('Tool Management', () => {
test('should get all tools as a flattened array', async () => {
// Mock tool response
mockClientMethods.listTools.mockResolvedValue({
tools: [
{ name: 'tool1', description: 'Tool 1', inputSchema: {} },
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
],
});
mockClientMethods.listTools.mockReturnValue(
Promise.resolve({
tools: [
{ name: 'tool1', description: 'Tool 1', inputSchema: {} },
{ name: 'tool2', description: 'Tool 2', inputSchema: {} },
],
})
);
const client = new MultiServerMCPClient({
server1: {
@@ -717,7 +702,7 @@ describe('MultiServerMCPClient', () => {
test('should handle errors during cleanup gracefully', async () => {
// Mock close to throw error
mockStdioMethods.close.mockRejectedValueOnce(new Error('Close failed'));
mockStdioMethods.close.mockReturnValueOnce(Promise.reject(new Error('Close failed')));
const client = new MultiServerMCPClient({
'test-server': {
@@ -738,8 +723,8 @@ describe('MultiServerMCPClient', () => {
test('should clean up all resources even if some fail', async () => {
// First close fails, second succeeds
mockStdioMethods.close.mockRejectedValueOnce(new Error('Close failed'));
mockSSEMethods.close.mockResolvedValue(undefined);
mockStdioMethods.close.mockReturnValueOnce(Promise.reject(new Error('Close failed')));
mockSSEMethods.close.mockReturnValueOnce(Promise.resolve());
const client = new MultiServerMCPClient({
'stdio-server': {
@@ -821,7 +806,7 @@ describe('MultiServerMCPClient', () => {
expect(StdioClientTransport).toHaveBeenCalled();
// Simulate connection close to test restart
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
mockStdioMethods.triggerOnclose();
// Wait for reconnection
@@ -833,7 +818,7 @@ describe('MultiServerMCPClient', () => {
test('should connect to an SSE server correctly', async () => {
// Clear previous mock invocations
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
const client = new MultiServerMCPClient();
await client.connectToServerViaSSE('test-server', 'http://localhost:8000/sse');
@@ -846,7 +831,7 @@ describe('MultiServerMCPClient', () => {
test('should connect with headers', async () => {
// Clear previous mock invocations
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
const client = new MultiServerMCPClient();
const headers = { Authorization: 'Bearer token' };
@@ -860,7 +845,7 @@ describe('MultiServerMCPClient', () => {
test('should connect with useNodeEventSource option', async () => {
// Clear previous mock invocations
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
const client = new MultiServerMCPClient();
await client.connectToServerViaSSE(
@@ -891,7 +876,7 @@ describe('MultiServerMCPClient', () => {
expect(SSEClientTransport).toHaveBeenCalled();
// Simulate connection close to test reconnect
(SSEClientTransport as jest.Mock).mockClear();
(SSEClientTransport as vi.Mock).mockClear();
mockSSEMethods.triggerOnclose();
// Wait for reconnection
@@ -909,7 +894,7 @@ describe('MultiServerMCPClient', () => {
// Clear mock history
mockStdioMethods.close.mockClear();
(StdioClientTransport as jest.Mock).mockClear();
(StdioClientTransport as vi.Mock).mockClear();
// Connect again with same name (should close previous)
await client.connectToServerViaStdio('test-server', 'node', ['script.js']);
@@ -936,16 +921,18 @@ describe('MultiServerMCPClient', () => {
expect(result).toEqual([]);
});
test('should handle transport creation errors', async () => {
test('should throw on transport creation errors', async () => {
// Force an error when creating transport
(StdioClientTransport as jest.Mock).mockImplementationOnce(() => {
(StdioClientTransport as vi.Mock).mockImplementationOnce(() => {
throw new Error('Transport creation failed');
});
const client = new MultiServerMCPClient();
// Should not throw
await client.connectToServerViaStdio('test-server', 'python', ['./script.py']);
// Should throw error when connecting
await expect(
client.connectToServerViaStdio('test-server', 'python', ['./script.py'])
).rejects.toThrow();
// Should have attempted to create transport
expect(StdioClientTransport).toHaveBeenCalled();
+40 -34
View File
@@ -1,31 +1,39 @@
import * as fs from 'fs';
import * as path from 'path';
import winston from 'winston';
import { describe, test, expect, beforeEach, afterEach, vi } from 'vitest';
// Mock fs and path modules
jest.mock('fs');
jest.mock('path');
vi.mock('fs', () => ({
existsSync: vi.fn(),
mkdirSync: vi.fn(),
writeFileSync: vi.fn(),
unlinkSync: vi.fn(),
accessSync: vi.fn(),
}));
vi.mock('path', () => ({
join: vi.fn(),
}));
const fs = await import('fs');
const path = await import('path');
const winston = await import('winston');
describe('Logger', () => {
// Store original console.warn implementation
const originalConsoleWarn = console.warn;
let consoleWarnMock: jest.SpyInstance;
let consoleWarnMock: any;
beforeEach(() => {
// Clear module cache to ensure logger is reinitialized
jest.resetModules();
vi.resetModules();
// Mock console.warn to capture warnings
consoleWarnMock = jest.spyOn(console, 'warn').mockImplementation();
consoleWarnMock = vi.spyOn(console, 'warn').mockImplementation((..._args) => {});
// Configure path.join to return predictable paths
(path.join as jest.Mock).mockImplementation((...args) => args.join('/'));
(path.join as any).mockImplementation((...args: string[]) => args.join('/'));
// Reset fs mock implementation
(fs.existsSync as jest.Mock).mockReset();
(fs.mkdirSync as jest.Mock).mockReset();
(fs.writeFileSync as jest.Mock).mockReset();
(fs.unlinkSync as jest.Mock).mockReset();
(fs.existsSync as any).mockReset();
(fs.mkdirSync as any).mockReset();
(fs.writeFileSync as any).mockReset();
(fs.unlinkSync as any).mockReset();
});
afterEach(() => {
@@ -36,10 +44,10 @@ describe('Logger', () => {
test('should fallback to console-only logging when directory creation fails', async () => {
// Mock fs.existsSync to return false (directory doesn't exist)
(fs.existsSync as jest.Mock).mockReturnValue(false);
(fs.existsSync as any).mockReturnValue(false);
// Mock fs.mkdirSync to throw an error
(fs.mkdirSync as jest.Mock).mockImplementation(() => {
(fs.mkdirSync as any).mockImplementation(() => {
throw new Error('Permission denied');
});
@@ -59,10 +67,10 @@ describe('Logger', () => {
test('should fallback to console-only logging when write permission test fails', async () => {
// Mock fs.existsSync to return true (directory exists)
(fs.existsSync as jest.Mock).mockReturnValue(true);
(fs.existsSync as any).mockReturnValue(true);
// Mock fs.writeFileSync to throw an error
(fs.writeFileSync as jest.Mock).mockImplementation(() => {
(fs.writeFileSync as any).mockImplementation(() => {
throw new Error('Permission denied');
});
@@ -80,24 +88,22 @@ describe('Logger', () => {
expect(logger.transports[0]).toBeInstanceOf(winston.transports.Console);
});
test('should set up file transports when permissions are available', () => {
test('should set up file transports when permissions are available', async () => {
// Mock all the file operations to succeed
(fs.mkdirSync as jest.Mock).mockImplementation(() => true);
(fs.accessSync as jest.Mock).mockImplementation(() => true);
(fs.existsSync as jest.Mock).mockReturnValue(true);
(fs.writeFileSync as jest.Mock).mockImplementation(() => undefined);
(fs.mkdirSync as any).mockImplementation(() => true);
(fs.accessSync as any).mockImplementation(() => true);
(fs.existsSync as any).mockReturnValue(true);
(fs.writeFileSync as any).mockImplementation(() => undefined);
// Import logger (after mocks are set up)
jest.isolateModules(async () => {
const loggerModule = await import('../src/logger.js');
const logger = loggerModule.default;
// Import logger directly
const loggerModule = await import('../src/logger.js');
const logger = loggerModule.default;
// Just verify logger was created - don't worry about warnings
expect(logger).toBeDefined();
expect(typeof logger.debug).toBe('function');
expect(typeof logger.info).toBe('function');
expect(typeof logger.warn).toBe('function');
expect(typeof logger.error).toBe('function');
});
// Just verify logger was created - don't worry about warnings
expect(logger).toBeDefined();
expect(typeof logger.debug).toBe('function');
expect(typeof logger.info).toBe('function');
expect(typeof logger.warn).toBe('function');
expect(typeof logger.error).toBe('function');
});
});
+70 -186
View File
@@ -1,174 +1,31 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { convertMcpToolToLangchainTool, loadMcpTools } from '../src/tools.js';
import { z } from 'zod';
import { describe, test, expect, beforeEach, vi } from 'vitest';
const { loadMcpTools } = await import('../src/tools.js');
// Create a mock client
const mockClient = {
callTool: jest.fn(),
listTools: jest.fn(),
callTool: vi.fn(),
listTools: vi.fn(),
};
describe('Simplified Tool Adapter Tests', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('convertMcpToolToLangchainTool', () => {
test('should convert MCP tool to LangChain tool with text content', async () => {
// Set up mock tool
const mcpTool = {
name: 'testTool',
description: 'A test tool',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
},
};
// Set up mock response
mockClient.callTool.mockResolvedValueOnce({
content: [{ type: 'text', text: 'Test result' }],
});
// Convert tool
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
// Verify tool properties
expect(tool.name).toBe('testTool');
expect(tool.description).toBe('A test tool');
// Call the tool
const result = await tool.invoke({ input: 'test' });
// Verify that the client was called with the right arguments
expect(mockClient.callTool).toHaveBeenCalledWith({
name: 'testTool',
arguments: { input: 'test' },
});
// Verify result
expect(result).toBe('Test result');
});
test('should handle error results', async () => {
// Set up mock tool
const mcpTool = {
name: 'errorTool',
description: 'A tool that errors',
};
// Set up mock response
mockClient.callTool.mockResolvedValueOnce({
isError: true,
content: [{ type: 'text', text: 'Error message' }],
});
// Convert tool
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
// Call the tool and expect an error
await expect(tool.invoke({ input: 'test' })).rejects.toThrow('Error message');
});
test('should handle non-text content', async () => {
// Set up mock tool
const mcpTool = {
name: 'imageTool',
description: 'A tool that returns images',
};
// Set up mock response with non-text content
mockClient.callTool.mockResolvedValueOnce({
content: [
{ type: 'text', text: 'Image caption' },
{ type: 'image', url: 'http://example.com/image.jpg' },
],
});
// Convert tool
const tool = convertMcpToolToLangchainTool(mockClient as unknown as Client, mcpTool);
// Call the tool
const result = await tool.invoke({ input: 'test' });
// Verify result (should only include text content)
expect(result).toBe('Image caption');
});
test('should return both text and non-text content with content_and_artifact format', async () => {
// Set up mock tool
const mcpTool = {
name: 'multiTool',
description: 'A tool that returns multiple content types',
};
// Set up mock response with mixed content
const mockImageContent = { type: 'image', url: 'http://example.com/image.jpg' };
mockClient.callTool.mockResolvedValueOnce({
content: [{ type: 'text', text: 'Here is your image' }, mockImageContent],
});
// Convert tool with content_and_artifact response format
const tool = convertMcpToolToLangchainTool(
mockClient as unknown as Client,
mcpTool,
undefined,
'content_and_artifact'
);
// Verify tool properties
console.log('Tool class:', tool.constructor.name);
console.log('Tool responseFormat:', (tool as any).responseFormat);
// Call the tool
const result = await tool.invoke({ input: 'test' });
// Debug the result
console.log('Result type:', typeof result);
console.log('Result:', JSON.stringify(result));
console.log('Is array:', Array.isArray(result));
// The result is the text content, as LangChain processes the array in the call method
expect(result).toBe('Here is your image');
// Set up a second mock response for the direct call test
mockClient.callTool.mockResolvedValueOnce({
content: [{ type: 'text', text: 'Here is your image' }, mockImageContent],
});
// Access the tool class implementation through the prototype chain
const toolPrototype = Object.getPrototypeOf(tool);
// Call the _call method directly with bound 'this' context
const directResult = await toolPrototype._call.call(tool, { input: 'test' });
expect(Array.isArray(directResult)).toBe(true);
const [textContent, nonTextContent] = directResult as [string, any[]];
// Check the text content
expect(textContent).toBe('Here is your image');
// Check the non-text content
expect(Array.isArray(nonTextContent)).toBe(true);
expect(nonTextContent.length).toBe(1);
expect(nonTextContent[0]).toEqual(mockImageContent);
});
vi.clearAllMocks();
});
describe('loadMcpTools', () => {
test('should load all tools from client', async () => {
// Set up mock response
mockClient.listTools.mockResolvedValueOnce({
tools: [
{ name: 'tool1', description: 'Tool 1' },
{ name: 'tool2', description: 'Tool 2' },
],
});
mockClient.listTools.mockReturnValueOnce(
Promise.resolve({
tools: [
{ name: 'tool1', description: 'Tool 1' },
{ name: 'tool2', description: 'Tool 2' },
],
})
);
// Load tools
const tools = await loadMcpTools(mockClient as unknown as Client);
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
// Verify results
expect(tools.length).toBe(2);
@@ -178,12 +35,14 @@ describe('Simplified Tool Adapter Tests', () => {
test('should handle empty tool list', async () => {
// Set up mock response
mockClient.listTools.mockResolvedValueOnce({
tools: [],
});
mockClient.listTools.mockReturnValueOnce(
Promise.resolve({
tools: [],
})
);
// Load tools
const tools = await loadMcpTools(mockClient as unknown as Client);
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
// Verify results
expect(tools.length).toBe(0);
@@ -191,16 +50,18 @@ describe('Simplified Tool Adapter Tests', () => {
test('should filter out tools without names', async () => {
// Set up mock response
mockClient.listTools.mockResolvedValueOnce({
tools: [
{ name: 'tool1', description: 'Tool 1' },
{ description: 'No name tool' }, // Should be filtered out
{ name: 'tool2', description: 'Tool 2' },
],
});
mockClient.listTools.mockReturnValueOnce(
Promise.resolve({
tools: [
{ name: 'tool1', description: 'Tool 1' },
{ description: 'No name tool' }, // Should be filtered out
{ name: 'tool2', description: 'Tool 2' },
],
})
);
// Load tools
const tools = await loadMcpTools(mockClient as unknown as Client);
const tools = await loadMcpTools(mockClient as unknown as Parameters<typeof loadMcpTools>[0]);
// Verify results
expect(tools.length).toBe(2);
@@ -209,39 +70,62 @@ describe('Simplified Tool Adapter Tests', () => {
});
test('should load tools with specified response format', async () => {
// Set up mock response
mockClient.listTools.mockResolvedValueOnce({
tools: [{ name: 'tool1', description: 'Tool 1' }],
});
// Set up mock response with input schema
mockClient.listTools.mockReturnValueOnce(
Promise.resolve({
tools: [
{
name: 'tool1',
description: 'Tool 1',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
required: ['input'],
},
},
],
})
);
// Load tools with content_and_artifact response format
const tools = await loadMcpTools(mockClient as unknown as Client, 'content_and_artifact');
const tools = await loadMcpTools(
mockClient as unknown as Parameters<typeof loadMcpTools>[0],
'content_and_artifact'
);
// Verify tool was loaded
expect(tools.length).toBe(1);
expect((tools[0] as any).responseFormat).toBe('content_and_artifact');
expect((tools[0] as any).responseFormat).toBe('content');
// Mock the call result to check response format handling
const mockImageContent = { type: 'image', url: 'http://example.com/image.jpg' };
mockClient.callTool.mockResolvedValueOnce({
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
});
mockClient.callTool.mockReturnValueOnce(
Promise.resolve({
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
})
);
// Invoke the tool
const result = await tools[0].invoke({ test: 'input' });
// Invoke the tool with proper input matching the schema
const result = await tools[0].invoke({ input: 'test input' });
// The result is the text content, as LangChain processes the array in the call method
expect(result).toBe('Image result');
// Verify the result
expect(Array.isArray(result)).toBe(true);
expect(result[0]).toBe('Image result');
expect(result[1]).toEqual([mockImageContent]);
// Set up a second mock response for the direct call test
mockClient.callTool.mockResolvedValueOnce({
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
});
mockClient.callTool.mockReturnValueOnce(
Promise.resolve({
content: [{ type: 'text', text: 'Image result' }, mockImageContent],
})
);
// Access the tool class implementation through the prototype chain
const toolPrototype = Object.getPrototypeOf(tools[0]);
// Call the _call method directly with bound 'this' context
const directResult = await toolPrototype._call.call(tools[0], { test: 'input' });
const directResult = await toolPrototype._call.call(tools[0], { input: 'test input' });
expect(Array.isArray(directResult)).toBe(true);
-27
View File
@@ -1,27 +0,0 @@
export default {
preset: 'ts-jest',
testEnvironment: 'node',
testMatch: ['**/__tests__/**/*.test.ts'],
collectCoverage: true,
collectCoverageFrom: ['src/**/*.ts', '!src/**/*.test.ts'],
coverageDirectory: 'coverage',
coverageReporters: ['text', 'lcov'],
moduleFileExtensions: ['ts', 'js', 'json'],
transform: {
'^.+\\.tsx?$': [
'ts-jest',
{
useESM: true,
},
],
},
extensionsToTreatAsEsm: ['.ts'],
moduleNameMapper: {
'^(\\.{1,2}/.*)\\.js$': '$1',
'^node:(.*)$': '$1',
},
transformIgnorePatterns: [
'/node_modules/(?!(@dmitryrechkin/json-schema-to-zod|pkce-challenge|@modelcontextprotocol)/)',
],
setupFiles: ['<rootDir>/jest.setup.js'],
};
+1984 -3070
View File
File diff suppressed because it is too large Load Diff
+7 -7
View File
@@ -15,8 +15,9 @@
},
"scripts": {
"build": "run-s \"build:main -- {@}\" \"build:examples -- {@}\" --",
"test": "jest",
"test:coverage": "jest --coverage",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"test:watch": "vitest",
"lint": "eslint --ignore-pattern 'dist/**' .",
"lint:fix": "eslint --ignore-pattern 'dist/**' . --fix",
"format": "prettier --write \"src/**/*.ts\" \"examples/**/*.ts\"",
@@ -40,6 +41,7 @@
"author": "Ravi Kiran Vemula",
"license": "MIT",
"dependencies": {
"@dmitryrechkin/json-schema-to-zod": "^1.0.1",
"@modelcontextprotocol/sdk": "^1.7.0",
"winston": "^3.17.0"
},
@@ -53,8 +55,7 @@
"@eslint/js": "^9.21.0",
"@langchain/langgraph": "^0.2.56",
"@langchain/openai": "^0.4.4",
"@types/jest": "^29.5.12",
"@types/node": "^20.11.30",
"@types/node": "^22.13.10",
"@typescript-eslint/eslint-plugin": "^7.3.1",
"@typescript-eslint/parser": "^7.3.1",
"dotenv": "^16.4.7",
@@ -62,14 +63,13 @@
"eslint-config-prettier": "^10.0.2",
"eventsource": "^3.0.5",
"husky": "^9.0.11",
"jest": "^29.7.0",
"lint-staged": "^15.2.2",
"npm-run-all": "^4.1.5",
"prettier": "^3.2.5",
"ts-jest": "^29.1.2",
"ts-node": "^10.9.2",
"typescript": "^5.4.2",
"typescript-eslint": "^8.26.0"
"typescript-eslint": "^8.26.0",
"vitest": "^3.0.9"
},
"engines": {
"node": ">=18"
+208 -54
View File
@@ -2,7 +2,6 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StructuredToolInterface } from '@langchain/core/tools';
import { z } from 'zod';
import { loadMcpTools } from './tools.js';
import * as fs from 'fs';
import * as path from 'path';
@@ -94,7 +93,7 @@ export class MCPClientError extends Error {
*/
export class MultiServerMCPClient {
private clients: Map<string, Client> = new Map();
private serverNameToTools: Map<string, StructuredToolInterface<z.ZodObject<any>>[]> = new Map();
private serverNameToTools: Map<string, StructuredToolInterface[]> = new Map();
private connections?: Record<string, Connection>;
private cleanupFunctions: Array<() => Promise<void>> = [];
private transportInstances: Map<string, StdioClientTransport | SSEClientTransport> = new Map();
@@ -104,7 +103,7 @@ export class MultiServerMCPClient {
*
* @param connections - Optional connections to initialize
*/
constructor(connections?: Record<string, any>) {
constructor(connections?: Record<string, Connection>) {
if (connections) {
this.connections = this.processConnections(connections);
} else {
@@ -212,7 +211,9 @@ export class MultiServerMCPClient {
* @param connections - Raw connection configurations
* @returns Processed connection configurations
*/
private processConnections(connections: Record<string, any>): Record<string, Connection> {
private processConnections(
connections: Record<string, Partial<Connection>>
): Record<string, Connection> {
const processedConnections: Record<string, Connection> = {};
for (const [serverName, config] of Object.entries(connections)) {
@@ -221,17 +222,21 @@ export class MultiServerMCPClient {
continue;
}
try {
// Determine the connection type and process accordingly
if (this.isStdioConnection(config)) {
processedConnections[serverName] = this.processStdioConfig(serverName, config);
} else if (this.isSSEConnection(config)) {
processedConnections[serverName] = this.processSSEConfig(serverName, config);
} else {
logger.warn(`Server "${serverName}" has invalid or unsupported configuration. Skipping.`);
}
} catch (error) {
logger.error(`Error processing configuration for server "${serverName}": ${error}`);
// Determine the connection type and process accordingly
if (MultiServerMCPClient.isStdioConnection(config)) {
processedConnections[serverName] = MultiServerMCPClient.processStdioConfig(
serverName,
config
);
} else if (MultiServerMCPClient.isSSEConnection(config)) {
processedConnections[serverName] = MultiServerMCPClient.processSSEConfig(
serverName,
config
);
} else {
throw new MCPClientError(
`Server "${serverName}" has invalid or unsupported configuration. Skipping.`
);
}
}
@@ -241,45 +246,108 @@ export class MultiServerMCPClient {
/**
* Check if a configuration is for a stdio connection
*/
private isStdioConnection(config: any): boolean {
private static isStdioConnection(config: unknown): config is StdioConnection {
// When transport is missing, default to stdio if it has command and args
// OR when transport is explicitly set to 'stdio'
return (
(config.transport === 'stdio' || !config.transport) &&
config.command &&
Array.isArray(config.args)
typeof config === 'object' &&
config !== null &&
(!('transport' in config) || config.transport === 'stdio') &&
'command' in config &&
(!('args' in config) || Array.isArray(config.args))
);
}
/**
* Check if a configuration is for an SSE connection
*/
private isSSEConnection(config: any): boolean {
private static isSSEConnection(config: unknown): config is SSEConnection {
// Only consider it an SSE connection if transport is explicitly set to 'sse'
return config.transport === 'sse' && typeof config.url === 'string';
return (
typeof config === 'object' &&
config !== null &&
'transport' in config &&
config.transport === 'sse' &&
'url' in config &&
typeof config.url === 'string'
);
}
/**
* Process stdio connection configuration
*/
private processStdioConfig(serverName: string, config: any): StdioConnection {
private static processStdioConfig(
serverName: string,
config: Partial<StdioConnection>
): StdioConnection {
if (!config.command || typeof config.command !== 'string') {
throw new MCPClientError(`Missing or invalid command for server "${serverName}"`);
}
if (config.args !== undefined && !Array.isArray(config.args)) {
throw new MCPClientError(
`Invalid args for server "${serverName} - must be an array of strings`
);
}
if (config.args !== undefined && !config.args.every(arg => typeof arg === 'string')) {
throw new MCPClientError(
`Invalid args for server "${serverName} - must be an array of strings`
);
}
// Always set transport to 'stdio' regardless of whether it was in the original config
const stdioConfig: StdioConnection = {
transport: 'stdio',
command: config.command,
args: config.args,
args: config.args ?? [],
};
if (config.env && typeof config.env !== 'object') {
throw new MCPClientError(
`Invalid env for server "${serverName} - must be an object of key-value pairs`
);
}
if (config.env && typeof config.env === 'object' && Array.isArray(config.env)) {
throw new MCPClientError(
`Invalid env for server "${serverName} - must be an object of key-value pairs`
);
}
if (
config.env &&
typeof config.env === 'object' &&
!Object.values(config.env).every(value => typeof value === 'string')
) {
throw new MCPClientError(
`Invalid env for server "${serverName} - must be an object of key-value pairs with string values`
);
}
// Add optional properties if they exist
if (config.env && typeof config.env === 'object') {
stdioConfig.env = config.env;
}
if (config.encoding !== undefined && typeof config.encoding !== 'string') {
throw new MCPClientError(`Invalid encoding for server "${serverName} - must be a string`);
}
if (typeof config.encoding === 'string') {
stdioConfig.encoding = config.encoding;
}
if (['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler)) {
if (
config.encodingErrorHandler !== undefined &&
!['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler)
) {
throw new MCPClientError(
`Invalid encodingErrorHandler for server "${serverName} - must be one of: strict, ignore, replace`
);
}
if (['strict', 'ignore', 'replace'].includes(config.encodingErrorHandler ?? '')) {
stdioConfig.encodingErrorHandler = config.encodingErrorHandler as
| 'strict'
| 'ignore'
@@ -287,15 +355,40 @@ export class MultiServerMCPClient {
}
// Add restart configuration if present
if (config.restart && typeof config.restart !== 'object') {
throw new MCPClientError(`Invalid restart for server "${serverName} - must be an object`);
}
if (config.restart && typeof config.restart === 'object') {
if (config.restart.enabled !== undefined && typeof config.restart.enabled !== 'boolean') {
throw new MCPClientError(
`Invalid restart.enabled for server "${serverName} - must be a boolean`
);
}
stdioConfig.restart = {
enabled: Boolean(config.restart.enabled),
};
if (
config.restart.maxAttempts !== undefined &&
typeof config.restart.maxAttempts !== 'number'
) {
throw new MCPClientError(
`Invalid restart.maxAttempts for server "${serverName} - must be a number`
);
}
if (typeof config.restart.maxAttempts === 'number') {
stdioConfig.restart.maxAttempts = config.restart.maxAttempts;
}
if (config.restart.delayMs !== undefined && typeof config.restart.delayMs !== 'number') {
throw new MCPClientError(
`Invalid restart.delayMs for server "${serverName} - must be a number`
);
}
if (typeof config.restart.delayMs === 'number') {
stdioConfig.restart.delayMs = config.restart.delayMs;
}
@@ -307,32 +400,102 @@ export class MultiServerMCPClient {
/**
* Process SSE connection configuration
*/
private processSSEConfig(serverName: string, config: any): SSEConnection {
private static processSSEConfig(serverName: string, config: SSEConnection): SSEConnection {
if (!config.url || typeof config.url !== 'string') {
throw new MCPClientError(`Missing or invalid url for server "${serverName}"`);
}
try {
const url = new URL(config.url);
if (!url.protocol.startsWith('http')) {
throw new MCPClientError(
`Invalid url for server "${serverName} - must be a valid HTTP or HTTPS URL`
);
}
} catch {
throw new MCPClientError(`Invalid url for server "${serverName} - must be a valid URL`);
}
if (!config.transport || config.transport !== 'sse') {
throw new MCPClientError(`Invalid transport for server "${serverName} - must be 'sse'`);
}
const sseConfig: SSEConnection = {
transport: 'sse',
url: config.url,
};
if (config.headers && typeof config.headers !== 'object') {
throw new MCPClientError(`Invalid headers for server "${serverName} - must be an object`);
}
if (config.headers && typeof config.headers === 'object' && Array.isArray(config.headers)) {
throw new MCPClientError(
`Invalid headers for server "${serverName} - must be an object of key-value pairs`
);
}
if (
config.headers &&
typeof config.headers === 'object' &&
!Object.values(config.headers).every(value => typeof value === 'string')
) {
throw new MCPClientError(
`Invalid headers for server "${serverName} - must be an object of key-value pairs with string values`
);
}
// Add optional headers if they exist
if (config.headers && typeof config.headers === 'object') {
sseConfig.headers = config.headers;
}
if (config.useNodeEventSource !== undefined && typeof config.useNodeEventSource !== 'boolean') {
throw new MCPClientError(
`Invalid useNodeEventSource for server "${serverName} - must be a boolean`
);
}
// Add optional useNodeEventSource flag if it exists
if (typeof config.useNodeEventSource === 'boolean') {
sseConfig.useNodeEventSource = config.useNodeEventSource;
}
if (config.reconnect && typeof config.reconnect !== 'object') {
throw new MCPClientError(`Invalid reconnect for server "${serverName} - must be an object`);
}
// Add reconnection configuration if present
if (config.reconnect && typeof config.reconnect === 'object') {
if (config.reconnect.enabled !== undefined && typeof config.reconnect.enabled !== 'boolean') {
throw new MCPClientError(
`Invalid reconnect.enabled for server "${serverName} - must be a boolean`
);
}
sseConfig.reconnect = {
enabled: Boolean(config.reconnect.enabled),
};
if (
config.reconnect.maxAttempts !== undefined &&
typeof config.reconnect.maxAttempts !== 'number'
) {
throw new MCPClientError(
`Invalid reconnect.maxAttempts for server "${serverName} - must be a number`
);
}
if (typeof config.reconnect.maxAttempts === 'number') {
sseConfig.reconnect.maxAttempts = config.reconnect.maxAttempts;
}
if (config.reconnect.delayMs !== undefined && typeof config.reconnect.delayMs !== 'number') {
throw new MCPClientError(
`Invalid reconnect.delayMs for server "${serverName} - must be a number`
);
}
if (typeof config.reconnect.delayMs === 'number') {
sseConfig.reconnect.delayMs = config.reconnect.delayMs;
}
@@ -377,33 +540,25 @@ export class MultiServerMCPClient {
* @returns A map of server names to arrays of tools
* @throws {MCPClientError} If initialization fails
*/
async initializeConnections(): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
async initializeConnections(): Promise<Map<string, StructuredToolInterface[]>> {
if (!this.connections || Object.keys(this.connections).length === 0) {
logger.warn('No connections to initialize');
return new Map();
}
for (const [serverName, connection] of Object.entries(this.connections)) {
try {
logger.info(`Initializing connection to server "${serverName}"...`);
logger.info(`Initializing connection to server "${serverName}"...`);
if (connection.transport === 'stdio') {
await this.initializeStdioConnection(serverName, connection);
} else if (connection.transport === 'sse') {
await this.initializeSSEConnection(serverName, connection);
} else {
// This should never happen due to the validation in the constructor
throw new MCPClientError(
`Unsupported transport type for server "${serverName}"`,
serverName
);
}
} catch (error) {
if (error instanceof MCPClientError) {
logger.error(error.message);
} else {
logger.error(`Failed to connect to server "${serverName}": ${error}`);
}
if (connection.transport === 'stdio') {
await this.initializeStdioConnection(serverName, connection);
} else if (connection.transport === 'sse') {
await this.initializeSSEConnection(serverName, connection);
} else {
// This should never happen due to the validation in the constructor
throw new MCPClientError(
`Unsupported transport type for server "${serverName}"`,
serverName
);
}
}
@@ -674,8 +829,7 @@ export class MultiServerMCPClient {
this.serverNameToTools.set(serverName, tools);
logger.info(`Successfully loaded ${tools.length} tools from server "${serverName}"`);
} catch (error) {
logger.error(`Failed to load tools from server "${serverName}": ${error}`);
// Continue even if tool loading fails - the connection is still established
throw new MCPClientError(`Failed to load tools from server "${serverName}": ${error}`);
}
}
@@ -752,7 +906,7 @@ export class MultiServerMCPClient {
* If not provided, returns tools from all servers.
* @returns A flattened array of tools from the specified servers (or all servers)
*/
getTools(servers?: string[]): StructuredToolInterface<z.ZodObject<any>>[] {
getTools(servers?: string[]): StructuredToolInterface[] {
if (!servers || servers.length === 0) {
return this.getAllToolsAsFlatArray();
}
@@ -764,8 +918,8 @@ export class MultiServerMCPClient {
*
* @returns A flattened array of all tools
*/
private getAllToolsAsFlatArray(): StructuredToolInterface<z.ZodObject<any>>[] {
const allTools: StructuredToolInterface<z.ZodObject<any>>[] = [];
private getAllToolsAsFlatArray(): StructuredToolInterface[] {
const allTools: StructuredToolInterface[] = [];
for (const tools of this.serverNameToTools.values()) {
allTools.push(...tools);
}
@@ -778,8 +932,8 @@ export class MultiServerMCPClient {
* @param serverNames - Names of servers to get tools from
* @returns A flattened array of tools from the specified servers
*/
private getToolsFromServers(serverNames: string[]): StructuredToolInterface<z.ZodObject<any>>[] {
const allTools: StructuredToolInterface<z.ZodObject<any>>[] = [];
private getToolsFromServers(serverNames: string[]): StructuredToolInterface[] {
const allTools: StructuredToolInterface[] = [];
for (const serverName of serverNames) {
const tools = this.serverNameToTools.get(serverName);
if (tools) {
@@ -837,7 +991,7 @@ export class MultiServerMCPClient {
args: string[],
env?: Record<string, string>,
restart?: StdioConnection['restart']
): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
): Promise<Map<string, StructuredToolInterface[]>> {
const connections: Record<string, Connection> = {
[serverName]: {
transport: 'stdio',
@@ -868,7 +1022,7 @@ export class MultiServerMCPClient {
headers?: Record<string, string>,
useNodeEventSource?: boolean,
reconnect?: SSEConnection['reconnect']
): Promise<Map<string, StructuredToolInterface<z.ZodObject<any>>[]>> {
): Promise<Map<string, StructuredToolInterface[]>> {
const connection: SSEConnection = {
transport: 'sse',
url,
+1 -2
View File
@@ -1,3 +1,2 @@
export { MultiServerMCPClient } from './client.js';
export { convertMcpToolToLangchainTool, loadMcpTools } from './tools.js';
export { default as logger, enableLogging, disableLogging } from './logger.js';
export { logger, enableLogging, disableLogging } from './logger.js';
+2 -2
View File
@@ -97,7 +97,7 @@ try {
* Create the logger instance with our configuration.
* By default, logging is disabled (silent) but can be enabled by setting the level.
*/
const logger = winston.createLogger({
export const logger = winston.createLogger({
level: defaultLevel, // Start with silent logging by default
levels,
format,
@@ -111,7 +111,7 @@ const logger = winston.createLogger({
*/
export function enableLogging(level: keyof typeof levels | 'silent' = 'info'): void {
logger.level = level;
logger.info(`Logging enabled at level: ${level}`);
logger.debug(`Logging enabled at level: ${level}`);
}
/**
+74 -81
View File
@@ -1,6 +1,11 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StructuredTool, StructuredToolInterface } from '@langchain/core/tools';
import { z } from 'zod';
import {
DynamicStructuredTool,
ResponseFormat,
StructuredToolInterface,
} from '@langchain/core/tools';
import { JSONSchema, JSONSchemaToZod } from '@dmitryrechkin/json-schema-to-zod';
import logger from './logger.js';
interface TextContent {
@@ -72,73 +77,53 @@ function _convertCallToolResult(
}
/**
* Valid response formats for MCP tools
*/
export type ResponseFormat = 'text' | 'content_and_artifact';
/**
* Convert an MCP tool to a LangChain tool.
* Call an MCP tool.
*
* Use this with `.bind` to capture the fist three arguments, then pass to the constructor of DynamicStructuredTool.
*
* @internal
*
* @param client - The MCP client
* @param tool - The MCP tool to convert
* @param toolSchema - Tool schema (kept for backward compatibility, not used in current implementation)
* @param responseFormat - Response format ('text' or 'content_and_artifact')
* @returns A LangChain tool
* @param name - The name of the tool (forwarded to the client)
* @param responseFormat - The response format
* @param args - The arguments to pass to the tool
* @returns A tuple of [textContent, nonTextContent]
*/
export function convertMcpToolToLangchainTool(
async function _callTool(
client: Client,
tool: { name: string; description?: string; inputSchema?: any },
toolSchema?: any,
responseFormat: ResponseFormat = 'content_and_artifact'
): StructuredToolInterface {
// Create a minimal MCPTool class extending StructuredTool
class MCPTool extends StructuredTool {
name = tool.name;
description = tool.description || '';
schema = z.object({}).passthrough();
override responseFormat: ResponseFormat;
name: string,
responseFormat: ResponseFormat,
args: Record<string, unknown>
): Promise<string | [string | string[], NonTextContent[] | null]> {
try {
logger.info(`Calling tool ${name}(${JSON.stringify(args)})`);
const result = await client.callTool({
name,
arguments: args,
});
constructor(responseFormat: ResponseFormat) {
super({
responseFormat,
verboseParsingErrors: false,
});
this.responseFormat = responseFormat;
const [textContent, nonTextContent] = _convertCallToolResult({
...result,
isError: result.isError === true,
content: result.content || [],
});
logger.info(`Tool ${name} returned: ${JSON.stringify({ textContent, nonTextContent })}`);
// Return based on the response format
if (responseFormat === 'content_and_artifact') {
return [textContent, nonTextContent];
}
protected async _call(
args: Record<string, unknown>
): Promise<string | [string | string[], NonTextContent[] | null]> {
try {
const result = await client.callTool({
name: tool.name,
arguments: args,
});
const [textContent, nonTextContent] = _convertCallToolResult({
...result,
isError: result.isError === true,
content: result.content || [],
});
// Return based on the response format
if (this.responseFormat === 'content_and_artifact') {
return [textContent, nonTextContent] as [string | string[], NonTextContent[] | null];
}
// Default to returning just the text content
return typeof textContent === 'string' ? textContent : textContent.join('\n');
} catch (error) {
if (error instanceof ToolException) {
throw error;
}
throw new ToolException(`Error calling tool ${tool.name}: ${String(error)}`);
}
// Default to returning just the text content
return typeof textContent === 'string' ? textContent : textContent.join('\n');
} catch (error) {
logger.error(`Error calling tool ${name}: ${String(error)}`);
if (error instanceof ToolException) {
throw error;
}
throw new ToolException(`Error calling tool ${name}: ${String(error)}`);
}
// Return an instance of our tool with the specified response format
return new MCPTool(responseFormat);
}
/**
@@ -150,28 +135,36 @@ export function convertMcpToolToLangchainTool(
*/
export async function loadMcpTools(
client: Client,
responseFormat: ResponseFormat = 'text'
responseFormat: ResponseFormat = 'content',
throwOnLoadError: boolean = true
): Promise<StructuredToolInterface[]> {
try {
// Get tools in a single operation
const toolsResponse = await client.listTools();
logger.info(`Found ${toolsResponse.tools?.length || 0} MCP tools`);
// Get tools in a single operation
const toolsResponse = await client.listTools();
logger.info(`Found ${toolsResponse.tools?.length || 0} MCP tools`);
// Filter out tools without names and convert in a single map operation
return (toolsResponse.tools || [])
.filter(tool => !!tool.name)
.map(tool => {
try {
logger.debug(`Successfully loaded tool: ${tool.name}`);
return convertMcpToolToLangchainTool(client, tool, undefined, responseFormat);
} catch (error) {
logger.error(`Failed to load tool "${tool.name}":`, error);
return null;
// Filter out tools without names and convert in a single map operation
return (toolsResponse.tools || [])
.filter(tool => !!tool.name)
.map(tool => {
try {
const dst = new DynamicStructuredTool({
name: tool.name,
description: tool.description || '',
schema: JSONSchemaToZod.convert(
(tool.inputSchema ?? { type: 'object', properties: {} }) as JSONSchema
),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
func: _callTool.bind(null, client, tool.name, responseFormat) as any,
});
logger.debug(`Successfully loaded tool: ${dst.name}`);
return dst;
} catch (error) {
logger.error(`Failed to load tool "${tool.name}":`, error);
if (throwOnLoadError) {
throw error;
}
})
.filter(Boolean) as StructuredToolInterface[];
} catch (error) {
logger.error('Failed to list MCP tools:', error);
return [];
}
return null;
}
})
.filter(Boolean) as StructuredToolInterface[];
}
+18
View File
@@ -0,0 +1,18 @@
import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
environment: 'node',
include: ['**/__tests__/**/*.test.ts'],
coverage: {
provider: 'v8',
reporter: ['text', 'lcov'],
include: ['src/**/*.ts'],
exclude: ['src/**/*.test.ts'],
},
setupFiles: ['./vitest.setup.ts'],
transformMode: {
web: [/\.[jt]sx?$/],
},
},
});
+6 -5
View File
@@ -1,16 +1,17 @@
/* global jest */
// Mock node: imports
jest.mock('node:crypto', () => {
import { vi } from 'vitest';
// Mock node:crypto module
vi.mock('node:crypto', () => {
return {
webcrypto: {
getRandomValues: array => {
getRandomValues: (array: Uint8Array) => {
for (let i = 0; i < array.length; i++) {
array[i] = Math.floor(Math.random() * 256);
}
return array;
},
subtle: {
digest: jest.fn().mockResolvedValue(new ArrayBuffer(32)),
digest: vi.fn().mockResolvedValue(new ArrayBuffer(32)),
},
},
};