diff --git a/docs/changelog/113437.yaml b/docs/changelog/113437.yaml new file mode 100644 index 0000000000000..98831958e63f8 --- /dev/null +++ b/docs/changelog/113437.yaml @@ -0,0 +1,6 @@ +pr: 113437 +summary: Fix check on E5 model platform compatibility +area: Machine Learning +type: bug +issues: + - 113577 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 6c15b42dc65d5..01e8c30e3bf27 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -24,7 +24,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest { public void testPutE5Small_withNoModelVariant() { { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "testPutE5Small_withNoModelVariant"; expectThrows( org.elasticsearch.client.ResponseException.class, () -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity()) @@ -33,7 +33,7 @@ public void testPutE5Small_withNoModelVariant() { } public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withplatformagnosticvariant"; putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity()); var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); @@ -50,9 +50,8 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { deleteTextEmbeddingModel(inferenceEntityId); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198") public void testPutE5Small_withPlatformSpecificVariant() throws IOException { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withplatformspecificvariant"; if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) { putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity()); var models = getTrainedModel("_all"); @@ -77,7 +76,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException { } public void testPutE5Small_withFakeModelVariant() { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withfakevariant"; expectThrows( org.elasticsearch.client.ResponseException.class, () -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity()) @@ -112,7 +111,7 @@ private Map putTextEmbeddingModel(String inferenceEntityId, Stri private String noModelIdVariantJsonEntity() { return """ { - "service": "text_embedding", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 93408c067098b..675bc275c8bd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -201,9 +201,7 @@ private void e5Case( MULTILINGUAL_E5_SMALL_MODEL_ID ) ); - } - - if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { + } else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) { throw new IllegalArgumentException( "Error parsing request config, model id does not match any models available on this platform. Was [" + esServiceSettingsBuilder.getModelId() @@ -224,17 +222,19 @@ private void e5Case( ); } - private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic( - Set platformArchitectures, - String modelId - ) { + static boolean modelVariantValidForArchitecture(Set platformArchitectures, String modelId) { + if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) { + // platform agnostic model is always compatible + return true; + } + return modelId.equals( selectDefaultModelVariantBasedOnClusterArchitecture( platformArchitectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, MULTILINGUAL_E5_SMALL_MODEL_ID ) - ) && modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID) == false; + ); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 257616033f080..8569117c348b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -65,6 +65,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -167,17 +169,12 @@ public void testParseRequestConfig_E5() { ElasticsearchInternalServiceSettings.NUM_THREADS, 4, ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID + MULTILINGUAL_E5_SMALL_MODEL_ID ) ) ); - var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( - 1, - 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, - null - ); + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null); service.parseRequestConfig( randomInferenceEntityId, @@ -201,7 +198,7 @@ public void testParseRequestConfig_E5() { ElasticsearchInternalServiceSettings.NUM_THREADS, 4, ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + MULTILINGUAL_E5_SMALL_MODEL_ID, "not_a_valid_service_setting", randomAlphaOfLength(10) ) @@ -435,19 +432,14 @@ public void testParsePersistedConfig() { ElasticsearchInternalServiceSettings.NUM_THREADS, 4, ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + MULTILINGUAL_E5_SMALL_MODEL_ID, ServiceFields.DIMENSIONS, 1 ) ) ); - var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( - 1, - 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, - null - ); + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null); MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( randomInferenceEntityId, @@ -950,6 +942,31 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { assertThat(model, is(expectedModel)); } + public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { + { + var architectures = Set.of("Aarch64"); + assertFalse( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + { + var architectures = Set.of("linux-x86_64"); + assertTrue( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + { + var architectures = Set.of("linux-x86_64", "Aarch64"); + assertFalse( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + } + private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElasticsearchInternalService(context);