From 88a4c679b7d6e4f4010f27637dd4b068881597f2 Mon Sep 17 00:00:00 2001 From: Saloni Gupta <131198887+salonig23@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:11:39 -0700 Subject: [PATCH] feat: add Config Policies GET API and modify CRUD functions to accept both Workload types (#9946) --- master/internal/api_config_policies.go | 49 +++++- .../internal/api_config_policies_intg_test.go | 136 +++++++++++++++- .../postgres_task_config_policy.go | 90 +++++------ .../postgres_task_config_policy_intg_test.go | 148 ++++++++++-------- master/pkg/model/task_config_policy.go | 29 +--- 5 files changed, 311 insertions(+), 141 deletions(-) diff --git a/master/internal/api_config_policies.go b/master/internal/api_config_policies.go index 760bc701e9a..a8773378d9c 100644 --- a/master/internal/api_config_policies.go +++ b/master/internal/api_config_policies.go @@ -63,14 +63,55 @@ func (*apiServer) PutGlobalConfigPolicies( } // Get workspace task config policies. -func (*apiServer) GetWorkspaceConfigPolicies( +func (a *apiServer) GetWorkspaceConfigPolicies( ctx context.Context, req *apiv1.GetWorkspaceConfigPoliciesRequest, ) (*apiv1.GetWorkspaceConfigPoliciesResponse, error) { + license.RequireLicense("manage config policies") + + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + + w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false) + if err != nil { + return nil, err + } + + err = workspace.AuthZProvider.Get().CanViewWorkspaceConfigPolicies(ctx, *curUser, w) + if err != nil { + return nil, err + } + if !configpolicy.ValidWorkloadType(req.WorkloadType) { - return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType) + errMessage := fmt.Sprintf("invalid workload type: %s.", req.WorkloadType) + if len(req.WorkloadType) == 0 { + errMessage = noWorkloadErr + } + return nil, status.Errorf(codes.InvalidArgument, errMessage) } - data, err := stubData() - return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: data}, err + + configPolicies, err := configpolicy.GetTaskConfigPolicies( + ctx, ptrs.Ptr(int(req.WorkspaceId)), req.WorkloadType) + if err != nil { + return nil, err + } + policyMap := map[string]interface{}{} + if configPolicies.InvariantConfig != nil { + var configMap map[string]interface{} + if err := yaml.Unmarshal([]byte(*configPolicies.InvariantConfig), &configMap); err != nil { + return nil, fmt.Errorf("unable to unmarshal json: %w", err) + } + policyMap["invariant_config"] = configMap + } + if configPolicies.Constraints != nil { + var constraintsMap map[string]interface{} + if err := yaml.Unmarshal([]byte(*configPolicies.Constraints), &constraintsMap); err != nil { + return nil, fmt.Errorf("unable to unmarshal json: %w", err) + } + policyMap["constraints"] = constraintsMap + } + return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: configpolicy.MarshalConfigPolicy(policyMap)}, nil } // Get global task config policies. diff --git a/master/internal/api_config_policies_intg_test.go b/master/internal/api_config_policies_intg_test.go index 3f905391403..6192f551bb0 100644 --- a/master/internal/api_config_policies_intg_test.go +++ b/master/internal/api_config_policies_intg_test.go @@ -61,12 +61,12 @@ func TestDeleteWorkspaceConfigPolicies(t *testing.T) { for _, test := range cases { t.Run(test.name, func(t *testing.T) { - ntscPolicies := &model.NTSCTaskConfigPolicies{ + ntscPolicies := &model.TaskConfigPolicies{ WorkspaceID: ptrs.Ptr(int(test.req.WorkspaceId)), WorkloadType: model.NTSCType, LastUpdatedBy: curUser.ID, } - err = configpolicy.SetNTSCConfigPolicies(ctx, ntscPolicies) + err = configpolicy.SetTaskConfigPolicies(ctx, ntscPolicies) require.NoError(t, err) resp, err := api.DeleteWorkspaceConfigPolicies(ctx, test.req) @@ -79,7 +79,7 @@ func TestDeleteWorkspaceConfigPolicies(t *testing.T) { require.NotNil(t, resp) // Policies removed? - policies, err := configpolicy.GetNTSCConfigPolicies(ctx, ptrs.Ptr(int(workspaceID))) + policies, err := configpolicy.GetTaskConfigPolicies(ctx, ptrs.Ptr(int(workspaceID)), test.req.WorkloadType) require.Nil(t, policies) require.ErrorIs(t, err, sql.ErrNoRows) }) @@ -130,7 +130,7 @@ func TestDeleteGlobalConfigPolicies(t *testing.T) { for _, test := range cases { t.Run(test.name, func(t *testing.T) { - err := configpolicy.SetNTSCConfigPolicies(ctx, &model.NTSCTaskConfigPolicies{ + err := configpolicy.SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: curUser.ID, }) @@ -146,7 +146,7 @@ func TestDeleteGlobalConfigPolicies(t *testing.T) { require.NotNil(t, resp) // Policies removed? - policies, err := configpolicy.GetNTSCConfigPolicies(ctx, nil) + policies, err := configpolicy.GetTaskConfigPolicies(ctx, nil, test.req.WorkloadType) require.Nil(t, policies) require.ErrorIs(t, err, sql.ErrNoRows) }) @@ -212,3 +212,129 @@ func TestBasicRBACConfigPolicyPerms(t *testing.T) { }) } } + +func TestGetWorkspaceConfigPolicies(t *testing.T) { + api, curUser, ctx := setupAPITest(t, nil) + testutils.MustLoadLicenseAndKeyFromFilesystem("../../") + + wkspResp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()}) + require.NoError(t, err) + workspaceID1 := wkspResp.Workspace.Id + wkspResp, err = api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()}) + require.NoError(t, err) + workspaceID2 := wkspResp.Workspace.Id + + // set only config policy + taskConfigPolicies := &model.TaskConfigPolicies{ + WorkspaceID: ptrs.Ptr(int(workspaceID1)), + WorkloadType: model.NTSCType, + LastUpdatedBy: curUser.ID, + InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr), + } + err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies) + require.NoError(t, err) + + // set only constraints policy + taskConfigPolicies = &model.TaskConfigPolicies{ + WorkspaceID: ptrs.Ptr(int(workspaceID1)), + WorkloadType: model.ExperimentType, + LastUpdatedBy: curUser.ID, + Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr), + } + err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies) + require.NoError(t, err) + + // set both config and constraints policy + taskConfigPolicies = &model.TaskConfigPolicies{ + WorkspaceID: ptrs.Ptr(int(workspaceID2)), + WorkloadType: model.NTSCType, + LastUpdatedBy: curUser.ID, + InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr), + Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr), + } + err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies) + require.NoError(t, err) + + cases := []struct { + name string + req *apiv1.GetWorkspaceConfigPoliciesRequest + err error + hasConfig bool + hasConstraints bool + }{ + { + "invalid workload type", + &apiv1.GetWorkspaceConfigPoliciesRequest{ + WorkspaceId: workspaceID1, + WorkloadType: "bad workload type", + }, + fmt.Errorf("invalid workload type"), + false, + false, + }, + { + "empty workload type", + &apiv1.GetWorkspaceConfigPoliciesRequest{ + WorkspaceId: workspaceID1, + WorkloadType: "", + }, + fmt.Errorf(noWorkloadErr), + false, + false, + }, + { + "valid request only config", + &apiv1.GetWorkspaceConfigPoliciesRequest{ + WorkspaceId: workspaceID1, + WorkloadType: model.NTSCType, + }, + nil, + true, + false, + }, + { + "valid request only constraints", + &apiv1.GetWorkspaceConfigPoliciesRequest{ + WorkspaceId: workspaceID1, + WorkloadType: model.ExperimentType, + }, + nil, + false, + true, + }, + { + "valid request both configs and constraints", + &apiv1.GetWorkspaceConfigPoliciesRequest{ + WorkspaceId: workspaceID2, + WorkloadType: model.NTSCType, + }, + nil, + true, + true, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + resp, err := api.GetWorkspaceConfigPolicies(ctx, test.req) + if test.err != nil { + require.ErrorContains(t, err, test.err.Error()) + return + } + require.NoError(t, err) + require.NotNil(t, resp) + + if test.hasConfig { + require.Contains(t, resp.ConfigPolicies.String(), "config") + } else { + require.NotContains(t, resp.ConfigPolicies.String(), "config") + } + + if test.hasConstraints { + require.Contains(t, resp.ConfigPolicies.String(), "constraints") + } else { + require.NotContains(t, resp.ConfigPolicies.String(), "constraints") + } + }) + } +} diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index 2096cf62a2d..b160ef5d477 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -6,8 +6,6 @@ import ( "strings" "github.com/uptrace/bun" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/pkg/model" @@ -16,76 +14,78 @@ import ( const ( wkspIDQuery = "workspace_id = ?" wkspIDGlobalQuery = "workspace_id IS ?" + // DefaultInvariantConfigStr is the default invariant config val used for tests. + DefaultInvariantConfigStr = `{"description": "random description", "resources": {"slots": 4, "max_slots": 8}}` + // DefaultConstraintsStr is the default constraints val used for tests. + DefaultConstraintsStr = `{"priority_limit": 10, "resources": {"max_slots": 8}}` ) -// SetNTSCConfigPolicies adds the NTSC invariant config and constraints config policies to +// SetTaskConfigPolicies adds the task invariant config and constraints config policies to // the database. -func SetNTSCConfigPolicies(ctx context.Context, - ntscTCPs *model.NTSCTaskConfigPolicies, +func SetTaskConfigPolicies(ctx context.Context, + tcp *model.TaskConfigPolicies, ) error { return db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - return SetNTSCConfigPoliciesTx(ctx, &tx, ntscTCPs) + return SetTaskConfigPoliciesTx(ctx, &tx, tcp) }) } -// SetNTSCConfigPoliciesTx adds the NTSC invariant config and constraints config policies to +// SetTaskConfigPoliciesTx adds the task invariant config and constraints config policies to // the database. -func SetNTSCConfigPoliciesTx(ctx context.Context, tx *bun.Tx, - ntscTCPs *model.NTSCTaskConfigPolicies, +func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx, + tcp *model.TaskConfigPolicies, ) error { - if ntscTCPs.WorkloadType != model.NTSCType { - return status.Error(codes.InvalidArgument, - "invalid workload type for config policies: "+ntscTCPs.WorkloadType) + q := db.Bun().NewInsert(). + Model(tcp) + + if tcp.InvariantConfig == nil { + q = q.ExcludeColumn("invariant_config") + } + if tcp.Constraints == nil { + q = q.ExcludeColumn("constraints") } - q := ` - INSERT INTO task_config_policies (workspace_id, workload_type, last_updated_by, - last_updated_time, invariant_config, constraints) VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT (workspace_id, workload_type) WHERE workspace_id IS NOT NULL - DO UPDATE SET last_updated_by = ?, last_updated_time = ?, invariant_config = ?, - constraints = ? - ` - if ntscTCPs.WorkspaceID == nil { - q = ` - INSERT INTO task_config_policies (workspace_id, workload_type, last_updated_by, - last_updated_time, invariant_config, constraints) VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT (workload_type) WHERE workspace_id IS NULL - DO UPDATE SET last_updated_by = ?, last_updated_time = ?, invariant_config = ?, - constraints = ? - ` + if tcp.WorkspaceID == nil { + q = q.On("CONFLICT (workload_type) WHERE workspace_id IS NULL DO UPDATE") + } else { + q = q.On("CONFLICT (workspace_id, workload_type) WHERE workspace_id IS NOT NULL DO UPDATE") } - _, err := db.Bun().NewRaw(q, ntscTCPs.WorkspaceID, model.NTSCType, - ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime, ntscTCPs.InvariantConfig, - ntscTCPs.Constraints, ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime, - ntscTCPs.InvariantConfig, ntscTCPs.Constraints). - Exec(ctx) - if err != nil { - return fmt.Errorf("error setting NTSC task config policies: %w", err) + + q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime) + if tcp.InvariantConfig != nil { + q = q.Set("invariant_config = ?", tcp.InvariantConfig) + } + if tcp.Constraints != nil { + q = q.Set("constraints = ?", tcp.Constraints) } + _, err := q.Exec(ctx) + if err != nil { + return fmt.Errorf("error setting task config policies: %w", err) + } return nil } -// GetNTSCConfigPolicies retrieves the invariant NTSC config and constraints for the -// given scope (global or workspace-level). -func GetNTSCConfigPolicies(ctx context.Context, - scope *int, -) (*model.NTSCTaskConfigPolicies, error) { - var ntscTCP model.NTSCTaskConfigPolicies +// GetTaskConfigPolicies retrieves the invariant config and constraints for the +// given scope (global or workspace-level) and workload Type. +func GetTaskConfigPolicies(ctx context.Context, + scope *int, workloadType string, +) (*model.TaskConfigPolicies, error) { + var tcp model.TaskConfigPolicies wkspQuery := wkspIDQuery if scope == nil { wkspQuery = wkspIDGlobalQuery } err := db.Bun().NewSelect(). - Model(&ntscTCP). + Model(&tcp). Where(wkspQuery, scope). - Where("workload_type = ?", model.NTSCType). + Where("workload_type = ?", workloadType). Scan(ctx) if err != nil { - return nil, fmt.Errorf("error retrieving NTSC task config policies for "+ - "workspace with ID %d: %w", scope, err) + return nil, fmt.Errorf("error retrieving %v task config policies for "+ + "workspace with ID %d: %w", workloadType, scope, err) } - return &ntscTCP, nil + return &tcp, nil } // DeleteConfigPolicies deletes the invariant experiment config and constraints for the diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 01d9dfbf37d..a7dcd69a0b2 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestSetNTSCConfigPolicies(t *testing.T) { +func TestSetTaskConfigPolicies(t *testing.T) { ctx := context.Background() require.NoError(t, etc.SetRootPath(db.RootFromDB)) pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) @@ -40,18 +40,18 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }() tests := []struct { - name string - ntscTCPs *model.NTSCTaskConfigPolicies - global bool - err *string + name string + tcps *model.TaskConfigPolicies + global bool + err *string }{ { "invalid user id", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: -1, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: DefaultCommandConfig(), + InvariantConfig: DefaultInvariantConfig(), Constraints: DefaultConstraints(), }, false, @@ -59,23 +59,23 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }, { "valid config no constraint", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: DefaultCommandConfig(), - Constraints: model.Constraints{}, + InvariantConfig: DefaultInvariantConfig(), + Constraints: nil, }, false, nil, }, { "valid constraint no config", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: model.CommandConfig{}, + InvariantConfig: nil, Constraints: DefaultConstraints(), }, false, @@ -83,11 +83,11 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }, { "valid constraint valid config", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: DefaultCommandConfig(), + InvariantConfig: DefaultInvariantConfig(), Constraints: DefaultConstraints(), }, false, @@ -95,12 +95,12 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }, { "global valid constraint no config", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, WorkspaceID: nil, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: model.CommandConfig{}, + InvariantConfig: nil, Constraints: DefaultConstraints(), }, true, @@ -108,23 +108,23 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }, { "global valid config no constraint", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: DefaultCommandConfig(), - Constraints: model.Constraints{}, + InvariantConfig: DefaultInvariantConfig(), + Constraints: nil, }, true, nil, }, { "global valid constraint valid config", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: DefaultCommandConfig(), + InvariantConfig: DefaultInvariantConfig(), Constraints: DefaultConstraints(), }, true, @@ -132,27 +132,27 @@ func TestSetNTSCConfigPolicies(t *testing.T) { }, { "global no constraint no config", - &model.NTSCTaskConfigPolicies{ + &model.TaskConfigPolicies{ WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: model.CommandConfig{}, - Constraints: model.Constraints{}, + InvariantConfig: nil, + Constraints: nil, }, true, nil, }, { - "experiment workload type for NTSC policies", - &model.NTSCTaskConfigPolicies{ + "experiment workload type for TCP policies", + &model.TaskConfigPolicies{ WorkloadType: model.ExperimentType, LastUpdatedBy: user.ID, LastUpdatedTime: time.Now().UTC().Truncate(time.Second), - InvariantConfig: model.CommandConfig{}, - Constraints: model.Constraints{}, + InvariantConfig: nil, + Constraints: nil, }, true, - ptrs.Ptr("invalid workload type"), + nil, }, } @@ -164,11 +164,11 @@ func TestSetNTSCConfigPolicies(t *testing.T) { _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) require.NoError(t, err) workspaceIDs = append(workspaceIDs, int32(w.ID)) - test.ntscTCPs.WorkspaceID = ptrs.Ptr(w.ID) + test.tcps.WorkspaceID = ptrs.Ptr(w.ID) } // Test add NTSC task config policies. - err := SetNTSCConfigPolicies(ctx, test.ntscTCPs) + err := SetTaskConfigPolicies(ctx, test.tcps) if test.err != nil { require.ErrorContains(t, err, *test.err) return @@ -176,32 +176,31 @@ func TestSetNTSCConfigPolicies(t *testing.T) { require.NoError(t, err) // Test get NTSC task config policies. - ntscTCPs, err := GetNTSCConfigPolicies(ctx, test.ntscTCPs.WorkspaceID) + tcps, err := GetTaskConfigPolicies(ctx, test.tcps.WorkspaceID, test.tcps.WorkloadType) require.NoError(t, err) - ntscTCPs.LastUpdatedTime = ntscTCPs.LastUpdatedTime.UTC() - require.Equal(t, test.ntscTCPs, ntscTCPs) + tcps.LastUpdatedTime = tcps.LastUpdatedTime.UTC() + requireEqualTaskPolicy(t, test.tcps, tcps) // Test update NTSC task config policies. - test.ntscTCPs.InvariantConfig.Environment.Image = model.RuntimeItem{ - CPU: uuid.NewString(), - } - err = SetNTSCConfigPolicies(ctx, test.ntscTCPs) + test.tcps.InvariantConfig = ptrs.Ptr( + `{"description":"random description","resources":{"slots":4,"max_slots":8},"notebook_idle_type":"activity"}`) + err = SetTaskConfigPolicies(ctx, test.tcps) require.NoError(t, err) // Test get NTSC task config policies. - ntscTCPs, err = GetNTSCConfigPolicies(ctx, test.ntscTCPs.WorkspaceID) + tcps, err = GetTaskConfigPolicies(ctx, test.tcps.WorkspaceID, test.tcps.WorkloadType) require.NoError(t, err) - ntscTCPs.LastUpdatedTime = ntscTCPs.LastUpdatedTime.UTC().Truncate(time.Second) - require.Equal(t, test.ntscTCPs, ntscTCPs) + tcps.LastUpdatedTime = tcps.LastUpdatedTime.UTC().Truncate(time.Second) + requireEqualTaskPolicy(t, test.tcps, tcps) }) } // Test invalid workspace ID. - err := SetNTSCConfigPolicies(ctx, &model.NTSCTaskConfigPolicies{ + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkspaceID: ptrs.Ptr(-1), LastUpdatedBy: user.ID, WorkloadType: model.NTSCType, - InvariantConfig: DefaultCommandConfig(), + InvariantConfig: DefaultInvariantConfig(), Constraints: DefaultConstraints(), }) require.ErrorContains(t, err, "violates foreign key constraint") @@ -219,7 +218,7 @@ func TestTaskConfigPoliciesUnique(t *testing.T) { // Global scope. _, _, ntscTCPs := CreateMockTaskConfigPolicies(ctx, t, pgDB, user, true, true, true) - ntscTCPs.Constraints = model.Constraints{} + ntscTCPs.Constraints = nil expInvariantConfig, err := json.Marshal(ntscTCPs.InvariantConfig) require.NoError(t, err) @@ -239,7 +238,7 @@ func TestTaskConfigPoliciesUnique(t *testing.T) { // Workspace-level. w, _, ntscTCPs := CreateMockTaskConfigPolicies(ctx, t, pgDB, user, false, true, true) - ntscTCPs.Constraints = model.Constraints{} + ntscTCPs.Constraints = nil expInvariantConfig, err = json.Marshal(ntscTCPs.InvariantConfig) require.NoError(t, err) @@ -419,14 +418,14 @@ func TestDeleteConfigPolicies(t *testing.T) { // requested for the specified scope. func CreateMockTaskConfigPolicies(ctx context.Context, t *testing.T, pgDB *db.PgDB, user model.User, global bool, hasInvariantConfig bool, - hasConstraints bool) (*model.Workspace, *model.ExperimentTaskConfigPolicies, - *model.NTSCTaskConfigPolicies, + hasConstraints bool) (*model.Workspace, *model.TaskConfigPolicies, + *model.TaskConfigPolicies, ) { var scope *int var w model.Workspace - var ntscConfig model.CommandConfig + var ntscConfig *string - var constraints model.Constraints + var constraints *string if !global { w = model.Workspace{Name: uuid.NewString(), UserID: user.ID} @@ -435,40 +434,57 @@ func CreateMockTaskConfigPolicies(ctx context.Context, t *testing.T, scope = ptrs.Ptr(w.ID) } if hasInvariantConfig { - ntscConfig = DefaultCommandConfig() + ntscConfig = DefaultInvariantConfig() } if hasConstraints { constraints = DefaultConstraints() } - ntscTCP := &model.NTSCTaskConfigPolicies{ + ntscTCP := &model.TaskConfigPolicies{ WorkspaceID: scope, WorkloadType: model.NTSCType, LastUpdatedBy: user.ID, InvariantConfig: ntscConfig, Constraints: constraints, } - err := SetNTSCConfigPolicies(ctx, ntscTCP) + err := SetTaskConfigPolicies(ctx, ntscTCP) require.NoError(t, err) return &w, nil, ntscTCP } -func DefaultCommandConfig() model.CommandConfig { - return model.CommandConfig{ - Description: "random description", - Resources: model.ResourcesConfig{ - Slots: 4, - MaxSlots: ptrs.Ptr(8), - }, - } +func DefaultInvariantConfig() *string { + return ptrs.Ptr(DefaultInvariantConfigStr) } -func DefaultConstraints() model.Constraints { - return model.Constraints{ - PriorityLimit: ptrs.Ptr[int](10), - ResourceConstraints: &model.ResourceConstraints{ - MaxSlots: ptrs.Ptr(10), - }, +func DefaultConstraints() *string { + return ptrs.Ptr(DefaultConstraintsStr) +} + +func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *model.TaskConfigPolicies) { + require.Equal(t, exp.LastUpdatedBy, act.LastUpdatedBy) + require.Equal(t, exp.LastUpdatedTime, act.LastUpdatedTime) + require.Equal(t, exp.WorkloadType, act.WorkloadType) + require.Equal(t, exp.WorkspaceID, act.WorkspaceID) + + if exp.Constraints == nil { + require.Nil(t, exp.Constraints, act.Constraints) + } else { + var expJSONMap, actJSONMap map[string]interface{} + err := json.Unmarshal([]byte(*exp.Constraints), &expJSONMap) + require.NoError(t, err) + err = json.Unmarshal([]byte(*act.Constraints), &actJSONMap) + require.NoError(t, err) + require.Equal(t, expJSONMap, actJSONMap) + } + if exp.InvariantConfig == nil { + require.Nil(t, exp.InvariantConfig, act.InvariantConfig) + } else { + var expJSONMap, actJSONMap map[string]interface{} + err := json.Unmarshal([]byte(*exp.InvariantConfig), &expJSONMap) + require.NoError(t, err) + err = json.Unmarshal([]byte(*act.InvariantConfig), &actJSONMap) + require.NoError(t, err) + require.Equal(t, expJSONMap, actJSONMap) } } diff --git a/master/pkg/model/task_config_policy.go b/master/pkg/model/task_config_policy.go index 58713a8e373..74311ec6495 100644 --- a/master/pkg/model/task_config_policy.go +++ b/master/pkg/model/task_config_policy.go @@ -4,8 +4,6 @@ import ( "time" "github.com/uptrace/bun" - - "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) // Constants. @@ -19,26 +17,15 @@ const ( NTSCType string = "NTSC" ) -// ExperimentTaskConfigPolicies is the bun model of a task config policy. -type ExperimentTaskConfigPolicies struct { - bun.BaseModel `bun:"table:task_config_policies"` - WorkspaceID *int `bun:"workspace_id"` - WorkloadType string `bun:"workload_type,notnull"` - LastUpdatedBy UserID `bun:"last_updated_by,notnull"` - LastUpdatedTime time.Time `bun:"last_updated_time,notnull"` - InvariantConfig expconf.ExperimentConfig `bun:"invariant_config"` - Constraints Constraints `bun:"constraints"` -} - -// NTSCTaskConfigPolicies is the bun model of a task config policy. -type NTSCTaskConfigPolicies struct { +// TaskConfigPolicies is the bun model of a task config policy. +type TaskConfigPolicies struct { bun.BaseModel `bun:"table:task_config_policies"` - WorkspaceID *int `bun:"workspace_id"` - WorkloadType string `bun:"workload_type,notnull"` - LastUpdatedBy UserID `bun:"last_updated_by,notnull"` - LastUpdatedTime time.Time `bun:"last_updated_time,notnull"` - InvariantConfig CommandConfig `bun:"invariant_config"` - Constraints Constraints `bun:"constraints"` + WorkspaceID *int `bun:"workspace_id"` + WorkloadType string `bun:"workload_type,notnull"` + LastUpdatedBy UserID `bun:"last_updated_by,notnull"` + LastUpdatedTime time.Time `bun:"last_updated_time,notnull"` + InvariantConfig *string `bun:"invariant_config"` + Constraints *string `bun:"constraints"` } // ResourceConstraints are non-overridable resource constraints.