diff --git a/docs/changelog/113158.yaml b/docs/changelog/113158.yaml new file mode 100644 index 0000000000000..d097ea11b3a23 --- /dev/null +++ b/docs/changelog/113158.yaml @@ -0,0 +1,5 @@ +pr: 113158 +summary: Adds a new Inference API for streaming responses back to the user. +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index a37fb3dd75673..9e9a4cf890379 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -188,4 +188,21 @@ default boolean isInClusterService() { * @return {@link TransportVersion} specifying the version */ TransportVersion getMinimalSupportedVersion(); + + /** + * The set of tasks where this service provider supports using the streaming API. + * @return set of supported task types. Defaults to empty. + */ + default Set supportedStreamingTasks() { + return Set.of(); + } + + /** + * Checks the task type against the set of supported streaming tasks returned by {@link #supportedStreamingTasks()}. + * @param taskType the task that supports streaming + * @return true if the taskType is supported + */ + default boolean canStream(TaskType taskType) { + return supportedStreamingTasks().contains(taskType); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index d898f961651f1..a19edd5a08162 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -92,6 +92,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final Map taskSettings; private final InputType inputType; private final TimeValue inferenceTimeout; + private final boolean stream; public Request( TaskType taskType, @@ -100,7 +101,8 @@ public Request( List input, Map taskSettings, InputType inputType, - TimeValue inferenceTimeout + TimeValue inferenceTimeout, + boolean stream ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -109,6 +111,7 @@ public Request( this.taskSettings = taskSettings; this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; + this.stream = stream; } public Request(StreamInput in) throws IOException { @@ -134,6 +137,9 @@ public Request(StreamInput in) throws IOException { this.query = null; this.inferenceTimeout = DEFAULT_TIMEOUT; } + + // streaming is not supported yet for transport traffic + this.stream = false; } public TaskType getTaskType() { @@ -165,7 +171,7 @@ public TimeValue getInferenceTimeout() { } public boolean isStreaming() { - return false; + return stream; } @Override @@ -261,6 +267,7 @@ public static class Builder { private Map taskSettings = Map.of(); private String query; private TimeValue timeout = DEFAULT_TIMEOUT; + private boolean stream = false; private Builder() {} @@ -303,8 +310,13 @@ private Builder setInferenceTimeout(String inferenceTimeout) { return setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName())); } + public Builder setStream(boolean stream) { + this.stream = stream; + return this; + } + public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout); + return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index f41e117e75b9f..a9ca5e6da8720 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -46,7 +46,8 @@ protected InferenceAction.Request createTestInstance() { randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), - TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + TimeValue.timeValueMillis(randomLongBetween(1, 2048)), + false ); } @@ -80,7 +81,8 @@ public void testValidation_TextEmbedding() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException e = request.validate(); assertNull(e); @@ -94,7 +96,8 @@ public void testValidation_Rerank() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException e = request.validate(); assertNull(e); @@ -108,7 +111,8 @@ public void testValidation_TextEmbedding_Null() { null, null, null, - null + null, + false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); assertNotNull(inputNullError); @@ -123,7 +127,8 @@ public void testValidation_TextEmbedding_Empty() { List.of(), null, null, - null + null, + false ); ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate(); assertNotNull(inputEmptyError); @@ -138,7 +143,8 @@ public void testValidation_Rerank_Null() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException queryNullError = queryNullRequest.validate(); assertNotNull(queryNullError); @@ -153,7 +159,8 @@ public void testValidation_Rerank_Empty() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate(); assertNotNull(queryEmptyError); @@ -185,7 +192,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 1 -> new InferenceAction.Request( @@ -195,7 +203,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); case 2 -> { var changedInputs = new ArrayList(instance.getInput()); @@ -207,7 +216,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc changedInputs, instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 3 -> { @@ -225,7 +235,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), taskSettings, instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 4 -> { @@ -237,7 +248,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), nextInputType, - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 5 -> new InferenceAction.Request( @@ -247,7 +259,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); case 6 -> { var newDuration = Duration.of( @@ -262,7 +275,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()) + TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), + false ); } default -> throw new UnsupportedOperationException(); @@ -279,7 +293,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput().subList(0, 1), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0)) { return new InferenceAction.Request( @@ -289,7 +304,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.UNSPECIFIED @@ -302,7 +318,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) { @@ -313,7 +330,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_14_0)) { return new InferenceAction.Request( @@ -323,7 +341,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } @@ -339,7 +358,8 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio List.of(), Map.of(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ), TransportVersions.V_8_13_0 ); @@ -353,7 +373,8 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn List.of(), Map.of(), InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); InferenceAction.Request deserializedInstance = copyWriteable( diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java new file mode 100644 index 0000000000000..eb5f3c75bab60 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.entity.ContentType; +import org.apache.http.nio.ContentDecoder; +import org.apache.http.nio.IOControl; +import org.apache.http.nio.protocol.AbstractAsyncResponseConsumer; +import org.apache.http.nio.util.SimpleInputBuffer; +import org.apache.http.protocol.HttpContext; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicReference; + +class AsyncInferenceResponseConsumer extends AbstractAsyncResponseConsumer { + private final AtomicReference httpResponse = new AtomicReference<>(); + private final Deque collector = new ArrayDeque<>(); + private final ServerSentEventParser sseParser = new ServerSentEventParser(); + private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096); + + @Override + protected void onResponseReceived(HttpResponse httpResponse) { + this.httpResponse.set(httpResponse); + } + + @Override + protected void onContentReceived(ContentDecoder contentDecoder, IOControl ioControl) throws IOException { + inputBuffer.consumeContent(contentDecoder); + } + + @Override + protected void onEntityEnclosed(HttpEntity httpEntity, ContentType contentType) { + httpResponse.updateAndGet(response -> { + response.setEntity(httpEntity); + return response; + }); + } + + @Override + protected HttpResponse buildResult(HttpContext httpContext) { + var allBytes = new byte[inputBuffer.length()]; + try { + inputBuffer.read(allBytes); + sseParser.parse(allBytes).forEach(collector::offer); + } catch (IOException e) { + failed(e); + } + return httpResponse.get(); + } + + @Override + protected void releaseResources() {} + + Deque events() { + return collector; + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index f30f2e8fe201a..c19cd916055d3 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -9,7 +9,9 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; @@ -19,11 +21,15 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; import java.io.IOException; +import java.util.Deque; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -72,6 +78,23 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) { """, taskType); } + static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) { + var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; + return Strings.format(""" + { + %s + "service": "streaming_completion_test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """, taskType); + } + static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" @@ -252,6 +275,32 @@ protected Map inferOnMockService(String modelId, List in return inferOnMockServiceInternal(endpoint, input); } + protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); + return callAsync(endpoint, input); + } + + private Deque callAsync(String endpoint, List input) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); + var request = new Request("POST", endpoint); + request.setJsonEntity(jsonBody(input)); + request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); + var latch = new CountDownLatch(1); + client().performRequestAsync(request, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + latch.countDown(); + } + }); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + return responseConsumer.events(); + } + protected Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferOnMockServiceInternal(endpoint, input); @@ -259,7 +308,13 @@ protected Map inferOnMockService(String modelId, TaskType taskTy private Map inferOnMockServiceInternal(String endpoint, List input) throws IOException { var request = new Request("POST", endpoint); + request.setJsonEntity(jsonBody(input)); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + private String jsonBody(List input) { var bodyBuilder = new StringBuilder("{\"input\": ["); for (var in : input) { bodyBuilder.append('"').append(in).append('"').append(','); @@ -267,11 +322,7 @@ private Map inferOnMockServiceInternal(String endpoint, List { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("error")); + case DATA -> assertThat( + event.value(), + containsString( + "Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]" + ) + ); + } + }); + } finally { + deleteModel(modelId); + } + } + + public void testSupportedStream() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(0, randomInt(10)).mapToObj(i -> randomAlphaOfLength(10)).toList(); + + try { + var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input); + + var expectedResponses = Stream.concat( + input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), + Stream.of("[DONE]") + ).iterator(); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java index 752472b90374b..eef0da909f529 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java @@ -44,6 +44,11 @@ public List getNamedWriteables() { ServiceSettings.class, TestRerankingServiceExtension.TestServiceSettings.NAME, TestRerankingServiceExtension.TestServiceSettings::new + ), + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + TestStreamingCompletionServiceExtension.TestServiceSettings.NAME, + TestStreamingCompletionServiceExtension.TestServiceSettings::new ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java new file mode 100644 index 0000000000000..3d72b1f2729b0 --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -0,0 +1,204 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; + +public class TestStreamingCompletionServiceExtension implements InferenceServiceExtension { + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + public static class TestInferenceService extends AbstractTestInferenceService { + private static final String NAME = "streaming_completion_test_service"; + private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION); + + public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} + + @Override + public String name() { + return NAME; + } + + @Override + protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { + return TestServiceSettings.fromMap(serviceSettingsMap); + } + + @Override + @SuppressWarnings("unchecked") + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var serviceSettings = TestSparseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); + } + + @Override + public void infer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeResults(input)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private StreamingChatCompletionResults makeResults(List input) { + var responseIter = input.stream().map(String::toUpperCase).iterator(); + return new StreamingChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(completionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + private ChunkedToXContent completionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startArray(COMPLETION), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("delta", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.endObject() + ); + } + + @Override + public void chunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + + @Override + public Set supportedStreamingTasks() { + return supportedStreamingTasks; + } + } + + public record TestServiceSettings(String modelId) implements ServiceSettings { + public static final String NAME = "streaming_completion_test_service_settings"; + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + public static TestServiceSettings fromMap(Map map) { + var modelId = map.remove("model").toString(); + + if (modelId == null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("missing model id"); + throw validationException; + } + + return new TestServiceSettings(modelId); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId()); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("model", modelId()).endObject(); + } + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension index 690168b538fb9..c996a33d1e916 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension @@ -1,3 +1,4 @@ org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension +org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1cec996400a97..a6972ddc214fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; +import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; @@ -167,6 +168,7 @@ public List getRestHandlers( ) { return List.of( new RestInferenceAction(), + new RestStreamInferenceAction(), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestDeleteInferenceEndpointAction(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index bfdfca166ef3a..803e8f1e07612 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -26,10 +27,17 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; + public class TransportInferenceAction extends HandledTransportAction { private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private static final Set> supportsStreaming = Set.of(); + private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private final InferenceStats inferenceStats; @@ -101,15 +109,40 @@ private void inferOnService( InferenceService service, ActionListener listener ) { - service.infer( - model, - request.getQuery(), - request.getInput(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - createListener(request, listener) - ); + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + service.infer( + model, + request.getQuery(), + request.getInput(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + createListener(request, listener) + ); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } } private ActionListener createListener( @@ -118,17 +151,9 @@ private ActionListener createListener( ) { if (request.isStreaming()) { return listener.delegateFailureAndWrap((l, inferenceResults) -> { - if (inferenceResults.isStreaming()) { - var taskProcessor = streamingTaskManager.create( - STREAMING_INFERENCE_TASK_TYPE, - STREAMING_TASK_ACTION - ); - inferenceResults.publisher().subscribe(taskProcessor); - l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor)); - } else { - // if we asked for streaming but the provider doesn't support it, for now we're going to get back the single response - l.onResponse(new InferenceAction.Response(inferenceResults)); - } + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor)); }); } return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 8f1e28d0d8ee4..7f21f94d33276 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -204,7 +204,8 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu List.of(query), Map.of(), InputType.SEARCH, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, + false ); queryRewriteContext.registerAsyncAction( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index cad11cbdc9d5b..0ff48bfd493ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -144,7 +144,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of(), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java new file mode 100644 index 0000000000000..e72e68052f648 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.io.IOException; + +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; + +abstract class BaseInferenceAction extends BaseRestHandler { + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String inferenceEntityId; + TaskType taskType; + if (restRequest.hasParam(INFERENCE_ID)) { + inferenceEntityId = restRequest.param(INFERENCE_ID); + taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + } else { + inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); + taskType = TaskType.ANY; + } + + InferenceAction.Request.Builder requestBuilder; + try (var parser = restRequest.contentParser()) { + requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + } + + var inferTimeout = restRequest.paramAsTime( + InferenceAction.Request.TIMEOUT.getPreferredName(), + InferenceAction.Request.DEFAULT_TIMEOUT + ); + requestBuilder.setInferenceTimeout(inferTimeout); + var request = prepareInferenceRequest(requestBuilder); + return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); + } + + protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { + return builder.build(); + } + + protected abstract ActionListener listener(RestChannel channel); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index e33931f3d2f8d..9f64b58e48b55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -15,6 +15,13 @@ public final class Paths { static final String TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID + "}"; static final String INFERENCE_DIAGNOSTICS_PATH = "_inference/.diagnostics"; + static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream"; + static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_stream"; + private Paths() { } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index f5c30d0a94c54..0fbc2f8214cbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -7,26 +7,21 @@ package org.elasticsearch.xpack.inference.rest; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; -import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; -import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; @ServerlessScope(Scope.PUBLIC) -public class RestInferenceAction extends BaseRestHandler { +public class RestInferenceAction extends BaseInferenceAction { @Override public String getName() { return "inference_action"; @@ -38,27 +33,7 @@ public List routes() { } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; - if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); - } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; - } - - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); - } - - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); - requestBuilder.setInferenceTimeout(inferTimeout); - return channel -> client.execute(InferenceAction.INSTANCE, requestBuilder.build(), new RestChunkedToXContentListener<>(channel)); + protected ActionListener listener(RestChannel channel) { + return new RestChunkedToXContentListener<>(channel); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java new file mode 100644 index 0000000000000..875c288da52bd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestStreamInferenceAction extends BaseInferenceAction { + @Override + public String getName() { + return "stream_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, STREAM_INFERENCE_ID_PATH), new Route(POST, STREAM_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { + return builder.setStream(true).build(); + } + + @Override + protected ActionListener listener(RestChannel channel) { + return new ServerSentEventsRestActionListener(channel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index a26dc50097cf5..a042fca44fdb5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -92,7 +92,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of("inferenceResultCount", inferenceResultCount), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 6d0c15d5c0bfe..120527f489549 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -312,7 +312,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of("throwing", true), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java new file mode 100644 index 0000000000000..05a8d52be5df4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestChunkedToXContentListener; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; +import org.junit.Before; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class BaseInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new BaseInferenceAction() { + @Override + protected ActionListener listener(RestChannel channel) { + return new RestChunkedToXContentListener<>(channel); + } + + @Override + public String getName() { + return "base_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, route("{task_type_or_id}"))); + } + }); + } + + private static String route(String param) { + return "_route/" + param; + } + + public void testUsesDefaultTimeout() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + public void testUses3SecondTimeoutFromParams() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withParams(new HashMap<>(Map.of("timeout", "3s"))) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + static InferenceAction.Response createResponse() { + return new InferenceAction.Response( + new InferenceTextEmbeddingByteResults( + List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 })) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java index 48e5d54a62733..1b0df1b4a20da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java @@ -9,19 +9,14 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.junit.Before; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -33,13 +28,13 @@ public void setUpAction() { controller().registerHandler(new RestInferenceAction()); } - public void testUsesDefaultTimeout() { + public void testStreamIsFalse() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + assertThat(request.isStreaming(), is(false)); executeCalled.set(true); return createResponse(); @@ -52,33 +47,4 @@ public void testUsesDefaultTimeout() { dispatchRequest(inferenceRequest); assertThat(executeCalled.get(), equalTo(true)); } - - public void testUses3SecondTimeoutFromParams() { - SetOnce executeCalled = new SetOnce<>(); - verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); - - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); - - executeCalled.set(true); - return createResponse(); - })); - - RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath("_inference/test") - .withParams(new HashMap<>(Map.of("timeout", "3s"))) - .withContent(new BytesArray("{}"), XContentType.JSON) - .build(); - dispatchRequest(inferenceRequest); - assertThat(executeCalled.get(), equalTo(true)); - } - - private static InferenceAction.Response createResponse() { - return new InferenceAction.Response( - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 })) - ) - ); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java new file mode 100644 index 0000000000000..b999e2c9b72f0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestStreamInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestStreamInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/test/_stream") + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index fd13e3de4e6cd..ab5a9d43fd6d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -126,7 +126,8 @@ private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, request.getInputs(), request.getTaskSettings(), inputType, - request.getInferenceTimeout() + request.getInferenceTimeout(), + false ), listener.delegateFailureAndWrap((l, r) -> l.onResponse(translateInferenceServiceResponse(r.getResults()))) );