Skip to content

Commit

Permalink
[ML] Fix check on E5 model platform compatibility (#113437)
Browse files Browse the repository at this point in the history
Creating an endpoint for the built in multilingual e5 model failed for
linux optimised version due to an error in the logic that checks model
compatibility.
  • Loading branch information
davidkyle committed Sep 30, 2024
1 parent 1d373a5 commit 3a04f07
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 29 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/113437.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 113437
summary: Fix check on E5 model platform compatibility
area: Machine Learning
type: bug
issues:
- 113577
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

public void testPutE5Small_withNoModelVariant() {
{
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
String inferenceEntityId = "testPutE5Small_withNoModelVariant";
expectThrows(
org.elasticsearch.client.ResponseException.class,
() -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity())
Expand All @@ -33,7 +33,7 @@ public void testPutE5Small_withNoModelVariant() {
}

public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
String inferenceEntityId = "teste5mall_withplatformagnosticvariant";
putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity());
var models = getTrainedModel("_all");
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
Expand All @@ -50,9 +50,8 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
deleteTextEmbeddingModel(inferenceEntityId);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
String inferenceEntityId = "teste5mall_withplatformspecificvariant";
if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity());
var models = getTrainedModel("_all");
Expand All @@ -77,7 +76,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
}

public void testPutE5Small_withFakeModelVariant() {
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
String inferenceEntityId = "teste5mall_withfakevariant";
expectThrows(
org.elasticsearch.client.ResponseException.class,
() -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity())
Expand Down Expand Up @@ -112,7 +111,7 @@ private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, Stri
private String noModelIdVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ private void e5Case(
MULTILINGUAL_E5_SMALL_MODEL_ID
)
);
}

if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) {
} else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) {
throw new IllegalArgumentException(
"Error parsing request config, model id does not match any models available on this platform. Was ["
+ esServiceSettingsBuilder.getModelId()
Expand All @@ -224,17 +222,19 @@ private void e5Case(
);
}

private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
Set<String> platformArchitectures,
String modelId
) {
static boolean modelVariantValidForArchitecture(Set<String> platformArchitectures, String modelId) {
if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) {
// platform agnostic model is always compatible
return true;
}

return modelId.equals(
selectDefaultModelVariantBasedOnClusterArchitecture(
platformArchitectures,
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
MULTILINGUAL_E5_SMALL_MODEL_ID
)
) && modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID) == false;
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -167,17 +169,12 @@ public void testParseRequestConfig_E5() {
ElasticsearchInternalServiceSettings.NUM_THREADS,
4,
ElasticsearchInternalServiceSettings.MODEL_ID,
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID
MULTILINGUAL_E5_SMALL_MODEL_ID
)
)
);

var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
1,
4,
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
null
);
var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null);

service.parseRequestConfig(
randomInferenceEntityId,
Expand All @@ -201,7 +198,7 @@ public void testParseRequestConfig_E5() {
ElasticsearchInternalServiceSettings.NUM_THREADS,
4,
ElasticsearchInternalServiceSettings.MODEL_ID,
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
MULTILINGUAL_E5_SMALL_MODEL_ID,
"not_a_valid_service_setting",
randomAlphaOfLength(10)
)
Expand Down Expand Up @@ -435,19 +432,14 @@ public void testParsePersistedConfig() {
ElasticsearchInternalServiceSettings.NUM_THREADS,
4,
ElasticsearchInternalServiceSettings.MODEL_ID,
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
MULTILINGUAL_E5_SMALL_MODEL_ID,
ServiceFields.DIMENSIONS,
1
)
)
);

var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
1,
4,
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
null
);
var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null);

MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig(
randomInferenceEntityId,
Expand Down Expand Up @@ -950,6 +942,31 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
assertThat(model, is(expectedModel));
}

public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
{
var architectures = Set.of("Aarch64");
assertFalse(
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
);

assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
}
{
var architectures = Set.of("linux-x86_64");
assertTrue(
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
);
assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
}
{
var architectures = Set.of("linux-x86_64", "Aarch64");
assertFalse(
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
);
assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
}
}

private ElasticsearchInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
return new ElasticsearchInternalService(context);
Expand Down

0 comments on commit 3a04f07

Please sign in to comment.