From 8eb53f35aa93c79244044babc6920f8250c0f415 Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Mon, 14 Sep 2020 14:50:52 -0400 Subject: [PATCH] Implement OAuth 2.0 bearer token w/scopes --- auth/middleware.go | 31 +++++++++++++++++++++++++++++-- config/middleware.go | 13 ++++++++++++- directives.go | 26 +++++++++++++++++++++++--- server.go | 2 +- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/auth/middleware.go b/auth/middleware.go index 862c271..d0afff0 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -50,7 +50,7 @@ const ( ) const ( - AUTH_OAUTH = iota + AUTH_OAUTH_LEGACY = iota AUTH_OAUTH2 = iota AUTH_COOKIE = iota AUTH_INTERNAL = iota @@ -75,6 +75,7 @@ type User struct { // Only filled out if AuthMethod == AUTH_OAUTH2 OAuth2Token *OAuth2Token + Access map[string]string } func authError(w http.ResponseWriter, reason string, code int) { @@ -404,6 +405,32 @@ func OAuth2(db *sql.DB, token string, hash [64]byte, user.AuthMethod = AUTH_OAUTH2 user.OAuth2Token = ot + if ot.Scopes != "" { + user.Access = make(map[string]string) + for _, grant := range strings.Split(ot.Scopes, ",") { + var ( + service string + scope string + access string + ) + parts := strings.Split(grant, "/") + if len(parts) != 2 { + panic(errors.New("OAuth grant without service/scope format")) + } + service = parts[0] + parts = strings.Split(parts[1], ":") + scope = parts[0] + if len(parts) == 1 { + access = "RO" + } else { + access = parts[1] + } + if service == config.ServiceName(r.Context()) { + user.Access[scope] = access + } + } + } + ctx := context.WithValue(r.Context(), userCtxKey, &user) r = r.WithContext(ctx) @@ -481,7 +508,7 @@ func LegacyOAuth(db *sql.DB, bearer string, hash [64]byte, return } - user.AuthMethod = AUTH_OAUTH + user.AuthMethod = AUTH_OAUTH_LEGACY ctx := context.WithValue(r.Context(), userCtxKey, &user) diff --git a/config/middleware.go b/config/middleware.go index b7db5b7..b5aaa30 100644 --- a/config/middleware.go +++ b/config/middleware.go @@ -9,15 +9,18 @@ import ( ) var configCtxKey = &contextKey{"config"} +var serviceCtxKey = &contextKey{"name"} type contextKey struct { name string } -func Middleware(conf ini.File) func (next http.Handler) http.Handler { +func Middleware(conf ini.File, service string) func (next http.Handler) http.Handler { + svc := service return func (next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), configCtxKey, conf) + ctx = context.WithValue(ctx, serviceCtxKey, &svc) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) @@ -31,3 +34,11 @@ func ForContext(ctx context.Context) ini.File { } return raw } + +func ServiceName(ctx context.Context) string { + raw, ok := ctx.Value(serviceCtxKey).(*string) + if !ok { + panic(errors.New("Invalid config context")) + } + return *raw +} diff --git a/directives.go b/directives.go index fdd518b..785b1e5 100644 --- a/directives.go +++ b/directives.go @@ -22,10 +22,30 @@ func Internal(ctx context.Context, obj interface{}, func Access(ctx context.Context, obj interface{}, next graphql.Resolver, scope string, kind string) (interface{}, error) { - if auth.ForContext(ctx).AuthMethod == auth.AUTH_INTERNAL || - auth.ForContext(ctx).AuthMethod == auth.AUTH_COOKIE { + authctx := auth.ForContext(ctx) + + switch authctx.AuthMethod { + case auth.AUTH_INTERNAL: + case auth.AUTH_COOKIE: return next(ctx) + case auth.AUTH_OAUTH_LEGACY: + if kind == "RO" { + // Only legacy tokens with "*" scopes ever get this far + return next(ctx) + } + case auth.AUTH_OAUTH2: + if authctx.Access == nil { + return next(ctx) + } + if access, ok := authctx.Access[scope]; !ok { + break + } else if access == "RO" && kind == "RW" { + break + } + return next(ctx) + default: + panic(fmt.Errorf("Unknown auth method for access check")) } - panic(fmt.Errorf("TODO")) + return nil, fmt.Errorf("Access denied") } diff --git a/server.go b/server.go index 2fcbe60..17b7416 100644 --- a/server.go +++ b/server.go @@ -130,7 +130,7 @@ func MakeRouter(service string, conf ini.File, schema graphql.ExecutableSchema, requestsProcessed.Inc() }) }) - router.Use(config.Middleware(conf)) + router.Use(config.Middleware(conf, service)) router.Use(database.Middleware(db)) router.Use(redis.Middleware(rc)) router.Use(middleware.RealIP)