Skip to content

Commit 8a1cabe

Browse files
schrothbnggozadmutdmour
authored
feat: Add Cohere reranking capability to vector stores (#16014)
Co-authored-by: Yiorgis Gozadinos <[email protected]> Co-authored-by: Mutasem Aldmour <[email protected]>
1 parent 47ad74d commit 8a1cabe

File tree

14 files changed

+585
-20
lines changed

14 files changed

+585
-20
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
2+
import { CohereRerank } from '@langchain/cohere';
3+
import {
4+
NodeConnectionTypes,
5+
type INodeType,
6+
type INodeTypeDescription,
7+
type ISupplyDataFunctions,
8+
type SupplyData,
9+
} from 'n8n-workflow';
10+
11+
import { logWrapper } from '@utils/logWrapper';
12+
13+
export class RerankerCohere implements INodeType {
14+
description: INodeTypeDescription = {
15+
displayName: 'Reranker Cohere',
16+
name: 'rerankerCohere',
17+
icon: { light: 'file:cohere.svg', dark: 'file:cohere.dark.svg' },
18+
group: ['transform'],
19+
version: 1,
20+
description:
21+
'Use Cohere Reranker to reorder documents after retrieval from a vector store by relevance to the given query.',
22+
defaults: {
23+
name: 'Reranker Cohere',
24+
},
25+
requestDefaults: {
26+
ignoreHttpStatusErrors: true,
27+
baseURL: '={{ $credentials.host }}',
28+
},
29+
credentials: [
30+
{
31+
name: 'cohereApi',
32+
required: true,
33+
},
34+
],
35+
codex: {
36+
categories: ['AI'],
37+
subcategories: {
38+
AI: ['Rerankers'],
39+
},
40+
resources: {
41+
primaryDocumentation: [
42+
{
43+
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.rerankercohere/',
44+
},
45+
],
46+
},
47+
},
48+
inputs: [],
49+
outputs: [NodeConnectionTypes.AiReranker],
50+
outputNames: ['Reranker'],
51+
properties: [
52+
{
53+
displayName: 'Model',
54+
name: 'modelName',
55+
type: 'options',
56+
description:
57+
'The model that should be used to rerank the documents. <a href="https://docs.cohere.com/docs/models">Learn more</a>.',
58+
default: 'rerank-v3.5',
59+
options: [
60+
{
61+
name: 'rerank-v3.5',
62+
value: 'rerank-v3.5',
63+
},
64+
{
65+
name: 'rerank-english-v3.0',
66+
value: 'rerank-english-v3.0',
67+
},
68+
{
69+
name: 'rerank-multilingual-v3.0',
70+
value: 'rerank-multilingual-v3.0',
71+
},
72+
],
73+
},
74+
],
75+
};
76+
77+
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
78+
this.logger.debug('Supply data for reranking Cohere');
79+
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
80+
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
81+
const reranker = new CohereRerank({
82+
apiKey: credentials.apiKey,
83+
model: modelName,
84+
});
85+
86+
return {
87+
response: logWrapper(reranker, this),
88+
};
89+
}
90+
}
Lines changed: 5 additions & 0 deletions
Loading
Lines changed: 5 additions & 0 deletions
Loading
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import { CohereRerank } from '@langchain/cohere';
2+
import { mock } from 'jest-mock-extended';
3+
import type { ISupplyDataFunctions } from 'n8n-workflow';
4+
5+
import { logWrapper } from '@utils/logWrapper';
6+
7+
import { RerankerCohere } from '../RerankerCohere.node';
8+
9+
// Mock the CohereRerank class
10+
jest.mock('@langchain/cohere', () => ({
11+
CohereRerank: jest.fn(),
12+
}));
13+
14+
// Mock the logWrapper utility
15+
jest.mock('@utils/logWrapper', () => ({
16+
logWrapper: jest.fn().mockImplementation((obj) => ({ logWrapped: obj })),
17+
}));
18+
19+
describe('RerankerCohere', () => {
20+
let rerankerCohere: RerankerCohere;
21+
let mockSupplyDataFunctions: ISupplyDataFunctions;
22+
let mockCohereRerank: jest.Mocked<CohereRerank>;
23+
24+
beforeEach(() => {
25+
rerankerCohere = new RerankerCohere();
26+
27+
// Reset the mock
28+
jest.clearAllMocks();
29+
30+
// Create a mock CohereRerank instance
31+
mockCohereRerank = {
32+
compressDocuments: jest.fn(),
33+
} as unknown as jest.Mocked<CohereRerank>;
34+
35+
// Make the CohereRerank constructor return our mock instance
36+
(CohereRerank as jest.MockedClass<typeof CohereRerank>).mockImplementation(
37+
() => mockCohereRerank,
38+
);
39+
40+
// Create mock supply data functions
41+
mockSupplyDataFunctions = mock<ISupplyDataFunctions>({
42+
logger: {
43+
debug: jest.fn(),
44+
error: jest.fn(),
45+
info: jest.fn(),
46+
warn: jest.fn(),
47+
},
48+
});
49+
50+
// Mock specific methods with proper jest functions
51+
mockSupplyDataFunctions.getNodeParameter = jest.fn();
52+
mockSupplyDataFunctions.getCredentials = jest.fn();
53+
});
54+
55+
it('should create CohereRerank with default model and return wrapped instance', async () => {
56+
// Setup mocks
57+
const mockCredentials = { apiKey: 'test-api-key' };
58+
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
59+
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
60+
61+
// Execute
62+
const result = await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
63+
64+
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith(
65+
'modelName',
66+
0,
67+
'rerank-v3.5',
68+
);
69+
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
70+
expect(CohereRerank).toHaveBeenCalledWith({
71+
apiKey: 'test-api-key',
72+
model: 'rerank-v3.5',
73+
});
74+
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
75+
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
76+
});
77+
78+
it('should create CohereRerank with custom model', async () => {
79+
// Setup mocks
80+
const mockCredentials = { apiKey: 'custom-api-key' };
81+
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue(
82+
'rerank-multilingual-v3.0',
83+
);
84+
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
85+
86+
// Execute
87+
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
88+
89+
// Verify
90+
expect(CohereRerank).toHaveBeenCalledWith({
91+
apiKey: 'custom-api-key',
92+
model: 'rerank-multilingual-v3.0',
93+
});
94+
});
95+
96+
it('should handle different item indices', async () => {
97+
// Setup mocks
98+
const mockCredentials = { apiKey: 'test-api-key' };
99+
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0');
100+
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
101+
102+
// Execute with different item index
103+
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 2);
104+
105+
// Verify the correct item index is passed
106+
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith(
107+
'modelName',
108+
2,
109+
'rerank-v3.5',
110+
);
111+
});
112+
113+
it('should throw error when credentials are missing', async () => {
114+
// Setup mocks
115+
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
116+
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
117+
new Error('Missing credentials'),
118+
);
119+
120+
// Execute and verify error
121+
await expect(rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0)).rejects.toThrow(
122+
'Missing credentials',
123+
);
124+
});
125+
126+
it('should use fallback model when parameter is not provided', async () => {
127+
// Setup mocks - getNodeParameter returns the fallback value
128+
const mockCredentials = { apiKey: 'test-api-key' };
129+
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value
130+
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
131+
132+
// Execute
133+
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
134+
135+
// Verify fallback is used
136+
expect(CohereRerank).toHaveBeenCalledWith({
137+
apiKey: 'test-api-key',
138+
model: 'rerank-v3.5',
139+
});
140+
});
141+
});

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
4141
"inputs": "={{
4242
((parameters) => {
4343
const mode = parameters?.mode;
44+
const useReranker = parameters?.useReranker;
4445
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
4546
47+
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
48+
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
49+
}
50+
4651
if (mode === 'retrieve-as-tool') {
4752
return inputs;
4853
}
@@ -233,6 +238,21 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
233238
"name": "includeDocumentMetadata",
234239
"type": "boolean",
235240
},
241+
{
242+
"default": false,
243+
"description": "Whether or not to rerank results",
244+
"displayName": "Rerank Results",
245+
"displayOptions": {
246+
"show": {
247+
"mode": [
248+
"load",
249+
"retrieve-as-tool",
250+
],
251+
},
252+
},
253+
"name": "useReranker",
254+
"type": "boolean",
255+
},
236256
{
237257
"default": "",
238258
"description": "ID of an embedding entry",

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,13 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
6969
inputs: `={{
7070
((parameters) => {
7171
const mode = parameters?.mode;
72+
const useReranker = parameters?.useReranker;
7273
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
7374
75+
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
76+
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
77+
}
78+
7479
if (mode === 'retrieve-as-tool') {
7580
return inputs;
7681
}
@@ -202,6 +207,18 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
202207
},
203208
},
204209
},
210+
{
211+
displayName: 'Rerank Results',
212+
name: 'useReranker',
213+
type: 'boolean',
214+
default: false,
215+
description: 'Whether or not to rerank results',
216+
displayOptions: {
217+
show: {
218+
mode: ['load', 'retrieve-as-tool'],
219+
},
220+
},
221+
},
205222
// ID is always used for update operation
206223
{
207224
displayName: 'ID',
@@ -233,7 +250,6 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
233250
*/
234251
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
235252
const mode = this.getNodeParameter('mode', 0) as NodeOperationMode;
236-
237253
// Get the embeddings model connected to this node
238254
const embeddings = (await this.getInputConnectionData(
239255
NodeConnectionTypes.AiEmbedding,

0 commit comments

Comments
 (0)