gql.sr.ht/auth/middleware.go

600 lines
14 KiB
Go

package auth
import (
"bytes"
"context"
"crypto/sha512"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/vaughan0/go-ini"
"github.com/vektah/gqlparser/gqlerror"
"git.sr.ht/~sircmpwn/gql.sr.ht/config"
"git.sr.ht/~sircmpwn/gql.sr.ht/crypto"
"git.sr.ht/~sircmpwn/gql.sr.ht/database"
)
var userCtxKey = &contextKey{"user"}
type contextKey struct {
name string
}
var (
oauthBearerRegex = regexp.MustCompile(`^[0-9a-f]{32}$`)
oauth2BearerRegex = regexp.MustCompile(`^[0-9a-zA-Z_+/]{33,}$`)
)
const (
USER_UNCONFIRMED = "unconfirmed"
USER_ACTIVE_NON_PAYING = "active_non_paying"
USER_ACTIVE_FREE = "active_free"
USER_ACTIVE_PAYING = "active_paying"
USER_ACTIVE_DELINQUENT = "active_delinquent"
USER_ADMIN = "admin"
USER_UNKNOWN = "unknown"
USER_SUSPENDED = "suspended"
)
const (
AUTH_OAUTH_LEGACY = iota
AUTH_OAUTH2 = iota
AUTH_COOKIE = iota
AUTH_INTERNAL = iota
)
// XXX: Rename to AuthContext
type User struct {
ID int
Created time.Time
Updated time.Time
Username string
Email string
UserType string
URL *string
Location *string
Bio *string
SuspensionNotice *string
AuthMethod int
// Only filled out if AuthMethod == AUTH_INTERNAL
InternalAuth InternalAuth
// Only filled out if AuthMethod == AUTH_OAUTH2
OAuth2Token *OAuth2Token
Access map[string]string
}
func authError(w http.ResponseWriter, reason string, code int) {
gqlerr := gqlerror.Errorf("Authentication error: %s", reason)
b, err := json.Marshal(gqlerr)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
w.Write(b)
}
func authForUsername(ctx context.Context, db *sql.DB, username string) (*User, error) {
var (
err error
rows *sql.Rows
user User
)
query := database.
Select(ctx, []string{
`u.id`, `u.username`,
`u.created`, `u.updated`,
`u.email`,
`u.user_type`,
`u.url`, `u.location`, `u.bio`,
`u.suspension_notice`,
}).
From(`"user" u`).
Where(`u.username = ?`, username)
if rows, err = query.RunWith(db).Query(); err != nil {
panic(err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
return nil, fmt.Errorf("Authenticating for unknown user %s", username)
}
if err := rows.Scan(&user.ID, &user.Username, &user.Created, &user.Updated,
&user.Email, &user.UserType, &user.URL, &user.Location, &user.Bio,
&user.SuspensionNotice); err != nil {
panic(err)
}
if rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
panic(errors.New("Multiple matching user accounts; invariant broken"))
}
if user.UserType == USER_SUSPENDED {
return nil, fmt.Errorf("Account suspended with the following notice: %s\nContact support", user.SuspensionNotice)
}
return &user, nil
}
type AuthCookie struct {
// The username of the authenticated user
Name string `json:"name"`
}
func cookieAuth(db *sql.DB, cookie *http.Cookie,
w http.ResponseWriter, r *http.Request, next http.Handler) {
payload := crypto.Decrypt([]byte(cookie.Value))
if payload == nil {
authError(w, "Invalid authentication cookie", http.StatusForbidden)
return
}
var auth AuthCookie
if err := json.Unmarshal(payload, &auth); err != nil {
panic(err) // Programmer error
}
user, err := authForUsername(r.Context(), db, auth.Name)
if err != nil {
authError(w, err.Error(), http.StatusForbidden)
return
}
user.AuthMethod = AUTH_COOKIE
ctx := context.WithValue(r.Context(), userCtxKey, user)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
type InternalAuth struct {
// The username of the authenticated user
Name string `json:"name"`
// An arbitrary identifier for this internal user, e.g. "git.sr.ht"
ClientID string `json:"client_id"`
// An arbitrary identifier for this internal node, e.g. "us-east-3.git.sr.ht"
NodeID string `json:"node_id"`
}
func internalAuth(internalNet []*net.IPNet, db *sql.DB, payload []byte,
w http.ResponseWriter, r *http.Request, next http.Handler) {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
panic(fmt.Errorf("Unable to parse remote address"))
}
var ok bool = false
for _, ipnet := range internalNet {
ok = ok || ipnet.Contains(ip)
if ok {
break
}
}
if !ok {
authError(w, "Invalid source IP for internal auth", http.StatusForbidden)
return
}
payload = crypto.DecryptWithExpiration(payload, 30*time.Second)
if payload == nil {
authError(w, "Invalid Authorization header", http.StatusForbidden)
return
}
var auth InternalAuth
if err := json.Unmarshal(payload, &auth); err != nil {
panic(err) // Programmer error
}
if auth.ClientID == "" || auth.NodeID == "" {
authError(w, "Invalid Authorization header", http.StatusForbidden)
}
user, err := authForUsername(r.Context(), db, auth.Name)
if err != nil {
authError(w, err.Error(), http.StatusForbidden)
return
}
user.AuthMethod = AUTH_INTERNAL
user.InternalAuth = auth
ctx := context.WithValue(r.Context(), userCtxKey, user)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
func OAuth2(db *sql.DB, token string, hash [64]byte,
w http.ResponseWriter, r *http.Request, next http.Handler) {
var (
err error
res int32
rows *sql.Rows
user User
wg sync.WaitGroup
)
wg.Add(2)
ot := DecodeToken(token)
if ot == nil {
authError(w, `Invalid or expired OAuth 2.0 bearer token`, http.StatusForbidden)
return
}
go func() {
defer wg.Done()
// XXX: This branch could probably be eliminated if we:
// - Revoked all tokens upon suspending a user
// - Deferred loading the additional user details from the database
// until they were necessary (e.g. resolving query { me })
query := database.
Select(r.Context(), []string{
`u.id`, `u.username`,
`u.created`, `u.updated`,
`u.email`,
`u.user_type`,
`u.url`, `u.location`, `u.bio`,
`u.suspension_notice`,
}).
From(`"user" u`).
Where(`u.username = ?`, ot.Username)
if rows, err = query.RunWith(db).Query(); err != nil {
panic(err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
log.Println("Failed to look up user associated with bearer token")
return
}
if err := rows.Scan(
&user.ID, &user.Username,
&user.Created, &user.Updated,
&user.Email,
&user.UserType,
&user.URL,
&user.Location,
&user.Bio,
&user.SuspensionNotice); err != nil {
panic(err)
}
if rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
panic(errors.New("Multiple users of the same username; invariant broken"))
}
atomic.AddInt32(&res, 1)
}()
go func() {
defer wg.Done()
// Fetch revocation status for this token
conf := config.ForContext(r.Context())
meta, ok := conf.Get("meta.sr.ht", "origin")
if !ok {
panic(errors.New("No meta.sr.ht origin specified in config.ini"))
}
type GraphQLQuery struct {
Query string `json:"query"`
Variables map[string]interface{} `json:"variables"`
}
type GraphQLResponse struct {
Data struct {
RevocationStatus bool `json:"tokenRevocationStatus"`
} `json:"data"`
}
query := GraphQLQuery{
Query: `
query RevocationStatus($hash: String!, $clientId: String) {
tokenRevocationStatus(hash: $hash, clientId: $clientId)
}`,
Variables: map[string]interface{} {
"hash": hex.EncodeToString(hash[:]),
"clientId": ot.ClientID,
},
}
body, err := json.Marshal(query)
if err != nil {
panic(err) // Programmer error
}
reader := bytes.NewBuffer(body)
req, err := http.NewRequestWithContext(r.Context(),
"POST", fmt.Sprintf("%s/query", meta), reader)
if err != nil {
log.Printf("http.NewRequest: %e")
return
}
req.Header.Add("Content-Type", "application/json")
auth := InternalAuth{
Name: ot.Username,
// TODO: Populate these better
ClientID: "gql.sr.ht",
NodeID: "gql.sr.ht",
}
authBlob, err := json.Marshal(&auth)
if err != nil {
panic(err) // Programmer error
}
req.Header.Add("Authorization", fmt.Sprintf("Internal %s",
crypto.Encrypt(authBlob)))
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Printf("http.Do: %e", err)
return
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("ioutil.ReadAll: %e")
return
}
if resp.StatusCode != 200 {
log.Printf("meta.sr.ht returned status %d: %s",
resp.StatusCode, string(respBody))
return
}
var result GraphQLResponse
if err = json.Unmarshal(respBody, &result); err != nil {
log.Printf("json.Unmarshal: %e")
return
}
if !result.Data.RevocationStatus {
atomic.AddInt32(&res, 1)
}
}()
wg.Wait()
if res != 2 {
authError(w, "Invalid or expired OAuth 2.0 bearer token", http.StatusForbidden)
return
}
if user.UserType == USER_SUSPENDED {
authError(w, fmt.Sprintf("Account suspended with the following notice: %s\nContact support",
user.SuspensionNotice), http.StatusForbidden)
return
}
user.AuthMethod = AUTH_OAUTH2
user.OAuth2Token = ot
if ot.Scopes != "" {
user.Access = make(map[string]string)
for _, grant := range strings.Split(ot.Scopes, ",") {
var (
service string
scope string
access string
)
parts := strings.Split(grant, "/")
if len(parts) != 2 {
panic(errors.New("OAuth grant without service/scope format"))
}
service = parts[0]
parts = strings.Split(parts[1], ":")
scope = parts[0]
if len(parts) == 1 {
access = "RO"
} else {
access = parts[1]
}
if service == config.ServiceName(r.Context()) {
user.Access[scope] = access
}
}
}
ctx := context.WithValue(r.Context(), userCtxKey, &user)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
// TODO: Remove legacy OAuth support
func LegacyOAuth(db *sql.DB, bearer string, hash [64]byte,
w http.ResponseWriter, r *http.Request, next http.Handler) {
var (
err error
expires time.Time
rows *sql.Rows
scopes string
user User
)
query := database.
Select(r.Context(), []string{
`ot.expires`,
`ot.scopes`,
`u.id`, `u.username`,
`u.created`, `u.updated`,
`u.email`,
`u.user_type`,
`u.url`, `u.location`, `u.bio`,
`u.suspension_notice`,
}).
From(`oauthtoken ot`).
Join(`"user" u ON u.id = ot.user_id`).
Where(`ot.token_hash = ?`, bearer)
if rows, err = query.RunWith(db).Query(); err != nil {
panic(err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
authError(w, "Invalid or expired OAuth token", http.StatusForbidden)
return
}
if err := rows.Scan(&expires, &scopes,
&user.ID, &user.Username,
&user.Created, &user.Updated,
&user.Email,
&user.UserType,
&user.URL,
&user.Location,
&user.Bio,
&user.SuspensionNotice); err != nil {
panic(err)
}
if rows.Next() {
if err := rows.Err(); err != nil {
panic(err)
}
panic(errors.New("Multiple matching OAuth tokens; invariant broken"))
}
if time.Now().UTC().After(expires) {
authError(w, "Invalid or expired OAuth token", http.StatusForbidden)
return
}
if user.UserType == USER_SUSPENDED {
authError(w, fmt.Sprintf("Account suspended with the following notice: %s\nContact support",
user.SuspensionNotice), http.StatusForbidden)
return
}
if scopes != "*" {
authError(w, "Presently, OAuth authentication to the GraphQL API is only supported for OAuth tokens with all permissions, namely '*'.", http.StatusForbidden)
return
}
user.AuthMethod = AUTH_OAUTH_LEGACY
ctx := context.WithValue(r.Context(), userCtxKey, &user)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
func Middleware(conf ini.File, apiconf string) func(http.Handler) http.Handler {
var internalNet []*net.IPNet
src, ok := conf.Get(apiconf, "internal-ipnet")
if !ok {
// Conservative default
src = "127.0.0.1/24,::1/64"
}
for _, cidr := range strings.Split(src, ",") {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
internalNet = append(internalNet, ipnet)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/query" {
next.ServeHTTP(w, r)
return
}
db := database.ForContext(r.Context())
cookie, err := r.Cookie("sr.ht.unified-login.v1")
if err == nil {
cookieAuth(db, cookie, w, r, next)
return
}
auth := r.Header.Get("Authorization")
if auth == "" {
authError(w, `Authorization header is required. Expected 'Authorization: Bearer <token>'`, http.StatusForbidden)
return
}
z := strings.SplitN(auth, " ", 2)
if len(z) != 2 {
authError(w, "Invalid Authorization header", http.StatusBadRequest)
return
}
var bearer string
switch z[0] {
case "Bearer":
token := []byte(z[1])
if oauth2BearerRegex.Match(token) {
hash := sha512.Sum512(token)
bearer = z[1]
OAuth2(db, bearer, hash, w, r, next)
return
}
if oauthBearerRegex.Match(token) {
hash := sha512.Sum512(token)
bearer = hex.EncodeToString(hash[:])
LegacyOAuth(db, bearer, hash, w, r, next)
return
}
authError(w, "Invalid OAuth bearer token", http.StatusBadRequest)
return
case "Internal":
payload := []byte(z[1])
internalAuth(internalNet, db, payload, w, r, next)
return
default:
authError(w, "Invalid Authorization header", http.StatusBadRequest)
return
}
})
}
}
func ForContext(ctx context.Context) *User {
raw, ok := ctx.Value(userCtxKey).(*User)
if !ok {
panic(errors.New("Invalid authentication context"))
}
return raw
}