Skip to content

Commit 969552a

Browse files
authored
feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051)
1 parent ac1a1df commit 969552a

File tree

6 files changed

+226
-9
lines changed

6 files changed

+226
-9
lines changed

packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.ts

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
2-
import type { VectorStore } from '@langchain/core/vectorstores';
2+
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
3+
import { VectorStore } from '@langchain/core/vectorstores';
4+
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
35
import {
46
NodeConnectionTypes,
57
type INodeType,
@@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType {
6567
const vectorStore = (await this.getInputConnectionData(
6668
NodeConnectionTypes.AiVectorStore,
6769
itemIndex,
68-
)) as VectorStore;
70+
)) as
71+
| VectorStore
72+
| {
73+
reranker: BaseDocumentCompressor;
74+
vectorStore: VectorStore;
75+
};
6976

70-
const retriever = vectorStore.asRetriever(topK);
77+
let retriever = null;
78+
79+
if (vectorStore instanceof VectorStore) {
80+
retriever = vectorStore.asRetriever(topK);
81+
} else {
82+
retriever = new ContextualCompressionRetriever({
83+
baseCompressor: vectorStore.reranker,
84+
baseRetriever: vectorStore.vectorStore.asRetriever(topK),
85+
});
86+
}
7187

7288
return {
7389
response: logWrapper(retriever, this),
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
2+
import { VectorStore } from '@langchain/core/vectorstores';
3+
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
4+
import type { ISupplyDataFunctions } from 'n8n-workflow';
5+
import { NodeConnectionTypes } from 'n8n-workflow';
6+
7+
import { RetrieverVectorStore } from '../RetrieverVectorStore.node';
8+
9+
const mockLogger = {
10+
debug: jest.fn(),
11+
info: jest.fn(),
12+
warn: jest.fn(),
13+
error: jest.fn(),
14+
};
15+
16+
describe('RetrieverVectorStore', () => {
17+
let retrieverNode: RetrieverVectorStore;
18+
let mockContext: jest.Mocked<ISupplyDataFunctions>;
19+
20+
beforeEach(() => {
21+
retrieverNode = new RetrieverVectorStore();
22+
mockContext = {
23+
logger: mockLogger,
24+
getNodeParameter: jest.fn(),
25+
getInputConnectionData: jest.fn(),
26+
} as unknown as jest.Mocked<ISupplyDataFunctions>;
27+
jest.clearAllMocks();
28+
});
29+
30+
describe('supplyData', () => {
31+
it('should create a retriever from a basic VectorStore', async () => {
32+
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
33+
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
34+
35+
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
36+
if (param === 'topK') return 4;
37+
return defaultValue;
38+
});
39+
40+
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
41+
42+
const result = await retrieverNode.supplyData.call(mockContext, 0);
43+
44+
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
45+
NodeConnectionTypes.AiVectorStore,
46+
0,
47+
);
48+
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
49+
expect(result).toHaveProperty('response', { test: 'retriever' });
50+
});
51+
52+
it('should create a retriever with custom topK parameter', async () => {
53+
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
54+
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
55+
56+
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
57+
if (param === 'topK') return 10;
58+
return defaultValue;
59+
});
60+
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
61+
62+
const result = await retrieverNode.supplyData.call(mockContext, 0);
63+
64+
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10);
65+
expect(result).toHaveProperty('response', { test: 'retriever' });
66+
});
67+
68+
it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => {
69+
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
70+
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
71+
72+
const mockReranker = {} as BaseDocumentCompressor;
73+
74+
const inputWithReranker = {
75+
reranker: mockReranker,
76+
vectorStore: mockVectorStore,
77+
};
78+
79+
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
80+
if (param === 'topK') return 4;
81+
return defaultValue;
82+
});
83+
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
84+
85+
const result = await retrieverNode.supplyData.call(mockContext, 0);
86+
87+
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
88+
NodeConnectionTypes.AiVectorStore,
89+
0,
90+
);
91+
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
92+
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
93+
});
94+
95+
it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => {
96+
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
97+
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
98+
99+
const mockReranker = {} as BaseDocumentCompressor;
100+
101+
const inputWithReranker = {
102+
reranker: mockReranker,
103+
vectorStore: mockVectorStore,
104+
};
105+
106+
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
107+
if (param === 'topK') return 8;
108+
return defaultValue;
109+
});
110+
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
111+
112+
const result = await retrieverNode.supplyData.call(mockContext, 0);
113+
114+
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8);
115+
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
116+
});
117+
118+
it('should use default topK value when parameter is not provided', async () => {
119+
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
120+
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
121+
122+
mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => {
123+
return defaultValue;
124+
});
125+
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
126+
127+
await retrieverNode.supplyData.call(mockContext, 0);
128+
129+
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4);
130+
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
131+
});
132+
});
133+
});

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
4444
const useReranker = parameters?.useReranker;
4545
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
4646
47-
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
47+
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
4848
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
4949
}
5050
@@ -246,6 +246,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
246246
"show": {
247247
"mode": [
248248
"load",
249+
"retrieve",
249250
"retrieve-as-tool",
250251
],
251252
},

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
7272
const useReranker = parameters?.useReranker;
7373
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
7474
75-
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
75+
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
7676
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
7777
}
7878
@@ -215,7 +215,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
215215
description: 'Whether or not to rerank results',
216216
displayOptions: {
217217
show: {
218-
mode: ['load', 'retrieve-as-tool'],
218+
mode: ['load', 'retrieve', 'retrieve-as-tool'],
219219
},
220220
},
221221
},

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveOperation.test.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import type { Embeddings } from '@langchain/core/embeddings';
2+
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
23
import type { VectorStore } from '@langchain/core/vectorstores';
34
import type { MockProxy } from 'jest-mock-extended';
45
import { mock } from 'jest-mock-extended';
56
import type { ISupplyDataFunctions } from 'n8n-workflow';
7+
import { NodeConnectionTypes } from 'n8n-workflow';
68

