Skip to content

Commit

Permalink
feat: check task config policy constraints before scheduling NTSC wor… (
Browse files Browse the repository at this point in the history
#9991)

Co-authored-by: Amanda Vialva <[email protected]>
  • Loading branch information
kkunapuli and amandavialva01 authored Oct 1, 2024
1 parent 0083d7e commit 83a779e
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 7 deletions.
7 changes: 7 additions & 0 deletions master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/determined-ai/determined/master/internal/api/apiutils"
"github.com/determined-ai/determined/master/internal/authz"
"github.com/determined-ai/determined/master/internal/command"
"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/rbac/audit"
Expand Down Expand Up @@ -149,6 +150,12 @@ func (a *apiServer) getCommandLaunchParams(ctx context.Context, req *protoComman
return nil, nil, err
}

// Check submitted config against task config policies.
valid, err := configpolicy.CheckNTSCConstraints(ctx, int(cmdSpec.Metadata.WorkspaceID), config, a.m.rm)
if !valid {
return nil, nil, err
}

token, err := getTaskSessionToken(ctx, userModel)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 3 additions & 0 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,9 @@ func TestAuthZGetExperimentAndCanDoActions(t *testing.T) {
mock.Anything).Return(nil).Once()
workspaceAuthZ.On("CanGetWorkspace", mock.Anything, mock.Anything, mock.Anything).
Return(nil).Once()
mockRM := MockRM()
mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
api.m.rm = mockRM
_, err := api.LaunchTensorboard(ctx, &apiv1.LaunchTensorboardRequest{
ExperimentIds: []int32{int32(id)},
})
Expand Down
3 changes: 3 additions & 0 deletions master/internal/api_ntsc_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ func TestAuthZCanCreateNSC(t *testing.T) {
).Times(3)
workspaceAuthZ.On("CanGetWorkspace", mock.Anything, mock.Anything, mock.Anything).
Return(nil).Times(3)
mockRM := MockRM()
mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
api.m.rm = mockRM
_, err = api.LaunchNotebook(ctx, &apiv1.LaunchNotebookRequest{})
require.Equal(t, codes.PermissionDenied, status.Code(err))
_, err = api.LaunchCommand(ctx, &apiv1.LaunchCommandRequest{})
Expand Down
3 changes: 3 additions & 0 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,9 @@ func TestTrialAuthZ(t *testing.T) {
mock.Anything).Return(nil).Once()
workspaceAuthZ.On("CanGetWorkspace", mock.Anything, mock.Anything, mock.Anything).
Return(nil).Once()
mockRM := MockRM()
mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
api.m.rm = mockRM
_, err := api.LaunchTensorboard(ctx, &apiv1.LaunchTensorboardRequest{
TrialIds: []int32{int32(id)},
})
Expand Down
72 changes: 71 additions & 1 deletion master/internal/configpolicy/task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package configpolicy

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)
Expand All @@ -28,7 +31,74 @@ type NTSCConfigPolicies struct {
Constraints *model.Constraints `json:"constraints"`
}

// PriorityAllowed returns true if the desired priority is within the limit set by task config policies.
var (
errPriorityConstraintFailure = errors.New("submitted workload failed priority constraint")
errResourceConstraintFailure = errors.New("submitted workload failed a resource constraint")
)

// CheckNTSCConstraints returns true if the NTSC config passes constraint checks.
func CheckNTSCConstraints(
ctx context.Context,
workspaceID int,
workloadConfig model.CommandConfig,
resourceManager rm.ResourceManager,
) (bool, error) {
constraints, err := GetMergedConstraints(ctx, workspaceID, model.NTSCType)
if err != nil {
return false, err
}

// For each submitted constraint, check if the workload config is within allowed values.
// rm.SmallerValueIsHigherPriority only returns an error if task priority is not implemented for that resource manager.
// In that case, there is no need to check if requested priority is within limits.
smallerHigher, err := resourceManager.SmallerValueIsHigherPriority()
if err == nil && constraints.PriorityLimit != nil && workloadConfig.Resources.Priority != nil {
if !priorityWithinLimit(*workloadConfig.Resources.Priority, *constraints.PriorityLimit, smallerHigher) {
return false, fmt.Errorf("requested priority [%d] exceeds limit set by admin [%d]: %w",
*constraints.PriorityLimit, *workloadConfig.Resources.Priority, errPriorityConstraintFailure)
}
}

if constraints.ResourceConstraints != nil && constraints.ResourceConstraints.MaxSlots != nil &&
workloadConfig.Resources.MaxSlots != nil {
if *constraints.ResourceConstraints.MaxSlots < *workloadConfig.Resources.MaxSlots {
return false, fmt.Errorf("requested resources.max_slots [%d] exceeds limit set by admin [%d]: %w",
*constraints.ResourceConstraints.MaxSlots, *workloadConfig.Resources.MaxSlots, errResourceConstraintFailure)
}
}

return true, nil
}

