Skip to content

Commit

Permalink
[Inference API] Propagate infer trace context to EIS (#113407)
Browse files Browse the repository at this point in the history
  • Loading branch information
timgrein committed Sep 30, 2024
1 parent 55078d4 commit bf329e2
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Objects;

Expand All @@ -24,14 +25,17 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer

private final ServiceComponents serviceComponents;

public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents) {
private final TraceContext traceContext;

public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.traceContext = traceContext;
}

@Override
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents);
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
var errorMessage = constructFailedToSendRequestMessage(model.uri(), "Elastic Inference Service sparse embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceSparseEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.List;
import java.util.function.Supplier;
Expand All @@ -35,6 +36,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast

private final Truncator truncator;

private final TraceContext traceContext;

private static ResponseHandler createSparseEmbeddingsHandler() {
return new ElasticInferenceServiceResponseHandler(
"Elastic Inference Service sparse embeddings",
Expand All @@ -44,11 +47,13 @@ private static ResponseHandler createSparseEmbeddingsHandler() {

public ElasticInferenceServiceSparseEmbeddingsRequestManager(
ElasticInferenceServiceSparseEmbeddingsModel model,
ServiceComponents serviceComponents
ServiceComponents serviceComponents,
TraceContext traceContext
) {
super(serviceComponents.threadPool(), model);
this.model = model;
this.truncator = serviceComponents.truncator();
this.traceContext = traceContext;
}

@Override
Expand All @@ -64,7 +69,8 @@ public void execute(
ElasticInferenceServiceSparseEmbeddingsRequest request = new ElasticInferenceServiceSparseEmbeddingsRequest(
truncator,
truncatedInput,
model
model,
traceContext
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.net.URI;
import java.nio.charset.StandardCharsets;
Expand All @@ -31,15 +33,19 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
private final Truncator.TruncationResult truncationResult;
private final Truncator truncator;

private final TraceContext traceContext;

public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator truncator,
Truncator.TruncationResult truncationResult,
ElasticInferenceServiceSparseEmbeddingsModel model
ElasticInferenceServiceSparseEmbeddingsModel model,
TraceContext traceContext
) {
this.truncator = truncator;
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;
}

@Override
Expand All @@ -50,6 +56,10 @@ public HttpRequest createHttpRequest() {
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

if (traceContext != null) {
propagateTraceContext(httpPost);
}

httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
Expand All @@ -65,16 +75,32 @@ public URI getURI() {
return this.uri;
}

public TraceContext getTraceContext() {
return traceContext;
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model);
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}

private void propagateTraceContext(HttpPost httpPost) {
var traceParent = traceContext.traceParent();
var traceState = traceContext.traceState();

if (traceParent != null) {
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
}

if (traceState != null) {
httpPost.setHeader(Task.TRACE_STATE, traceState);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand All @@ -34,6 +35,7 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -75,8 +77,13 @@ protected void doInfer(
return;
}

// We extract the trace context here as it's sufficient to propagate the trace information of the REST request,
// which handles the request to the inference API overall (including the outgoing request, which is started in a new thread
// generating a different "traceparent" as every task and every REST request creates a new span).
var currentTraceInfo = getCurrentTraceInfo();

ElasticInferenceServiceModel elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents());
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo);

var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
Expand Down Expand Up @@ -258,4 +265,13 @@ private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDet

return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
}

private TraceContext getCurrentTraceInfo() {
var threadPool = getServiceComponents().threadPool();

var traceParent = threadPool.getThreadContext().getHeader(Task.TRACE_PARENT);
var traceState = threadPool.getThreadContext().getHeader(Task.TRACE_STATE);

return new TraceContext(traceParent, traceState);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.telemetry;

public record TraceContext(String traceParent, String traceState) {}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -89,7 +90,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
var action = actionCreator.create(model);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -145,7 +146,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
var action = actionCreator.create(model);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -197,7 +198,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
var action = actionCreator.create(model);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -257,7 +258,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {

// truncated to 1 token = 3 characters
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1);
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
var action = actionCreator.create(model);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -286,4 +287,8 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
}
}

private TraceContext createTraceContext() {
return new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -42,6 +44,23 @@ public void testCreateHttpRequest() throws IOException {
assertThat(requestMap.get("input"), is(List.of(input)));
}

public void testTraceContextPropagatedThroughHTTPHeaders() {
var url = "http://eis-gateway.com";
var input = "input";

var request = createRequest(url, input);
var httpRequest = request.createHttpRequest();

assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();

var traceParent = request.getTraceContext().traceParent();
var traceState = request.getTraceContext().traceState();

assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent));
assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState));
}

public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
var url = "http://eis-gateway.com";
var input = "abcd";
Expand Down Expand Up @@ -75,7 +94,8 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url,
return new ElasticInferenceServiceSparseEmbeddingsRequest(
TruncatorTests.createTruncator(),
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
embeddingsModel
embeddingsModel,
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10))
);
}
}

0 comments on commit bf329e2

Please sign in to comment.