79
import { logWrapper } from '@utils/logWrapper';
810

@@ -22,15 +24,19 @@ describe('handleRetrieveOperation', () => {
2224
let mockContext: MockProxy<ISupplyDataFunctions>;
2325
let mockEmbeddings: MockProxy<Embeddings>;
2426
let mockVectorStore: MockProxy<VectorStore>;
27+
let mockReranker: MockProxy<BaseDocumentCompressor>;
2528
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
2629

2730
beforeEach(() => {
2831
mockContext = mock<ISupplyDataFunctions>();
32+
mockContext.getNodeParameter.mockReturnValue(false); // Default useReranker to false
2933

3034
mockEmbeddings = mock<Embeddings>();
3135

3236
mockVectorStore = mock<VectorStore>();
3337

38+
mockReranker = mock<BaseDocumentCompressor>();
39+
3440
mockArgs = {
3541
meta: {
3642
displayName: 'Test Vector Store',
@@ -88,4 +94,46 @@ describe('handleRetrieveOperation', () => {
8894
// Call the closeFunction - should not throw error even with no release method
8995
await expect(result.closeFunction!()).resolves.not.toThrow();
9096
});
97+
98+
it('should retrieve vector store without reranker when useReranker is false', async () => {
99+
mockContext.getNodeParameter.mockReturnValue(false);
100+
101+
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
102+
103+
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
104+
105+
expect(mockArgs.getVectorStoreClient).toHaveBeenCalledWith(
106+
mockContext,
107+
{ testFilter: 'value' },
108+
mockEmbeddings,
109+
0,
110+
);
111+
112+
// Result should contain vector store and close function
113+
expect(result).toHaveProperty('response', mockVectorStore);
114+
expect(result).toHaveProperty('closeFunction');
115+
116+
// Should not try to get reranker input connection
117+
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
118+
});
119+
120+
it('should retrieve vector store with reranker when useReranker is true', async () => {
121+
mockContext.getNodeParameter.mockReturnValue(true);
122+
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
123+
124+
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
125+
126+
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
127+
128+
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
129+
NodeConnectionTypes.AiReranker,
130+
0,
131+
);
132+
133+
expect(result.response).toEqual({
134+
reranker: mockReranker,
135+
vectorStore: mockVectorStore,
136+
});
137+
expect(result).toHaveProperty('closeFunction');
138+
});
91139
});

packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveOperation.ts

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { Embeddings } from '@langchain/core/embeddings';
2+
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
23
import type { VectorStore } from '@langchain/core/vectorstores';
3-
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
4+
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
45

56
import { getMetadataFiltersValues } from '@utils/helpers';
67
import { logWrapper } from '@utils/logWrapper';
@@ -19,13 +20,31 @@ export async function handleRetrieveOperation<T extends VectorStore = VectorStor
1920
): Promise<SupplyData> {
2021
// Get metadata filters
2122
const filter = getMetadataFiltersValues(context, itemIndex);
23+
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
2224

2325
// Get the vector store client
2426
const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex);
27+
let response: VectorStore | { reranker: BaseDocumentCompressor; vectorStore: VectorStore } =
28+
vectorStore;
29+
30+
if (useReranker) {
31+
const reranker = (await context.getInputConnectionData(
32+
NodeConnectionTypes.AiReranker,
33+
0,
34+
)) as BaseDocumentCompressor;
35+
36+
// Return reranker and vector store with log wrapper
37+
response = {
38+
reranker,
39+
vectorStore: logWrapper(vectorStore, context),
40+
};
41+
} else {
42+
// Return the vector store with logging wrapper
43+
response = logWrapper(vectorStore, context);
44+
}
2545

26-
// Return the vector store with logging wrapper and cleanup function
2746
return {
28-
response: logWrapper(vectorStore, context),
47+
response,
2948
closeFunction: async () => {
3049
// Release the vector store client if a release method was provided
3150
args.releaseVectorStoreClient?.(vectorStore);

0 commit comments

Comments
 (0)