// GetMergedConstraints retrieves Workspace and Global constraints and returns a merged result.
// workloadType is expected to be model.ExperimentType or model.NTSCType.
func GetMergedConstraints(ctx context.Context, workspaceID int, workloadType string) (*model.Constraints, error) {
// Workspace-level constraints should be over-ridden by global contraints, if set.
var constraints model.Constraints
wkspConfigPolicies, err := GetTaskConfigPolicies(ctx, &workspaceID, workloadType)
if err != nil {
return nil, err
}
if wkspConfigPolicies.Constraints != nil {
if err = json.Unmarshal([]byte(*wkspConfigPolicies.Constraints), &constraints); err != nil {
return nil, fmt.Errorf("unable to merge workspace and global constraints: %w", err)
}
}

globalConfigPolicies, err := GetTaskConfigPolicies(ctx, nil, workloadType)
if err != nil {
return nil, err
}
if globalConfigPolicies.Constraints != nil {
if err = json.Unmarshal([]byte(*globalConfigPolicies.Constraints), &constraints); err != nil {
return nil, fmt.Errorf("unable to merge workspace and global constraints: %w", err)
}
}

return &constraints, nil
}

// PriorityAllowed returns true if the desired priority is within the task config policy limit.
func PriorityAllowed(wkspID int, workloadType string, priority int, smallerHigher bool) (bool, error) {
// Check if a priority limit has been set with a constraint policy.
// Global policies have highest precedence.
Expand Down
159 changes: 153 additions & 6 deletions master/internal/configpolicy/task_config_policy_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/mocks"
"github.com/determined-ai/determined/master/pkg/etc"
"github.com/determined-ai/determined/master/pkg/model"
)
Expand All @@ -26,7 +28,7 @@ func TestPriorityAllowed(t *testing.T) {

wkspLimit := 50
user := db.RequireMockUser(t, pgDB)
w := addWorkspacePriorityLimit(t, pgDB, user, wkspLimit)
w := addWorkspacePriorityLimit(t, user, wkspLimit)

// Priority is outside workspace limit.
smallerValueIsHigherPriority := true
Expand All @@ -35,7 +37,7 @@ func TestPriorityAllowed(t *testing.T) {
require.False(t, ok)

globalLimit := 42
addGlobalPriorityLimit(t, pgDB, user, globalLimit)
addConstraints(t, user, nil, fmt.Sprintf(`{"priority_limit": %d}`, globalLimit))

// Priority is within global limit.
ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true)
Expand All @@ -48,7 +50,142 @@ func TestPriorityAllowed(t *testing.T) {
require.False(t, ok)
}

func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) model.Workspace {
func TestValidateNTSCConstraints(t *testing.T) {
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

wkspPriorityLimit := 7
user := db.RequireMockUser(t, pgDB)

t.Run("no constraints set - ok", func(t *testing.T) {
resourceManager := mocks.ResourceManager{}
resourceManager.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil)

config := defaultConfig()
ok, err := CheckNTSCConstraints(context.Background(), 1, config, &resourceManager)
require.NoError(t, err)
require.True(t, ok)
})

t.Run("running in wksp with constraints - not ok", func(t *testing.T) {
w := addWorkspacePriorityLimit(t, user, wkspPriorityLimit)
resourceManager := mocks.ResourceManager{}
resourceManager.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
config := defaultConfig()
_, err := CheckNTSCConstraints(ctx, w.ID, config, &resourceManager)
require.Error(t, err)
require.ErrorIs(t, err, errPriorityConstraintFailure)
})

t.Run("running in wksp without constraints - ok", func(t *testing.T) {
resourceManager := mocks.ResourceManager{}
resourceManager.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil)
w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(context.Background())
require.NoError(t, err)

config := defaultConfig()
ok, err := CheckNTSCConstraints(context.Background(), w.ID, config, &resourceManager)
require.True(t, ok)
require.NoError(t, err)
})

