Skip to content

Commit

Permalink
Support custom domains in the admin service (#5498)
Browse files Browse the repository at this point in the history
* Support custom domains in the `admin` service

* Centralize URL handling in one type

* Support WithCustomDomain

* Centralize usage of `admin.URLs`

* Fix lint

* Fix tests

* Self review

* Self review 2

* Confirmation message

* Update auth redirects

* Fix redirects on localhost

* Change RPC to only return org name and be unauthenticated

* Document auth redirects for custom domains
  • Loading branch information
begelundmuller committed Aug 23, 2024
1 parent 8174508 commit ae071ab
Show file tree
Hide file tree
Showing 39 changed files with 7,088 additions and 5,278 deletions.
11 changes: 10 additions & 1 deletion admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (
type Options struct {
DatabaseDriver string
DatabaseDSN string
ExternalURL string
FrontendURL string
ProvisionerSetJSON string
DefaultProvisioner string
ExternalURL string
VersionNumber string
VersionCommit string
MetricsProjectOrg string
Expand All @@ -30,6 +31,7 @@ type Options struct {

type Service struct {
DB database.DB
URLs *URLs
ProvisionerSet map[string]provisioner.Provisioner
Email *email.Client
Github Github
Expand All @@ -54,6 +56,12 @@ func New(ctx context.Context, opts *Options, logger *zap.Logger, issuer *auth.Is
logger.Fatal("error connecting to database", zap.Error(err))
}

// Init URLs
urls, err := NewURLs(opts.ExternalURL, opts.FrontendURL)
if err != nil {
logger.Fatal("error parsing URLs", zap.Error(err))
}

// Auto-run migrations
v1, err := db.FindMigrationVersion(ctx)
if err != nil {
Expand Down Expand Up @@ -97,6 +105,7 @@ func New(ctx context.Context, opts *Options, logger *zap.Logger, issuer *auth.Is

return &Service{
DB: db,
URLs: urls,
ProvisionerSet: provSet,
Email: emailClient,
Github: github,
Expand Down
2 changes: 2 additions & 0 deletions admin/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func (s *Service) InitOrganizationBilling(ctx context.Context, org *database.Org
org, err = s.DB.UpdateOrganization(ctx, org.ID, &database.UpdateOrganizationOptions{
Name: org.Name,
Description: org.Description,
CustomDomain: org.CustomDomain,
QuotaProjects: valOrDefault(plan.Quotas.NumProjects, org.QuotaProjects),
QuotaDeployments: valOrDefault(plan.Quotas.NumDeployments, org.QuotaDeployments),
QuotaSlotsTotal: valOrDefault(plan.Quotas.NumSlotsTotal, org.QuotaSlotsTotal),
Expand Down Expand Up @@ -159,6 +160,7 @@ func (s *Service) RepairOrgBilling(ctx context.Context, org *database.Organizati
org, err = s.DB.UpdateOrganization(ctx, org.ID, &database.UpdateOrganizationOptions{
Name: org.Name,
Description: org.Description,
CustomDomain: org.CustomDomain,
QuotaProjects: biggerOfInt(plan.Quotas.NumProjects, org.QuotaProjects),
QuotaDeployments: biggerOfInt(plan.Quotas.NumDeployments, org.QuotaDeployments),
QuotaSlotsTotal: biggerOfInt(plan.Quotas.NumSlotsTotal, org.QuotaSlotsTotal),
Expand Down
8 changes: 6 additions & 2 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type DB interface {
FindOrganizationsForUser(ctx context.Context, userID string, afterName string, limit int) ([]*Organization, error)
FindOrganization(ctx context.Context, id string) (*Organization, error)
FindOrganizationByName(ctx context.Context, name string) (*Organization, error)
FindOrganizationByCustomDomain(ctx context.Context, domain string) (*Organization, error)
CheckOrganizationHasOutsideUser(ctx context.Context, orgID, userID string) (bool, error)
CheckOrganizationHasPublicProjects(ctx context.Context, orgID string) (bool, error)
InsertOrganization(ctx context.Context, opts *InsertOrganizationOptions) (*Organization, error)
Expand Down Expand Up @@ -276,6 +277,7 @@ type Organization struct {
ID string
Name string
Description string
CustomDomain string `db:"custom_domain"`
AllUsergroupID *string `db:"all_usergroup_id"`
CreatedOn time.Time `db:"created_on"`
UpdatedOn time.Time `db:"updated_on"`
Expand All @@ -294,6 +296,7 @@ type Organization struct {
type InsertOrganizationOptions struct {
Name string `validate:"slug"`
Description string
CustomDomain string `validate:"omitempty,fqdn"`
QuotaProjects int
QuotaDeployments int
QuotaSlotsTotal int
Expand All @@ -309,6 +312,7 @@ type InsertOrganizationOptions struct {
type UpdateOrganizationOptions struct {
Name string `validate:"slug"`
Description string
CustomDomain string `validate:"omitempty,fqdn"`
QuotaProjects int
QuotaDeployments int
QuotaSlotsTotal int
Expand Down Expand Up @@ -786,7 +790,7 @@ type OrganizationWhitelistedDomain struct {
type InsertOrganizationWhitelistedDomainOptions struct {
OrgID string `validate:"required"`
OrgRoleID string `validate:"required"`
Domain string `validate:"domain"`
Domain string `validate:"fqdn"`
}

// OrganizationWhitelistedDomainWithJoinedRoleNames convenience type used for display-friendly representation of an OrganizationWhitelistedDomain.
Expand All @@ -807,7 +811,7 @@ type ProjectWhitelistedDomain struct {
type InsertProjectWhitelistedDomainOptions struct {
ProjectID string `validate:"required"`
ProjectRoleID string `validate:"required"`
Domain string `validate:"domain"`
Domain string `validate:"fqdn"`
}

type ProjectWhitelistedDomainWithJoinedRoleNames struct {
Expand Down
3 changes: 3 additions & 0 deletions admin/database/postgres/migrations/0041.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE orgs ADD COLUMN custom_domain TEXT NOT NULL DEFAULT '';

CREATE UNIQUE INDEX orgs_custom_domain_idx ON orgs (lower(custom_domain)) WHERE custom_domain <> '';
9 changes: 9 additions & 0 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ func (c *connection) FindOrganizationByName(ctx context.Context, name string) (*
return res, nil
}

func (c *connection) FindOrganizationByCustomDomain(ctx context.Context, domain string) (*database.Organization, error) {
res := &database.Organization{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM orgs WHERE lower(custom_domain)=lower($1)", domain).StructScan(res)
if err != nil {
return nil, parseErr("org", err)
}
return res, nil
}

func (c *connection) CheckOrganizationHasOutsideUser(ctx context.Context, orgID, userID string) (bool, error) {
var res bool
err := c.getDB(ctx).QueryRowxContext(ctx,
Expand Down
11 changes: 0 additions & 11 deletions admin/database/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ var validate *validator.Validate
// slugRegexp is used to validate identifying names (e.g. "rill-data", not "Rill Data").
var slugRegexp = regexp.MustCompile("^[_a-zA-Z0-9][-_a-zA-Z0-9]{2,39}$")

// Regular expression to match a domain name
var domainRegex = regexp.MustCompile(`^[a-zA-Z0-9]+([-.][a-zA-Z0-9]+)*\.[a-zA-Z]{2,}$`)

func init() {
validate = validator.New()

Expand All @@ -30,12 +27,4 @@ func init() {
if err != nil {
panic(err)
}

// Register "domain" validation rule
err = validate.RegisterValidation("domain", func(fl validator.FieldLevel) bool {
return domainRegex.MatchString(fl.Field().String())
})
if err != nil {
panic(err)
}
}
24 changes: 24 additions & 0 deletions admin/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@ func (s *Service) OrganizationPermissionsForUser(ctx context.Context, orgID, use
composite = unionOrgRoles(composite, role)
}

// If the org has a public project, all users get read access to it.
if !composite.ReadOrg {
ok, err := s.DB.CheckOrganizationHasPublicProjects(ctx, orgID)
if err != nil {
return nil, err
}
if ok {
composite.ReadOrg = true
composite.ReadProjects = true
}
}

// If the user is an outside member of one of the org's projects, they get read access to org as well.
if !composite.ReadOrg {
ok, err := s.DB.CheckOrganizationHasOutsideUser(ctx, orgID, userID)
if err != nil {
return nil, err
}
if ok {
composite.ReadOrg = true
composite.ReadProjects = true
}
}

return composite, nil
}

Expand Down
12 changes: 12 additions & 0 deletions admin/pkg/urlutil/urlutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package urlutil
import "net/url"

func WithQuery(urlString string, query map[string]string) (string, error) {
if len(query) == 0 {
return urlString, nil
}

parsedURL, err := url.Parse(urlString)
if err != nil {
return "", err
Expand All @@ -16,6 +20,14 @@ func WithQuery(urlString string, query map[string]string) (string, error) {
return parsedURL.String(), nil
}

func MustWithQuery(urlString string, query map[string]string) string {
newURL, err := WithQuery(urlString, query)
if err != nil {
panic(err)
}
return newURL
}

func MustJoinURL(base string, elem ...string) string {
joinedURL, err := url.JoinPath(base, elem...)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions admin/server/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ func (s *Server) GetAlertMeta(ctx context.Context, req *adminv1.GetAlertMetaRequ
}

return &adminv1.GetAlertMetaResponse{
OpenUrl: s.urls.alertOpen(org.Name, proj.Name, req.Alert),
EditUrl: s.urls.alertEdit(org.Name, proj.Name, req.Alert),
OpenUrl: s.admin.URLs.AlertOpen(org.Name, proj.Name, req.Alert),
EditUrl: s.admin.URLs.AlertEdit(org.Name, proj.Name, req.Alert),
QueryForAttributes: attrPB,
}, nil
}
Expand Down
11 changes: 1 addition & 10 deletions admin/server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package auth

import (
"context"
"net/url"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/rilldata/rill/admin"
Expand All @@ -16,8 +15,6 @@ type AuthenticatorOptions struct {
AuthDomain string
AuthClientID string
AuthClientSecret string
ExternalURL string
FrontendURL string
}

// Authenticator wraps functionality for admin server auth.
Expand All @@ -39,16 +36,10 @@ func NewAuthenticator(logger *zap.Logger, adm *admin.Service, cookieStore *cooki
return nil, err
}

// Auth callback URL is fixed. See RegisterEndpoints.
redirectURL, err := url.JoinPath(opts.ExternalURL, "/auth/callback")
if err != nil {
return nil, err
}

oauth2Config := oauth2.Config{
ClientID: opts.AuthClientID,
ClientSecret: opts.AuthClientSecret,
RedirectURL: redirectURL,
RedirectURL: adm.URLs.AuthLoginCallback(),
Endpoint: oidcProvider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
}
Expand Down
19 changes: 3 additions & 16 deletions admin/server/auth/device_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/admin/pkg/oauth"
"github.com/rilldata/rill/admin/pkg/urlutil"
)

const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code"
Expand Down Expand Up @@ -76,33 +75,21 @@ func (a *Authenticator) handleDeviceCodeRequest(w http.ResponseWriter, r *http.R
return
}

verificationURI, err := url.JoinPath(a.opts.FrontendURL, "/-/auth/device")
if err != nil {
internalServerError(w, fmt.Errorf("failed to create verification uri: %w", err))
return
}

// add a "-" after the 4th character
readableUserCode := authCode.UserCode[:4] + "-" + authCode.UserCode[4:]

qry := map[string]string{"user_code": readableUserCode}
if values.Get("redirect") != "" {
qry["redirect"] = values.Get("redirect")
} else {
qry["redirect"] = urlutil.MustJoinURL(a.opts.FrontendURL, "/-/auth/cli/success")
}

verificationCompleteURI, err := urlutil.WithQuery(verificationURI, qry)
if err != nil {
internalServerError(w, fmt.Errorf("failed to create verification uri: %w", err))
return
qry["redirect"] = a.admin.URLs.AuthCLISuccessUI()
}

resp := DeviceCodeResponse{
DeviceCode: authCode.DeviceCode,
UserCode: readableUserCode,
VerificationURI: verificationURI,
VerificationCompleteURI: verificationCompleteURI,
VerificationURI: a.admin.URLs.AuthVerifyDeviceUI(nil),
VerificationCompleteURI: a.admin.URLs.AuthVerifyDeviceUI(qry),
ExpiresIn: int(admin.DeviceAuthCodeTTL.Seconds()),
PollingInterval: 5,
}
Expand Down
Loading

1 comment on commit ae071ab

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.