Skip to content

Revert "Switch percentiles implementation to MergingDigest (#18124)" #18497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ protobuf = "3.25.5"
jakarta_annotation = "1.3.5"
google_http_client = "1.44.1"
google_auth = "1.29.0"
tdigest = "3.3" # Warning: Before updating tdigest, ensure its serialization code for MergingDigest hasn't changed
tdigest = "3.3"
hdrhistogram = "2.2.2"
grpc = "1.68.2"
json_smart = "2.5.2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
import java.util.Map;
import java.util.Objects;

import com.tdunning.math.stats.Centroid;

/**
* Implementation of median absolute deviation agg
*
Expand All @@ -59,14 +57,11 @@ static double computeMedianAbsoluteDeviation(TDigestState valuesSketch) {
} else {
final double approximateMedian = valuesSketch.quantile(0.5);
final TDigestState approximatedDeviationsSketch = new TDigestState(valuesSketch.compression());
for (Centroid centroid : valuesSketch.centroids()) {
valuesSketch.centroids().forEach(centroid -> {
final double deviation = Math.abs(approximateMedian - centroid.mean());
// Weighted add() isn't supported for faster MergingDigest implementation, so add iteratively instead. see
// https://github.com/tdunning/t-digest/issues/167
for (int i = 0; i < centroid.count(); i++) {
approximatedDeviationsSketch.add(deviation);
}
}
approximatedDeviationsSketch.add(deviation, centroid.count());
});

return approximatedDeviationsSketch.quantile(0.5);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,21 @@

package org.opensearch.search.aggregations.metrics;

import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;

import com.tdunning.math.stats.AVLTreeDigest;
import com.tdunning.math.stats.Centroid;
import com.tdunning.math.stats.MergingDigest;

/**
* Extension of {@link com.tdunning.math.stats.TDigest} with custom serialization.
*
* @opensearch.internal
*/
public class TDigestState extends MergingDigest {
public class TDigestState extends AVLTreeDigest {

private final double compression;

Expand All @@ -58,64 +54,28 @@ public TDigestState(double compression) {
this.compression = compression;
}

private TDigestState(double compression, MergingDigest in) {
super(compression);
this.compression = compression;
this.add(List.of(in));
}

@Override
public double compression() {
return compression;
}

public static void write(TDigestState state, StreamOutput out) throws IOException {
if (out.getVersion().before(Version.V_3_1_0)) {
out.writeDouble(state.compression);
out.writeVInt(state.centroidCount());
for (Centroid centroid : state.centroids()) {
out.writeDouble(centroid.mean());
out.writeVLong(centroid.count());
}
} else {
int byteSize = state.byteSize();
out.writeVInt(byteSize);
ByteBuffer buf = ByteBuffer.allocate(byteSize);
state.asBytes(buf);
out.writeBytes(buf.array());
out.writeDouble(state.compression);
out.writeVInt(state.centroidCount());
for (Centroid centroid : state.centroids()) {
out.writeDouble(centroid.mean());
out.writeVLong(centroid.count());
}
}

public static TDigestState read(StreamInput in) throws IOException {
if (in.getVersion().before(Version.V_3_1_0)) {
// In older versions TDigestState was based on AVLTreeDigest. Load centroids into this class, then add it to MergingDigest.
double compression = in.readDouble();
AVLTreeDigest treeDigest = new AVLTreeDigest(compression);
int n = in.readVInt();
if (n > 0) {
for (int i = 0; i < n; i++) {
treeDigest.add(in.readDouble(), in.readVInt());
}
TDigestState state = new TDigestState(compression);
state.add(List.of(treeDigest));
return state;
}
return new TDigestState(compression);
} else {
// For MergingDigest, adding the original centroids in ascending order to a new, empty MergingDigest isn't guaranteed
// to produce a MergingDigest whose centroids are exactly equal to the originals.
// So, use the library's serialization code to ensure we get the exact same centroids, allowing us to compare with equals().
// The AVLTreeDigest had the same limitation for equals() where it was only guaranteed to return true if the other object was
// produced by de/serializing the object, so this should be fine.
int byteSize = in.readVInt();
byte[] bytes = new byte[byteSize];
in.readBytes(bytes, 0, byteSize);
MergingDigest mergingDigest = MergingDigest.fromBytes(ByteBuffer.wrap(bytes));
if (mergingDigest.centroids().isEmpty()) {
return new TDigestState(mergingDigest.compression());
}
return new TDigestState(mergingDigest.compression(), mergingDigest);
double compression = in.readDouble();
TDigestState state = new TDigestState(compression);
int n = in.readVInt();
for (int i = 0; i < n; i++) {
state.add(in.readDouble(), in.readVInt());
}
return state;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected InternalTDigestPercentileRanks createTestInstance(
Arrays.stream(values).forEach(state::add);

// the number of centroids is defined as <= the number of samples inserted
assertTrue(state.centroids().size() <= values.length);
assertTrue(state.centroidCount() <= values.length);
return new InternalTDigestPercentileRanks(name, percents, state, keyed, format, metadata);
}

Expand All @@ -66,7 +66,7 @@ protected void assertReduced(InternalTDigestPercentileRanks reduced, List<Intern
double max = Double.NEGATIVE_INFINITY;
long totalCount = 0;
for (InternalTDigestPercentileRanks ranks : inputs) {
if (ranks.state.centroids().isEmpty()) {
if (ranks.state.centroidCount() == 0) {
// quantiles would return NaN
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected InternalTDigestPercentiles createTestInstance(
Arrays.stream(values).forEach(state::add);

// the number of centroids is defined as <= the number of samples inserted
assertTrue(state.centroids().size() <= values.length);
assertTrue(state.centroidCount() <= values.length);
return new InternalTDigestPercentiles(name, percents, state, keyed, format, metadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void testSomeMatchesSortedNumericDocValues() throws IOException {
iw.addDocument(singleton(new SortedNumericDocValuesField("number", 0)));
}, tdigest -> {
assertEquals(7L, tdigest.state.size());
assertEquals(7L, tdigest.state.centroids().size());
assertEquals(7L, tdigest.state.centroidCount());
assertEquals(5.0d, tdigest.percentile(75), 0.0d);
assertEquals("5.0", tdigest.percentileAsString(75));
assertEquals(3.0d, tdigest.percentile(71), 0.0d);
Expand All @@ -128,7 +128,7 @@ public void testSomeMatchesNumericDocValues() throws IOException {
iw.addDocument(singleton(new NumericDocValuesField("number", 0)));
}, tdigest -> {
assertEquals(tdigest.state.size(), 7L);
assertTrue(tdigest.state.centroids().size() <= 7L);
assertTrue(tdigest.state.centroidCount() <= 7L);
assertEquals(8.0d, tdigest.percentile(100), 0.0d);
assertEquals("8.0", tdigest.percentileAsString(100));
assertEquals(8.0d, tdigest.percentile(88), 0.0d);
Expand Down Expand Up @@ -156,7 +156,7 @@ public void testQueryFiltering() throws IOException {

testCase(LongPoint.newRangeQuery("row", 1, 4), docs, tdigest -> {
assertEquals(4L, tdigest.state.size());
assertEquals(4L, tdigest.state.centroids().size());
assertEquals(4L, tdigest.state.centroidCount());
assertEquals(2.0d, tdigest.percentile(100), 0.0d);
assertEquals(1.0d, tdigest.percentile(50), 0.0d);
assertEquals(1.0d, tdigest.percentile(25), 0.0d);
Expand All @@ -165,7 +165,7 @@ public void testQueryFiltering() throws IOException {

testCase(LongPoint.newRangeQuery("row", 100, 110), docs, tdigest -> {
assertEquals(0L, tdigest.state.size());
assertEquals(0L, tdigest.state.centroids().size());
assertEquals(0L, tdigest.state.centroidCount());
assertFalse(AggregationInspectionHelper.hasValue(tdigest));
});
}
Expand Down
Loading