t.Run("exceeds max slots - not ok", func(t *testing.T) {
constraints := DefaultConstraints()
w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(context.Background())
require.NoError(t, err)
addConstraints(t, user, &w.ID, *constraints)

resourceManager := mocks.ResourceManager{}
resourceManager.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)

config := defaultConfig()
_, err = CheckNTSCConstraints(context.Background(), w.ID, config, &resourceManager)
require.Error(t, err)
require.ErrorIs(t, err, errResourceConstraintFailure)
})

t.Run("rm priority not supported - ok", func(t *testing.T) {
w := addWorkspacePriorityLimit(t, user, wkspPriorityLimit)
rm1 := mocks.ResourceManager{}
rm1.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil).Once()

config := defaultConfig()
_, err := CheckNTSCConstraints(context.Background(), w.ID, config, &rm1)
require.Error(t, err)
require.ErrorIs(t, err, errPriorityConstraintFailure)

// Validate constraints again. This time, the RM does not support priority.
rmNoPriority := mocks.ResourceManager{}
rmNoPriority.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, fmt.Errorf("not supported")).Once()
ok, err := CheckNTSCConstraints(context.Background(), w.ID, config, &rmNoPriority)
require.True(t, ok)
require.NoError(t, err)
})

t.Run("no config set - ok", func(t *testing.T) {
constraints := DefaultConstraints()
addConstraints(t, user, nil, *constraints) // add global constraints

resourceManager := mocks.ResourceManager{}
resourceManager.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)

config := defaultConfig()
ok, err := CheckNTSCConstraints(context.Background(), 1, config, &resourceManager)
require.False(t, ok)
require.Error(t, err)

emptyConfig := model.CommandConfig{}
ok, err = CheckNTSCConstraints(context.Background(), 1, emptyConfig, &resourceManager)
require.True(t, ok)
require.NoError(t, err)
})
}

func TestGetMergedConstraints(t *testing.T) {
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

// When no constraints present, all values are nil.
constraints, err := GetMergedConstraints(context.Background(), 0, model.NTSCType)
require.NoError(t, err)
require.Nil(t, constraints.PriorityLimit)
require.Nil(t, constraints.ResourceConstraints)

// Workspace priority limit set.
wkspLimit := 42
user := db.RequireMockUser(t, pgDB)
w := addWorkspacePriorityLimit(t, user, wkspLimit)
constraints, err = GetMergedConstraints(context.Background(), w.ID, model.NTSCType)
require.NoError(t, err)
require.Nil(t, constraints.ResourceConstraints)
require.Equal(t, wkspLimit, *constraints.PriorityLimit)

// Global limit overrides workspace limit.
globalLimit := 25
addConstraints(t, user, nil, fmt.Sprintf(`{"priority_limit": %d}`, globalLimit))
constraints, err = GetMergedConstraints(context.Background(), w.ID, model.NTSCType)
require.NoError(t, err)
require.Nil(t, constraints.ResourceConstraints)
require.Equal(t, globalLimit, *constraints.PriorityLimit)

// Workspace max slots set.
addConstraints(t, user, &w.ID, *DefaultConstraints())
constraints, err = GetMergedConstraints(context.Background(), w.ID, model.NTSCType)
require.NoError(t, err)
require.Equal(t, 8, *constraints.ResourceConstraints.MaxSlots) // defined in DefaultConstraintsStr
require.Equal(t, globalLimit, *constraints.PriorityLimit) // global constraint overrides workspace value
}

func addWorkspacePriorityLimit(t *testing.T, user model.User, limit int) model.Workspace {
ctx := context.Background()

// add a workspace to use
Expand All @@ -69,16 +206,26 @@ func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, lim
return w
}

func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) {
func addConstraints(t *testing.T, user model.User, wkspID *int, constraints string) {
ctx := context.Background()

constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit)
input := model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
WorkspaceID: nil,
WorkspaceID: wkspID,
Constraints: &constraints,
LastUpdatedBy: user.ID,
}
err := SetTaskConfigPolicies(ctx, &input)
require.NoError(t, err)
}

func defaultConfig() model.CommandConfig {
config := model.DefaultConfig(nil)

configPriority := 50
configMaxSlots := 12
config.Resources.Priority = &configPriority
config.Resources.MaxSlots = &configMaxSlots

return config
}

0 comments on commit 83a779e

Please sign in to comment.