diff --git a/sasl/aws_msk_iam_v2/README.md b/sasl/aws_msk_iam_v2/README.md new file mode 100644 index 00000000..2a7af8a3 --- /dev/null +++ b/sasl/aws_msk_iam_v2/README.md @@ -0,0 +1,60 @@ +# AWS MSK IAM V2 + +This extension provides a capability to get authenticated with [AWS Managed Apache Kafka](https://aws.amazon.com/msk/) +through AWS IAM. + +## How to use + +This module is an extension for MSK users and thus this is isolated from `kafka-go` module. +You can add this module to your dependency by running the command below. + +```shell +go get github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2 +``` + +You can use the `Mechanism` for SASL authentication, like below. + +```go +package main + +import ( + "context" + "crypto/tls" + "time" + + signer "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsCfg "github.com/aws/aws-sdk-go-v2/config" + "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2" +) + +func main() { + ctx := context.Background() + + // using aws-sdk-go-v2 + // NOTE: address error properly + + cfg, _ := awsCfg.LoadDefaultConfig(ctx) + creds, _ := cfg.Credentials.Retrieve(ctx) + m := &aws_msk_iam_v2.Mechanism{ + Signer: signer.NewSigner(), + Credentials: creds, + Region: "us-east-1", + SignTime: time.Now(), + Expiry: time.Minute * 5, + } + config := kafka.ReaderConfig{ + Brokers: []string{"https://localhost"}, + GroupID: "some-consumer-group", + GroupTopics: []string{"some-topic"}, + Dialer: &kafka.Dialer{ + Timeout: 10 * time.Second, + DualStack: true, + SASLMechanism: m, + TLS: &tls.Config{}, + }, + } +} + + +``` \ No newline at end of file diff --git a/sasl/aws_msk_iam_v2/go.mod b/sasl/aws_msk_iam_v2/go.mod new file mode 100644 index 00000000..69d81125 --- /dev/null +++ b/sasl/aws_msk_iam_v2/go.mod @@ -0,0 +1,10 @@ +module github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2 + +go 1.15 + +require ( + github.com/aws/aws-sdk-go-v2 v1.16.7 + github.com/aws/aws-sdk-go-v2/credentials v1.12.9 + github.com/segmentio/kafka-go v0.4.32 + github.com/stretchr/testify v1.7.1 +) diff --git a/sasl/aws_msk_iam_v2/go.sum b/sasl/aws_msk_iam_v2/go.sum new file mode 100644 index 00000000..ebef8f42 --- /dev/null +++ b/sasl/aws_msk_iam_v2/go.sum @@ -0,0 +1,49 @@ +github.com/aws/aws-sdk-go-v2 v1.16.7 h1:zfBwXus3u14OszRxGcqCDS4MfMCv10e8SMJ2r8Xm0Ns= +github.com/aws/aws-sdk-go-v2 v1.16.7/go.mod h1:6CpKuLXg2w7If3ABZCl/qZ6rEgwtjZTn4eAf4RcEyuw= +github.com/aws/aws-sdk-go-v2/credentials v1.12.9 h1:DloAJr0/jbvm0iVRFDFh8GlWxrOd9XKyX82U+dfVeZs= +github.com/aws/aws-sdk-go-v2/credentials v1.12.9/go.mod h1:2Vavxl1qqQXJ8MUcQZTsIEW8cwenFCWYXtLRPba3L/o= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.8/go.mod h1:oL1Q3KuCq1D4NykQnIvtRiBGLUXhcpY5pl6QZB2XEPU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.14/go.mod h1:kdjrMwHwrC3+FsKhNcCMJ7tUVj/8uSD5CZXeQ4wV6fM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.8/go.mod h1:ZIV8GYoC6WLBW5KGs+o4rsc65/ozd+eQ0L31XF5VDwk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.8/go.mod h1:rDVhIMAX9N2r8nWxDUlbubvvaFMnfsm+3jAV7q+rpM4= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.12/go.mod h1:MO4qguFjs3wPGcCSpQ7kOFTwRvb+eu+fn+1vKleGHUk= +github.com/aws/aws-sdk-go-v2/service/sts v1.16.9/go.mod h1:O1IvkYxr+39hRf960Us6j0x1P8pDqhTX+oXM5kQNl/Y= +github.com/aws/smithy-go v1.12.0 h1:gXpeZel/jPoWQ7OEmLIgCUnhkFftqNfwWUwAHSlp1v0= +github.com/aws/smithy-go v1.12.0/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= +github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE= +github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/kafka-go v0.4.32 h1:Ohr+9E+kDv/Ld2UPJN9hnKZRd2qgiqCmI8v2e1qlfLM= +github.com/segmentio/kafka-go v0.4.32/go.mod h1:JAPPIiY3MQIwVHj64CWOP0LsFFfQ7H0w69kuoxnMIS0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284 h1:rlLehGeYg6jfoyz/eDqDU1iRXLKfR42nnNh57ytKEWo= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc= +gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sasl/aws_msk_iam_v2/msk_iam.go b/sasl/aws_msk_iam_v2/msk_iam.go new file mode 100644 index 00000000..f6b06398 --- /dev/null +++ b/sasl/aws_msk_iam_v2/msk_iam.go @@ -0,0 +1,166 @@ +package aws_msk_iam_v2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + signer "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/segmentio/kafka-go/sasl" +) + +const ( + // These constants come from https://github.com/aws/aws-msk-iam-auth#details and + // https://github.com/aws/aws-msk-iam-auth/blob/main/src/main/java/software/amazon/msk/auth/iam/internals/AWS4SignedPayloadGenerator.java. + signAction = "kafka-cluster:Connect" + signPayload = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // the hex encoded SHA-256 of an empty string + signService = "kafka-cluster" + signVersion = "2020_10_22" + signActionKey = "action" + signHostKey = "host" + signUserAgentKey = "user-agent" + signVersionKey = "version" + queryActionKey = "Action" + queryExpiryKey = "X-Amz-Expires" +) + +var signUserAgent = fmt.Sprintf("kafka-go/sasl/aws_msk_iam/%s", runtime.Version()) + +// Mechanism implements sasl.Mechanism for the AWS_MSK_IAM mechanism, based on the official java implementation: +// https://github.com/aws/aws-msk-iam-auth +type Mechanism struct { + // The sigv4.Signer of aws-sdk-go-v2 to use when signing the request. Required. + Signer *signer.Signer + // The aws.Credentials of aws-sdk-go-v2. Required. + Credentials aws.Credentials + // The region where the msk cluster is hosted, e.g. "us-east-1". Required. + Region string + // The time the request is planned for. Optional, defaults to time.Now() at time of authentication. + SignTime time.Time + // The duration for which the presigned request is active. Optional, defaults to 5 minutes. + Expiry time.Duration +} + +func (m *Mechanism) Name() string { + return "AWS_MSK_IAM" +} + +func (m *Mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) { + // After the initial step, the authentication is complete + // kafka will return error if it rejected the credentials, so we'll only + // arrive here on success. + return true, nil, nil +} + +// Start produces the authentication values required for AWS_MSK_IAM. It produces the following json as a byte array, +// making use of the aws-sdk to produce the signed output. +// { +// "version" : "2020_10_22", +// "host" : "", +// "user-agent": "", +// "action": "kafka-cluster:Connect", +// "x-amz-algorithm" : "", +// "x-amz-credential" : "///kafka-cluster/aws4_request", +// "x-amz-date" : "", +// "x-amz-security-token" : "", +// "x-amz-signedheaders" : "host", +// "x-amz-expires" : "", +// "x-amz-signature" : "" +// } +func (m *Mechanism) Start(ctx context.Context) (sess sasl.StateMachine, ir []byte, err error) { + signedMap, err := m.preSign(ctx) + if err != nil { + return nil, nil, err + } + + signedJson, err := json.Marshal(signedMap) + return m, signedJson, err +} + +// preSign produces the authentication values required for AWS_MSK_IAM. +func (m *Mechanism) preSign(ctx context.Context) (map[string]string, error) { + req, err := buildReq(ctx, defaultExpiry(m.Expiry)) + if err != nil { + return nil, err + } + + signedUrl, header, err := m.Signer.PresignHTTP(ctx, m.Credentials, req, signPayload, signService, m.Region, defaultSignTime(m.SignTime)) + if err != nil { + return nil, err + } + + u, err := url.Parse(signedUrl) + if err != nil { + return nil, err + } + return buildSignedMap(u, header), nil +} + +// buildReq builds http.Request for aws PreSign. +func buildReq(ctx context.Context, expiry time.Duration) (*http.Request, error) { + query := url.Values{ + queryActionKey: {signAction}, + queryExpiryKey: {strconv.FormatInt(int64(expiry/time.Second), 10)}, + } + saslMeta := sasl.MetadataFromContext(ctx) + if saslMeta == nil { + return nil, errors.New("missing sasl metadata") + } + + signUrl := url.URL{ + Scheme: "kafka", + Host: saslMeta.Host, + Path: "/", + RawQuery: query.Encode(), + } + + req, err := http.NewRequest(http.MethodGet, signUrl.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// buildSignedMap builds signed string map which will be used to authenticate with MSK. +func buildSignedMap(u *url.URL, header http.Header) map[string]string { + signedMap := map[string]string{ + signVersionKey: signVersion, + signHostKey: u.Host, + signUserAgentKey: signUserAgent, + signActionKey: signAction, + } + // The protocol requires lowercase keys. + for key, vals := range header { + signedMap[strings.ToLower(key)] = vals[0] + } + for key, vals := range u.Query() { + signedMap[strings.ToLower(key)] = vals[0] + } + + return signedMap +} + +// defaultExpiry set default expiration time if user doesn't define Mechanism.Expiry. +func defaultExpiry(v time.Duration) time.Duration { + if v == 0 { + return 5 * time.Minute + } + return v +} + +// defaultSignTime set default sign time if user doesn't define Mechanism.SignTime. +func defaultSignTime(v time.Time) time.Time { + if v.IsZero() { + return time.Now() + } + return v +} diff --git a/sasl/aws_msk_iam_v2/msk_iam_test.go b/sasl/aws_msk_iam_v2/msk_iam_test.go new file mode 100644 index 00000000..d5acb6ad --- /dev/null +++ b/sasl/aws_msk_iam_v2/msk_iam_test.go @@ -0,0 +1,156 @@ +package aws_msk_iam_v2 + +import ( + "bytes" + "context" + "encoding/json" + "testing" + "time" + + "github.com/segmentio/kafka-go/sasl" + "github.com/stretchr/testify/assert" + + signer "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" +) + +const ( + accessKeyId = "ACCESS_KEY" + secretAccessKey = "SECRET_KEY" +) + +// using a fixed time allows the signature to be verifiable in a test +var signTime = time.Date(2021, 10, 14, 13, 5, 0, 0, time.UTC) + +func TestAwsMskIamMechanism(t *testing.T) { + creds, err := credentialsv2.NewStaticCredentialsProvider(accessKeyId, secretAccessKey, "").Retrieve(context.Background()) + if err != nil { + t.Fatal(err) + } + + ctxWithMetadata := func() context.Context { + return sasl.WithMetadata(context.Background(), &sasl.Metadata{ + Host: "localhost", + Port: 9092, + }) + + } + + tests := []struct { + description string + ctx func() context.Context + shouldFail bool + }{ + { + description: "with metadata", + ctx: ctxWithMetadata, + }, + { + description: "without metadata", + ctx: func() context.Context { + return context.Background() + }, + shouldFail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + ctx := tt.ctx() + + mskMechanism := &Mechanism{ + Signer: signer.NewSigner(), + Credentials: creds, + Region: "us-east-1", + SignTime: signTime, + } + + sess, auth, err := mskMechanism.Start(ctx) + if tt.shouldFail { // if error is expected + if err == nil { // but we don't find one + t.Fatal("error expected") + } else { // but we do find one + return // return early since the remaining assertions are irrelevant + } + } else { // if error is not expected (typical) + if err != nil { // but we do find one + t.Fatal(err) + } + } + + if sess != mskMechanism { + t.Error( + "Unexpected session", + "expected", mskMechanism, + "got", sess, + ) + } + + expectedMap := map[string]string{ + "version": "2020_10_22", + "action": "kafka-cluster:Connect", + "host": "localhost", + "user-agent": signUserAgent, + "x-amz-algorithm": "AWS4-HMAC-SHA256", + "x-amz-credential": "ACCESS_KEY/20211014/us-east-1/kafka-cluster/aws4_request", + "x-amz-date": "20211014T130500Z", + "x-amz-expires": "300", + "x-amz-signedheaders": "host", + "x-amz-signature": "6b8d25f9b45b9c7db9da855a49112d80379224153a27fd279c305a5b7940d1a7", + } + expectedAuth, err := json.Marshal(expectedMap) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expectedAuth, auth) { + t.Error("Unexpected authentication", + "expected", expectedAuth, + "got", auth, + ) + } + }) + } +} + +func TestDefaultExpiry(t *testing.T) { + expiry := time.Second * 5 + testCases := map[string]struct { + Expiry time.Duration + }{ + "with default": {Expiry: expiry}, + "without default": {}, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + actual := defaultExpiry(testCase.Expiry) + if testCase.Expiry == 0 { + assert.Equal(t, time.Minute*5, actual) + } else { + assert.Equal(t, expiry, actual) + } + + }) + } +} + +func TestDefaultSignTime(t *testing.T) { + testCases := map[string]struct { + SignTime time.Time + }{ + "with default": {SignTime: signTime}, + "without default": {}, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + actual := defaultSignTime(testCase.SignTime) + if testCase.SignTime.IsZero() { + assert.True(t, actual.After(signTime)) + } else { + assert.Equal(t, signTime, actual) + } + }) + } +}