auth: s/OAuth2Token/BearerToken/g

This also makes some tweaks to the organization of the crypto module to
make it easier for us to add more keys in the future, which will be
necessary for the internal token redesign.
This commit is contained in:
Drew DeVault 2021-07-29 10:01:59 +02:00
parent 39cd24bd09
commit 1d24fef3e2
6 changed files with 68 additions and 68 deletions

View File

@ -23,7 +23,7 @@ func ToTimestamp(t time.Time) Timestamp {
return Timestamp(t.UTC().Unix()) return Timestamp(t.UTC().Unix())
} }
type OAuth2Token struct { type BearerToken struct {
Version uint Version uint
Expires Timestamp Expires Timestamp
Grants string Grants string
@ -31,16 +31,16 @@ type OAuth2Token struct {
Username string Username string
} }
func (ot *OAuth2Token) Encode() string { func (bt *BearerToken) Encode() string {
plain, err := bare.Marshal(ot) plain, err := bare.Marshal(bt)
if err != nil { if err != nil {
panic(err) panic(err)
} }
mac := crypto.HMAC(plain) mac := crypto.BearerHMAC(plain)
return base64.RawStdEncoding.EncodeToString(append(plain, mac...)) return base64.RawStdEncoding.EncodeToString(append(plain, mac...))
} }
func DecodeToken(token string) *OAuth2Token { func DecodeBearerToken(token string) *BearerToken {
payload, err := base64.RawStdEncoding.DecodeString(token) payload, err := base64.RawStdEncoding.DecodeString(token)
if err != nil { if err != nil {
log.Printf("Invalid bearer token: invalid base64 %e", err) log.Printf("Invalid bearer token: invalid base64 %e", err)
@ -53,25 +53,25 @@ func DecodeToken(token string) *OAuth2Token {
mac := payload[len(payload)-32:] mac := payload[len(payload)-32:]
payload = payload[:len(payload)-32] payload = payload[:len(payload)-32]
if crypto.HMACVerify(payload, mac) == false { if crypto.BearerVerify(payload, mac) == false {
log.Printf("Invalid bearer token: HMAC verification failed (MAC: [%d]%s; payload: [%d]%s", log.Printf("Invalid bearer token: HMAC verification failed (MAC: [%d]%s; payload: [%d]%s",
len(mac), hex.EncodeToString(mac), len(payload), hex.EncodeToString(payload)) len(mac), hex.EncodeToString(mac), len(payload), hex.EncodeToString(payload))
return nil return nil
} }
var ot OAuth2Token var bt BearerToken
err = bare.Unmarshal(payload, &ot) err = bare.Unmarshal(payload, &bt)
if err != nil { if err != nil {
log.Printf("Invalid bearer token: BARE unmarshal failed: %e", err) log.Printf("Invalid bearer token: BARE unmarshal failed: %e", err)
return nil return nil
} }
if ot.Version != TokenVersion { if bt.Version != TokenVersion {
log.Printf("Invalid bearer token: invalid token version") log.Printf("Invalid bearer token: invalid token version")
return nil return nil
} }
if time.Now().UTC().After(ot.Expires.Time()) { if time.Now().UTC().After(bt.Expires.Time()) {
log.Printf("Invalid bearer token: token expired") log.Printf("Invalid bearer token: token expired")
return nil return nil
} }
return &ot return &bt
} }

View File

@ -27,75 +27,75 @@ network-key=tbuG-7Vh44vrDq1L_HKWkHnWrDOtJhEkPKPiauaLeuk=`))
} }
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
ot := &OAuth2Token{ bt := &BearerToken{
Version: TokenVersion, Version: TokenVersion,
Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Expires: ToTimestamp(time.Now().Add(30 * time.Minute)),
Grants: "", Grants: "",
ClientID: "", ClientID: "",
Username: "jdoe", Username: "jdoe",
} }
token := ot.Encode() token := bt.Encode()
bytes, err := base64.RawStdEncoding.DecodeString(token) bytes, err := base64.RawStdEncoding.DecodeString(token)
assert.Nil(t, err) assert.Nil(t, err)
mac := bytes[len(bytes)-32:] mac := bytes[len(bytes)-32:]
payload := bytes[:len(bytes)-32] payload := bytes[:len(bytes)-32]
assert.True(t, crypto.HMACVerify(payload, mac)) assert.True(t, crypto.BearerVerify(payload, mac))
var ot2 OAuth2Token var bt2 BearerToken
err = bare.Unmarshal(payload, &ot2) err = bare.Unmarshal(payload, &bt2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, ot.Version, ot2.Version) assert.Equal(t, bt.Version, bt2.Version)
assert.Equal(t, ot.Expires, ot2.Expires) assert.Equal(t, bt.Expires, bt2.Expires)
assert.Equal(t, ot.Grants, ot2.Grants) assert.Equal(t, bt.Grants, bt2.Grants)
assert.Equal(t, ot.ClientID, ot2.ClientID) assert.Equal(t, bt.ClientID, bt2.ClientID)
assert.Equal(t, ot.Username, ot2.Username) assert.Equal(t, bt.Username, bt2.Username)
} }
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
ot := &OAuth2Token{ bt := &BearerToken{
Version: TokenVersion, Version: TokenVersion,
Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Expires: ToTimestamp(time.Now().Add(30 * time.Minute)),
Grants: "", Grants: "",
ClientID: "", ClientID: "",
Username: "jdoe", Username: "jdoe",
} }
token := ot.Encode() token := bt.Encode()
ot2 := DecodeToken(token) bt2 := DecodeBearerToken(token)
assert.NotNil(t, ot2) assert.NotNil(t, bt2)
assert.Equal(t, ot.Version, ot2.Version) assert.Equal(t, bt.Version, bt2.Version)
assert.Equal(t, ot.Expires, ot2.Expires) assert.Equal(t, bt.Expires, bt2.Expires)
assert.Equal(t, ot.Grants, ot2.Grants) assert.Equal(t, bt.Grants, bt2.Grants)
assert.Equal(t, ot.ClientID, ot2.ClientID) assert.Equal(t, bt.ClientID, bt2.ClientID)
assert.Equal(t, ot.Username, ot2.Username) assert.Equal(t, bt.Username, bt2.Username)
// Expired token: // Expired token:
ot = &OAuth2Token{ bt = &BearerToken{
Version: TokenVersion, Version: TokenVersion,
Expires: ToTimestamp(time.Now().Add(-30 * time.Minute)), Expires: ToTimestamp(time.Now().Add(-30 * time.Minute)),
Grants: "", Grants: "",
ClientID: "", ClientID: "",
Username: "jdoe", Username: "jdoe",
} }
token = ot.Encode() token = bt.Encode()
ot2 = DecodeToken(token) bt2 = DecodeBearerToken(token)
assert.Nil(t, ot2) assert.Nil(t, bt2)
// Invalid MAC: // Invalid MAC:
ot = &OAuth2Token{ bt = &BearerToken{
Version: TokenVersion, Version: TokenVersion,
Expires: ToTimestamp(time.Now().Add(30 * time.Minute)), Expires: ToTimestamp(time.Now().Add(30 * time.Minute)),
Grants: "", Grants: "",
ClientID: "", ClientID: "",
Username: "jdoe", Username: "jdoe",
} }
plain, err := bare.Marshal(ot) plain, err := bare.Marshal(bt)
assert.Nil(t, err) assert.Nil(t, err)
mac := crypto.HMAC(plain) mac := crypto.BearerHMAC(plain)
ot.Username = "rdoe" bt.Username = "rdoe"
plain, err = bare.Marshal(ot) plain, err = bare.Marshal(bt)
assert.Nil(t, err) assert.Nil(t, err)
token = base64.RawStdEncoding.EncodeToString(append(plain, mac...)) token = base64.RawStdEncoding.EncodeToString(append(plain, mac...))
ot2 = DecodeToken(token) bt2 = DecodeBearerToken(token)
assert.Nil(t, ot2) assert.Nil(t, bt2)
} }

View File

@ -75,7 +75,7 @@ type AuthContext struct {
InternalAuth InternalAuth InternalAuth InternalAuth
// Only filled out if AuthMethod == AUTH_OAUTH2 // Only filled out if AuthMethod == AUTH_OAUTH2
OAuth2Token *OAuth2Token BearerToken *BearerToken
Access map[string]string Access map[string]string
} }
@ -503,15 +503,15 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter,
) )
wg.Add(2) wg.Add(2)
ot := DecodeToken(token) bt := DecodeBearerToken(token)
if ot == nil { if bt == nil {
authError(w, `Invalid or expired OAuth 2.0 bearer token`, http.StatusForbidden) authError(w, `Invalid or expired OAuth 2.0 bearer token`, http.StatusForbidden)
return return
} }
go func() { go func() {
defer wg.Done() defer wg.Done()
err = LookupUser(r.Context(), ot.Username, &auth) err = LookupUser(r.Context(), bt.Username, &auth)
if err != nil { if err != nil {
log.Printf("LookupUser: %e", err) log.Printf("LookupUser: %e", err)
atomic.AddInt32(&tempErr, 1) atomic.AddInt32(&tempErr, 1)
@ -523,7 +523,7 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter,
go func() { go func() {
defer wg.Done() defer wg.Done()
isRevoked, err := LookupTokenRevocation(r.Context(), isRevoked, err := LookupTokenRevocation(r.Context(),
ot.Username, hash, ot.ClientID) bt.Username, hash, bt.ClientID)
if err != nil { if err != nil {
log.Printf("LookupTokenRevocation: %e", err) log.Printf("LookupTokenRevocation: %e", err)
atomic.AddInt32(&tempErr, 1) atomic.AddInt32(&tempErr, 1)
@ -550,11 +550,11 @@ func OAuth2(token string, hash [64]byte, w http.ResponseWriter,
} }
auth.AuthMethod = AUTH_OAUTH2 auth.AuthMethod = AUTH_OAUTH2
auth.OAuth2Token = ot auth.BearerToken = bt
if ot.Grants != "" { if bt.Grants != "" {
auth.Access = make(map[string]string) auth.Access = make(map[string]string)
for _, grant := range strings.Split(ot.Grants, " ") { for _, grant := range strings.Split(bt.Grants, " ") {
var ( var (
service string service string
scope string scope string

View File

@ -12,6 +12,6 @@ import (
func main() { func main() {
conf := config.LoadConfig(":1111") conf := config.LoadConfig(":1111")
crypto.InitCrypto(conf) crypto.InitCrypto(conf)
tok := auth.DecodeToken(os.Args[1]) tok := auth.DecodeBearerToken(os.Args[1])
fmt.Printf("%+v\n", tok) fmt.Printf("%+v\n", tok)
} }

View File

@ -16,10 +16,10 @@ import (
) )
var ( var (
privateKey ed25519.PrivateKey webhookSk ed25519.PrivateKey
publicKey ed25519.PublicKey webhookPk ed25519.PublicKey
macKey []byte bearerKey []byte
fernetKey *fernet.Key fernetKey *fernet.Key
) )
func InitCrypto(config ini.File) { func InitCrypto(config ini.File) {
@ -31,8 +31,8 @@ func InitCrypto(config ini.File) {
if err != nil { if err != nil {
log.Fatalf("base64 decode webhooks private key: %v", err) log.Fatalf("base64 decode webhooks private key: %v", err)
} }
privateKey = ed25519.NewKeyFromSeed(seed) webhookSk = ed25519.NewKeyFromSeed(seed)
publicKey, _ = privateKey.Public().(ed25519.PublicKey) webhookPk, _ = webhookSk.Public().(ed25519.PublicKey)
b64fernet, ok := config.Get("sr.ht", "network-key") b64fernet, ok := config.Get("sr.ht", "network-key")
if !ok { if !ok {
@ -42,17 +42,17 @@ func InitCrypto(config ini.File) {
if err != nil { if err != nil {
log.Fatalf("Load Fernet network encryption key: %v", err) log.Fatalf("Load Fernet network encryption key: %v", err)
} }
mac := hmac.New(sha256.New, privateKey) mac := hmac.New(sha256.New, webhookSk)
mac.Write([]byte("sr.ht HMAC key")) mac.Write([]byte("sr.ht HMAC key"))
macKey = mac.Sum(nil) bearerKey = mac.Sum(nil)
} }
func Sign(payload []byte) []byte { func Sign(payload []byte) []byte {
return ed25519.Sign(privateKey, payload) return ed25519.Sign(webhookSk, payload)
} }
func Verify(payload, signature []byte) bool { func Verify(payload, signature []byte) bool {
return ed25519.Verify(publicKey, payload, signature) return ed25519.Verify(webhookPk, payload, signature)
} }
func Encrypt(payload []byte) []byte { func Encrypt(payload []byte) []byte {
@ -75,14 +75,14 @@ func DecryptWithExpiration(payload []byte, expiry time.Duration) []byte {
return fernet.VerifyAndDecrypt(payload, expiry, []*fernet.Key{fernetKey}) return fernet.VerifyAndDecrypt(payload, expiry, []*fernet.Key{fernetKey})
} }
func HMAC(payload []byte) []byte { func BearerHMAC(payload []byte) []byte {
mac := hmac.New(sha256.New, macKey) mac := hmac.New(sha256.New, bearerKey)
mac.Write(payload) mac.Write(payload)
return mac.Sum(nil) return mac.Sum(nil)
} }
func HMACVerify(payload []byte, signature []byte) bool { func BearerVerify(payload []byte, signature []byte) bool {
mac := hmac.New(sha256.New, macKey) mac := hmac.New(sha256.New, bearerKey)
mac.Write(payload) mac.Write(payload)
expected := mac.Sum(nil) expected := mac.Sum(nil)
return hmac.Equal(expected, signature) return hmac.Equal(expected, signature)

View File

@ -79,13 +79,13 @@ func TestEncryptWithExpire(t *testing.T) {
assert.Nil(t, dec) assert.Nil(t, dec)
} }
func TestHMAC(t *testing.T) { func TestBearerHMAC(t *testing.T) {
payload := []byte("Hello, world!") payload := []byte("Hello, world!")
mac := HMAC(payload) mac := BearerHMAC(payload)
valid := HMACVerify(payload, mac) valid := BearerVerify(payload, mac)
assert.True(t, valid) assert.True(t, valid)
valid = HMACVerify([]byte("Something else"), mac) valid = BearerVerify([]byte("Something else"), mac)
assert.False(t, valid) assert.False(t, valid)
} }