From 1d24fef3e29dc2c200e47d83f0eb7dddc1a23afe Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Thu, 29 Jul 2021 10:01:59 +0200 Subject: [PATCH] auth: s/OAuth2Token/BearerToken/g This also makes some tweaks to the organization of the crypto module to make it easier for us to add more keys in the future, which will be necessary for the internal token redesign. --- auth/{token.go => bearer.go} | 22 +++++----- auth/{token_test.go => bearer_test.go} | 60 +++++++++++++------------- auth/middleware.go | 16 +++---- cmd/token/main.go | 2 +- crypto/crypto.go | 28 ++++++------ crypto/crypto_test.go | 8 ++-- 6 files changed, 68 insertions(+), 68 deletions(-) rename auth/{token.go => bearer.go} (77%) rename auth/{token_test.go => bearer_test.go} (59%) diff --git a/auth/token.go b/auth/bearer.go similarity index 77% rename from auth/token.go rename to auth/bearer.go index 6076e02..8ea49df 100644 --- a/auth/token.go +++ b/auth/bearer.go @@ -23,7 +23,7 @@ func ToTimestamp(t time.Time) Timestamp { return Timestamp(t.UTC().Unix()) } -type OAuth2Token struct { +type BearerToken struct { Version uint Expires Timestamp Grants string @@ -31,16 +31,16 @@ type OAuth2Token struct { Username string } -func (ot *OAuth2Token) Encode() string { - plain, err := bare.Marshal(ot) +func (bt *BearerToken) Encode() string { + plain, err := bare.Marshal(bt) if err != nil { panic(err) } - mac := crypto.HMAC(plain) + mac := crypto.BearerHMAC(plain) return base64.RawStdEncoding.EncodeToString(append(plain, mac...)) } -func DecodeToken(token string) *OAuth2Token { +func DecodeBearerToken(token string) *BearerToken { payload, err := base64.RawStdEncoding.DecodeString(token) if err != nil { log.Printf("Invalid bearer token: invalid base64 %e", err) @@ -53,25 +53,25 @@ func DecodeToken(token string) *OAuth2Token { mac := payload[len(payload)-32:] payload = payload[:len(payload)-32] - if crypto.HMACVerify(payload, mac) == false { + if crypto.BearerVerify(payload, mac) == false { log.Printf("Invalid bearer token: HMAC verification failed (MAC: [%d]%s; payload: [%d]%s", len(mac), hex.EncodeToString(mac), len(payload), hex.EncodeToString(payload)) return nil } - var ot OAuth2Token - err = bare.Unmarshal(payload, &ot) + var bt BearerToken + err = bare.Unmarshal(payload, &bt) if err != nil { log.Printf("Invalid bearer token: BARE unmarshal failed: %e", err) return nil } - if ot.Version != TokenVersion { + if bt.Version != TokenVersion { log.Printf("Invalid bearer token: invalid token version") return nil } - if time.Now().UTC().After(ot.Expires.Time()) { + if time.Now().UTC().After(bt.Expires.Time()) { log.Printf("Invalid bearer token: token expired") return nil } - return &ot + return &bt } diff --git a/auth/token_test.go b/auth/bearer_test.go similarity index 59% rename from auth/token_test.go rename to auth/bearer_test.go index e33ae09..0ded53c 100644 --- a/auth/token_test.go +++ b/auth/bearer_test.go @@ -27,75 +27,75 @@ network-key=tbuG-7Vh44vrDq1L_HKWkHnWrDOtJhEkPKPiauaLeuk=`)) } func TestEncode(t *testing.T) { - ot := &OAuth2Token{ + bt := &BearerToken{ Version: TokenVersion, Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Grants: "", ClientID: "", Username: "jdoe", } - token := ot.Encode() + token := bt.Encode() bytes, err := base64.RawStdEncoding.DecodeString(token) assert.Nil(t, err) mac := bytes[len(bytes)-32:] payload := bytes[:len(bytes)-32] - assert.True(t, crypto.HMACVerify(payload, mac)) + assert.True(t, crypto.BearerVerify(payload, mac)) - var ot2 OAuth2Token - err = bare.Unmarshal(payload, &ot2) + var bt2 BearerToken + err = bare.Unmarshal(payload, &bt2) assert.Nil(t, err) - assert.Equal(t, ot.Version, ot2.Version) - assert.Equal(t, ot.Expires, ot2.Expires) - assert.Equal(t, ot.Grants, ot2.Grants) - assert.Equal(t, ot.ClientID, ot2.ClientID) - assert.Equal(t, ot.Username, ot2.Username) + assert.Equal(t, bt.Version, bt2.Version) + assert.Equal(t, bt.Expires, bt2.Expires) + assert.Equal(t, bt.Grants, bt2.Grants) + assert.Equal(t, bt.ClientID, bt2.ClientID) + assert.Equal(t, bt.Username, bt2.Username) } func TestDecode(t *testing.T) { - ot := &OAuth2Token{ + bt := &BearerToken{ Version: TokenVersion, Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Grants: "", ClientID: "", Username: "jdoe", } - token := ot.Encode() - ot2 := DecodeToken(token) - assert.NotNil(t, ot2) - assert.Equal(t, ot.Version, ot2.Version) - assert.Equal(t, ot.Expires, ot2.Expires) - assert.Equal(t, ot.Grants, ot2.Grants) - assert.Equal(t, ot.ClientID, ot2.ClientID) - assert.Equal(t, ot.Username, ot2.Username) + token := bt.Encode() + bt2 := DecodeBearerToken(token) + assert.NotNil(t, bt2) + assert.Equal(t, bt.Version, bt2.Version) + assert.Equal(t, bt.Expires, bt2.Expires) + assert.Equal(t, bt.Grants, bt2.Grants) + assert.Equal(t, bt.ClientID, bt2.ClientID) + assert.Equal(t, bt.Username, bt2.Username) // Expired token: - ot = &OAuth2Token{ + bt = &BearerToken{ Version: TokenVersion, Expires: ToTimestamp(time.Now().Add(-30 * time.Minute)), Grants: "", ClientID: "", Username: "jdoe", } - token = ot.Encode() - ot2 = DecodeToken(token) - assert.Nil(t, ot2) + token = bt.Encode() + bt2 = DecodeBearerToken(token) + assert.Nil(t, bt2) // Invalid MAC: - ot = &OAuth2Token{ + bt = &BearerToken{ Version: TokenVersion, Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Grants: "", ClientID: "", Username: "jdoe", } - plain, err := bare.Marshal(ot) + plain, err := bare.Marshal(bt) assert.Nil(t, err) - mac := crypto.HMAC(plain) - ot.Username = "rdoe" - plain, err = bare.Marshal(ot) + mac := crypto.BearerHMAC(plain) + bt.Username = "rdoe" + plain, err = bare.Marshal(bt) assert.Nil(t, err) token = base64.RawStdEncoding.EncodeToString(append(plain, mac...)) - ot2 = DecodeToken(token) - assert.Nil(t, ot2) + bt2 = DecodeBearerToken(token) + assert.Nil(t, bt2) } diff --git a/auth/middleware.go b/auth/middleware.go index b04d0a5..038a00e 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -75,7 +75,7 @@ type AuthContext struct { InternalAuth InternalAuth // Only filled out if AuthMethod == AUTH_OAUTH2 - OAuth2Token *OAuth2Token + BearerToken *BearerToken Access map[string]string } @@ -503,15 +503,15 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter, ) wg.Add(2) - ot := DecodeToken(token) - if ot == nil { + bt := DecodeBearerToken(token) + if bt == nil { authError(w, `Invalid or expired OAuth 2.0 bearer token`, http.StatusForbidden) return } go func() { defer wg.Done() - err = LookupUser(r.Context(), ot.Username, &auth) + err = LookupUser(r.Context(), bt.Username, &auth) if err != nil { log.Printf("LookupUser: %e", err) atomic.AddInt32(&tempErr, 1) @@ -523,7 +523,7 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter, go func() { defer wg.Done() isRevoked, err := LookupTokenRevocation(r.Context(), - ot.Username, hash, ot.ClientID) + bt.Username, hash, bt.ClientID) if err != nil { log.Printf("LookupTokenRevocation: %e", err) atomic.AddInt32(&tempErr, 1) @@ -550,11 +550,11 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter, } auth.AuthMethod = AUTH_OAUTH2 - auth.OAuth2Token = ot + auth.BearerToken = bt - if ot.Grants != "" { + if bt.Grants != "" { auth.Access = make(map[string]string) - for _, grant := range strings.Split(ot.Grants, " ") { + for _, grant := range strings.Split(bt.Grants, " ") { var ( service string scope string diff --git a/cmd/token/main.go b/cmd/token/main.go index a5c7f47..de31da4 100644 --- a/cmd/token/main.go +++ b/cmd/token/main.go @@ -12,6 +12,6 @@ import ( func main() { conf := config.LoadConfig(":1111") crypto.InitCrypto(conf) - tok := auth.DecodeToken(os.Args[1]) + tok := auth.DecodeBearerToken(os.Args[1]) fmt.Printf("%+v\n", tok) } diff --git a/crypto/crypto.go b/crypto/crypto.go index babfabd..bd6dc78 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -16,10 +16,10 @@ import ( ) var ( - privateKey ed25519.PrivateKey - publicKey ed25519.PublicKey - macKey []byte - fernetKey *fernet.Key + webhookSk ed25519.PrivateKey + webhookPk ed25519.PublicKey + bearerKey []byte + fernetKey *fernet.Key ) func InitCrypto(config ini.File) { @@ -31,8 +31,8 @@ func InitCrypto(config ini.File) { if err != nil { log.Fatalf("base64 decode webhooks private key: %v", err) } - privateKey = ed25519.NewKeyFromSeed(seed) - publicKey, _ = privateKey.Public().(ed25519.PublicKey) + webhookSk = ed25519.NewKeyFromSeed(seed) + webhookPk, _ = webhookSk.Public().(ed25519.PublicKey) b64fernet, ok := config.Get("sr.ht", "network-key") if !ok { @@ -42,17 +42,17 @@ func InitCrypto(config ini.File) { if err != nil { log.Fatalf("Load Fernet network encryption key: %v", err) } - mac := hmac.New(sha256.New, privateKey) + mac := hmac.New(sha256.New, webhookSk) mac.Write([]byte("sr.ht HMAC key")) - macKey = mac.Sum(nil) + bearerKey = mac.Sum(nil) } func Sign(payload []byte) []byte { - return ed25519.Sign(privateKey, payload) + return ed25519.Sign(webhookSk, payload) } func Verify(payload, signature []byte) bool { - return ed25519.Verify(publicKey, payload, signature) + return ed25519.Verify(webhookPk, payload, signature) } func Encrypt(payload []byte) []byte { @@ -75,14 +75,14 @@ func DecryptWithExpiration(payload []byte, expiry time.Duration) []byte { return fernet.VerifyAndDecrypt(payload, expiry, []*fernet.Key{fernetKey}) } -func HMAC(payload []byte) []byte { - mac := hmac.New(sha256.New, macKey) +func BearerHMAC(payload []byte) []byte { + mac := hmac.New(sha256.New, bearerKey) mac.Write(payload) return mac.Sum(nil) } -func HMACVerify(payload []byte, signature []byte) bool { - mac := hmac.New(sha256.New, macKey) +func BearerVerify(payload []byte, signature []byte) bool { + mac := hmac.New(sha256.New, bearerKey) mac.Write(payload) expected := mac.Sum(nil) return hmac.Equal(expected, signature) diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index 7cc93b5..120e762 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -79,13 +79,13 @@ func TestEncryptWithExpire(t *testing.T) { assert.Nil(t, dec) } -func TestHMAC(t *testing.T) { +func TestBearerHMAC(t *testing.T) { payload := []byte("Hello, world!") - mac := HMAC(payload) + mac := BearerHMAC(payload) - valid := HMACVerify(payload, mac) + valid := BearerVerify(payload, mac) assert.True(t, valid) - valid = HMACVerify([]byte("Something else"), mac) + valid = BearerVerify([]byte("Something else"), mac) assert.False(t, valid) }