From e74d8b31c0f4ddec7143be166155f1c95abf096d Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 25 Apr 2024 13:01:36 +0200 Subject: [PATCH] server: unify logic to decode SASL response --- conn.go | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/conn.go b/conn.go index 57e7c9c..ec9f686 100644 --- a/conn.go +++ b/conn.go @@ -773,15 +773,11 @@ func (c *Conn) handleAuth(arg string) { // Parse client initial response if there is one var ir []byte if len(parts) > 1 { - if parts[1] == "=" { - ir = []byte{} - } else { - var err error - ir, err = base64.StdEncoding.DecodeString(parts[1]) - if err != nil { - c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data") - return - } + var err error + ir, err = decodeSASLResponse(parts[1]) + if err != nil { + c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data") + return } } @@ -820,14 +816,10 @@ func (c *Conn) handleAuth(arg string) { return } - if encoded == "=" { - response = []byte{} - } else { - response, err = base64.StdEncoding.DecodeString(encoded) - if err != nil { - c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data") - return - } + response, err = decodeSASLResponse(encoded) + if err != nil { + c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data") + return } } @@ -835,6 +827,13 @@ func (c *Conn) handleAuth(arg string) { c.didAuth = true } +func decodeSASLResponse(s string) ([]byte, error) { + if s == "=" { + return []byte{}, nil + } + return base64.StdEncoding.DecodeString(s) +} + func (c *Conn) authMechanisms() []string { if authSession, ok := c.Session().(AuthSession); ok { return authSession.AuthMechanisms()