Skip to content

fix(query): handle casting array filter paths underneath array filter paths with embedded discriminators #15388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/helpers/query/getEmbeddedDiscriminatorPath.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const updatedPathsByArrayFilter = require('../update/updatedPathsByArrayFilter')
*/

module.exports = function getEmbeddedDiscriminatorPath(schema, update, filter, path, options) {
const parts = path.split('.');
const parts = path.indexOf('.') === -1 ? [path] : path.split('.');
let schematype = null;
let type = 'adhocOrUndefined';

Expand All @@ -26,9 +26,10 @@ module.exports = function getEmbeddedDiscriminatorPath(schema, update, filter, p
const arrayFilters = options != null && Array.isArray(options.arrayFilters) ?
options.arrayFilters : [];
const updatedPathsByFilter = updatedPathsByArrayFilter(update);
let startIndex = 0;

for (let i = 0; i < parts.length; ++i) {
const originalSubpath = parts.slice(0, i + 1).join('.');
const originalSubpath = parts.slice(startIndex, i + 1).join('.');
const subpath = cleanPositionalOperators(originalSubpath);
schematype = schema.path(subpath);
if (schematype == null) {
Expand Down Expand Up @@ -89,6 +90,8 @@ module.exports = function getEmbeddedDiscriminatorPath(schema, update, filter, p

const rest = parts.slice(i + 1).join('.');
schematype = discriminatorSchema.path(rest);
schema = discriminatorSchema;
startIndex = i + 1;
if (schematype != null) {
type = discriminatorSchema._getPathType(rest);
break;
Expand Down
7 changes: 5 additions & 2 deletions lib/helpers/schema/getPath.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const numberRE = /^\d+$/;
* @api private
*/

module.exports = function getPath(schema, path) {
module.exports = function getPath(schema, path, discriminatorValueMap) {
let schematype = schema.path(path);
if (schematype != null) {
return schematype;
Expand All @@ -26,10 +26,13 @@ module.exports = function getPath(schema, path) {
schematype = schema.path(cur);
if (schematype != null && schematype.schema) {
schema = schematype.schema;
cur = '';
if (!isArray && schematype.$isMongooseDocumentArray) {
isArray = true;
}
if (discriminatorValueMap && discriminatorValueMap[cur]) {
schema = schema.discriminators[discriminatorValueMap[cur]] ?? schema;
}
cur = '';
}
}

Expand Down
9 changes: 7 additions & 2 deletions lib/helpers/update/castArrayFilters.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ function _castArrayFilters(arrayFilters, schema, strictQuery, updatedPathsByFilt
return;
}

// Map to store discriminator values for embedded documents in the array filters.
// This is used to handle cases where array filters target specific embedded document types.
const discriminatorValueMap = {};

for (const filter of arrayFilters) {
if (filter == null) {
throw new Error(`Got null array filter in ${arrayFilters}`);
Expand All @@ -58,12 +62,13 @@ function _castArrayFilters(arrayFilters, schema, strictQuery, updatedPathsByFilt
updatedPathsByFilter[filterWildcardPath]
);

const baseSchematype = getPath(schema, baseFilterPath);
const baseSchematype = getPath(schema, baseFilterPath, discriminatorValueMap);
let filterBaseSchema = baseSchematype != null ? baseSchematype.schema : null;
if (filterBaseSchema != null &&
filterBaseSchema.discriminators != null &&
filter[filterWildcardPath + '.' + filterBaseSchema.options.discriminatorKey]) {
filterBaseSchema = filterBaseSchema.discriminators[filter[filterWildcardPath + '.' + filterBaseSchema.options.discriminatorKey]] || filterBaseSchema;
discriminatorValueMap[baseFilterPath] = filter[filterWildcardPath + '.' + filterBaseSchema.options.discriminatorKey];
}

for (const key of keys) {
Expand All @@ -83,7 +88,7 @@ function _castArrayFilters(arrayFilters, schema, strictQuery, updatedPathsByFilt
// If there are multiple array filters in the path being updated, make sure
// to replace them so we can get the schema path.
filterPathRelativeToBase = cleanPositionalOperators(filterPathRelativeToBase);
schematype = getPath(filterBaseSchema, filterPathRelativeToBase);
schematype = getPath(filterBaseSchema, filterPathRelativeToBase, discriminatorValueMap);
}

if (schematype == null) {
Expand Down
48 changes: 48 additions & 0 deletions test/helpers/update.castArrayFilters.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,52 @@ describe('castArrayFilters', function() {

assert.strictEqual(q.getUpdate().$set['groups.$[group].tags.$[tag]'], '42');
});

it('casts paths underneath embedded discriminators (gh-15386)', async function() {
const eventSchema = new Schema({ message: String }, { discriminatorKey: 'kind', _id: false });
const batchSchema = new Schema({ events: [eventSchema] });

const docArray = batchSchema.path('events');
docArray.discriminator('Clicked', new Schema({ element: { type: String, required: true } }, { _id: false }));

const productSchema = new Schema({
name: String,
price: Number
});

docArray.discriminator(
'Purchased',
new Schema({
products: {
type: [productSchema],
required: true
}
})
);

const q = new Query();
q.schema = batchSchema;

const filter = {};
const update = {
$set: {
'events.$[event].products.$[product].price': '20'
}
};
const purchasedId = new Types.ObjectId();
const productId = new Types.ObjectId();
const opts = {
arrayFilters: [
{ 'event._id': purchasedId, 'event.kind': 'Purchased' },
{ 'product._id': productId.toString() }
]
};

q.updateOne(filter, update, opts);
castArrayFilters(q);
q._update = q._castUpdate(q._update, false);

assert.strictEqual(q.getOptions().arrayFilters[1]['product._id'].toHexString(), productId.toHexString());
assert.strictEqual(q.getUpdate().$set['events.$[event].products.$[product].price'], 20);
});
});
Loading