Skip to content

Commit 6971ea8

Browse files
committed
Return bad request for invalid stat parameters in stats API
Signed-off-by: Andy Qin <[email protected]>
1 parent d42efb1 commit 6971ea8

File tree

3 files changed

+152
-22
lines changed

3 files changed

+152
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414
### Bug Fixes
1515
- Add validations to prevent empty input_text_field and input_image_field in TextImageEmbeddingProcessor ([#1257](https://github.com/opensearch-project/neural-search/pull/1257))
1616
- Fix score value as null for single shard when sorting is not done on score field ([#1277](https://github.com/opensearch-project/neural-search/pull/1277))
17+
- Return bad request for stats API calls with invalid stat names instead of ignoring them ([#1291](https://github.com/opensearch-project/neural-search/pull/1291))
18+
1719

1820
### Infrastructure
1921

src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsAction.java

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import java.util.Arrays;
2525
import java.util.EnumSet;
26+
import java.util.HashSet;
2627
import java.util.List;
2728
import java.util.Locale;
2829
import java.util.Optional;
@@ -157,47 +158,46 @@ private NeuralStatsInput createNeuralStatsInputFromRequestParams(RestRequest req
157158
boolean includeMetadata = request.paramAsBoolean(INCLUDE_METADATA_PARAM, false);
158159
neuralStatsInput.setIncludeMetadata(includeMetadata);
159160

160-
// Determine which stat names to retrieve based on user parameters
161-
Optional<String[]> stats = splitCommaSeparatedParam(request, "stat");
161+
// Process requested stats parameters
162+
processStatsRequestParameters(request, neuralStatsInput);
162163

163-
if (stats.isPresent() == false) {
164-
// No specific stats requested, add all stats by default
165-
addAllStats(neuralStatsInput);
166-
return neuralStatsInput;
167-
}
164+
return neuralStatsInput;
165+
}
168166

169-
// Process requested stats
170-
boolean anyStatAdded = processRequestedStats(stats.get(), neuralStatsInput);
167+
private void processStatsRequestParameters(RestRequest request, NeuralStatsInput neuralStatsInput) {
168+
// Determine which stat names to retrieve based on user parameters
169+
Optional<String[]> optionalStats = splitCommaSeparatedParam(request, "stat");
171170

172-
// If no valid stats were added, fall back to all stats
173-
if (anyStatAdded == false) {
171+
if (optionalStats.isPresent() == false || optionalStats.get().length == 0) {
172+
// No specific stats requested, add all stats by default
174173
addAllStats(neuralStatsInput);
174+
return;
175175
}
176176

177-
return neuralStatsInput;
178-
}
179-
180-
private boolean processRequestedStats(String[] stats, NeuralStatsInput neuralStatsInput) {
181-
boolean statAdded = false;
182-
177+
String[] stats = optionalStats.get();
178+
Set<String> invalidStatNames = new HashSet<>();
183179
for (String stat : stats) {
184180
// Validate parameter
185181
String normalizedStat = stat.toLowerCase(Locale.ROOT);
186182
if (isValidParamString(normalizedStat) == false) {
187-
log.info("Invalid stat name parameter format: {}", normalizedStat);
183+
invalidStatNames.add(normalizedStat);
188184
continue;
189185
}
190186

191187
if (EVENT_STAT_NAMES.contains(normalizedStat)) {
192188
neuralStatsInput.getEventStatNames().add(EventStatName.from(normalizedStat));
193-
statAdded = true;
194189
} else if (STATE_STAT_NAMES.contains(normalizedStat)) {
195190
neuralStatsInput.getInfoStatNames().add(InfoStatName.from(normalizedStat));
196-
statAdded = true;
191+
} else {
192+
invalidStatNames.add(normalizedStat);
197193
}
198-
log.info("Non-existent stat name parsed: {}", normalizedStat);
199194
}
200-
return statAdded;
195+
196+
// When we reach this block, we must have added at least one stat to the input, or else invalid stats will be
197+
// non empty. So throwing this exception here without adding all covers the empty input case.
198+
if (invalidStatNames.isEmpty() == false) {
199+
throw new IllegalArgumentException(unrecognized(request, invalidStatNames, Set.of(stats), "stat"));
200+
}
201201
}
202202

203203
private void addAllStats(NeuralStatsInput neuralStatsInput) {
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.rest;
6+
7+
import org.junit.Before;
8+
import org.mockito.ArgumentCaptor;
9+
import org.mockito.Mock;
10+
import org.mockito.MockitoAnnotations;
11+
import org.opensearch.common.settings.Settings;
12+
import org.opensearch.core.action.ActionListener;
13+
import org.opensearch.core.rest.RestStatus;
14+
import org.opensearch.core.xcontent.NamedXContentRegistry;
15+
import org.opensearch.neuralsearch.processor.InferenceProcessorTestCase;
16+
import org.opensearch.neuralsearch.settings.NeuralSearchSettingsAccessor;
17+
import org.opensearch.neuralsearch.stats.NeuralStatsInput;
18+
import org.opensearch.neuralsearch.stats.events.EventStatName;
19+
import org.opensearch.neuralsearch.stats.info.InfoStatName;
20+
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
21+
import org.opensearch.neuralsearch.transport.NeuralStatsRequest;
22+
import org.opensearch.neuralsearch.transport.NeuralStatsResponse;
23+
import org.opensearch.rest.BytesRestResponse;
24+
import org.opensearch.rest.RestChannel;
25+
import org.opensearch.rest.RestRequest;
26+
import org.opensearch.test.rest.FakeRestRequest;
27+
import org.opensearch.threadpool.TestThreadPool;
28+
import org.opensearch.threadpool.ThreadPool;
29+
import org.opensearch.transport.client.node.NodeClient;
30+
31+
import java.util.EnumSet;
32+
import java.util.HashMap;
33+
import java.util.Map;
34+
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.ArgumentMatchers.eq;
37+
import static org.mockito.Mockito.doAnswer;
38+
import static org.mockito.Mockito.never;
39+
import static org.mockito.Mockito.spy;
40+
import static org.mockito.Mockito.times;
41+
import static org.mockito.Mockito.verify;
42+
import static org.mockito.Mockito.when;
43+
44+
public class RestNeuralStatsActionTests extends InferenceProcessorTestCase {
45+
private NodeClient client;
46+
private ThreadPool threadPool;
47+
48+
@Mock
49+
RestChannel channel;
50+
51+
@Mock
52+
private NeuralSearchSettingsAccessor settingsAccessor;
53+
54+
@Before
55+
public void setup() {
56+
MockitoAnnotations.openMocks(this);
57+
58+
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
59+
client = spy(new NodeClient(Settings.EMPTY, threadPool));
60+
61+
doAnswer(invocation -> {
62+
ActionListener<NeuralStatsResponse> actionListener = invocation.getArgument(2);
63+
return null;
64+
}).when(client).execute(eq(NeuralStatsAction.INSTANCE), any(), any());
65+
}
66+
67+
@Override
68+
public void tearDown() throws Exception {
69+
super.tearDown();
70+
threadPool.shutdown();
71+
client.close();
72+
}
73+
74+
public void test_execute() throws Exception {
75+
when(settingsAccessor.isStatsEnabled()).thenReturn(true);
76+
RestNeuralStatsAction restNeuralStatsAction = new RestNeuralStatsAction(settingsAccessor);
77+
78+
RestRequest request = getRestRequest();
79+
restNeuralStatsAction.handleRequest(request, channel, client);
80+
81+
ArgumentCaptor<NeuralStatsRequest> argumentCaptor = ArgumentCaptor.forClass(NeuralStatsRequest.class);
82+
verify(client, times(1)).execute(eq(NeuralStatsAction.INSTANCE), argumentCaptor.capture(), any());
83+
84+
NeuralStatsInput capturedInput = argumentCaptor.getValue().getNeuralStatsInput();
85+
assertEquals(capturedInput.getEventStatNames(), EnumSet.allOf(EventStatName.class));
86+
assertEquals(capturedInput.getInfoStatNames(), EnumSet.allOf(InfoStatName.class));
87+
}
88+
89+
public void test_handleRequest_disabledForbidden() throws Exception {
90+
when(settingsAccessor.isStatsEnabled()).thenReturn(false);
91+
RestNeuralStatsAction restNeuralStatsAction = new RestNeuralStatsAction(settingsAccessor);
92+
93+
RestRequest request = getRestRequest();
94+
restNeuralStatsAction.handleRequest(request, channel, client);
95+
96+
verify(client, never()).execute(eq(NeuralStatsAction.INSTANCE), any(), any());
97+
98+
ArgumentCaptor<BytesRestResponse> responseCaptor = ArgumentCaptor.forClass(BytesRestResponse.class);
99+
verify(channel).sendResponse(responseCaptor.capture());
100+
101+
BytesRestResponse response = responseCaptor.getValue();
102+
assertEquals(RestStatus.FORBIDDEN, response.status());
103+
}
104+
105+
public void test_handleRequest_invalidStatParameter() throws Exception {
106+
when(settingsAccessor.isStatsEnabled()).thenReturn(true);
107+
RestNeuralStatsAction restNeuralStatsAction = new RestNeuralStatsAction(settingsAccessor);
108+
109+
// Create request with invalid stat parameter
110+
Map<String, String> params = new HashMap<>();
111+
params.put("stat", "INVALID_STAT");
112+
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
113+
.withParams(params)
114+
.build();
115+
116+
assertThrows(
117+
IllegalArgumentException.class,
118+
() -> restNeuralStatsAction.handleRequest(request, channel, client)
119+
);
120+
121+
verify(client, never()).execute(eq(NeuralStatsAction.INSTANCE), any(), any());
122+
}
123+
124+
private RestRequest getRestRequest() {
125+
Map<String, String> params = new HashMap<>();
126+
return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
127+
}
128+
}

0 commit comments

Comments
 (0)