mirror of https://git.sr.ht/~sircmpwn/gql.sr.ht
600 lines
14 KiB
Go
600 lines
14 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha512"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/vaughan0/go-ini"
|
|
"github.com/vektah/gqlparser/gqlerror"
|
|
|
|
"git.sr.ht/~sircmpwn/gql.sr.ht/config"
|
|
"git.sr.ht/~sircmpwn/gql.sr.ht/crypto"
|
|
"git.sr.ht/~sircmpwn/gql.sr.ht/database"
|
|
)
|
|
|
|
var userCtxKey = &contextKey{"user"}
|
|
|
|
type contextKey struct {
|
|
name string
|
|
}
|
|
|
|
var (
|
|
oauthBearerRegex = regexp.MustCompile(`^[0-9a-f]{32}$`)
|
|
oauth2BearerRegex = regexp.MustCompile(`^[0-9a-zA-Z_+/]{33,}$`)
|
|
)
|
|
|
|
const (
|
|
USER_UNCONFIRMED = "unconfirmed"
|
|
USER_ACTIVE_NON_PAYING = "active_non_paying"
|
|
USER_ACTIVE_FREE = "active_free"
|
|
USER_ACTIVE_PAYING = "active_paying"
|
|
USER_ACTIVE_DELINQUENT = "active_delinquent"
|
|
USER_ADMIN = "admin"
|
|
USER_UNKNOWN = "unknown"
|
|
USER_SUSPENDED = "suspended"
|
|
)
|
|
|
|
const (
|
|
AUTH_OAUTH_LEGACY = iota
|
|
AUTH_OAUTH2 = iota
|
|
AUTH_COOKIE = iota
|
|
AUTH_INTERNAL = iota
|
|
)
|
|
|
|
// XXX: Rename to AuthContext
|
|
type User struct {
|
|
ID int
|
|
Created time.Time
|
|
Updated time.Time
|
|
Username string
|
|
Email string
|
|
UserType string
|
|
URL *string
|
|
Location *string
|
|
Bio *string
|
|
SuspensionNotice *string
|
|
AuthMethod int
|
|
|
|
// Only filled out if AuthMethod == AUTH_INTERNAL
|
|
InternalAuth InternalAuth
|
|
|
|
// Only filled out if AuthMethod == AUTH_OAUTH2
|
|
OAuth2Token *OAuth2Token
|
|
Access map[string]string
|
|
}
|
|
|
|
func authError(w http.ResponseWriter, reason string, code int) {
|
|
gqlerr := gqlerror.Errorf("Authentication error: %s", reason)
|
|
b, err := json.Marshal(gqlerr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
w.Write(b)
|
|
}
|
|
|
|
func authForUsername(ctx context.Context, db *sql.DB, username string) (*User, error) {
|
|
var (
|
|
err error
|
|
rows *sql.Rows
|
|
user User
|
|
)
|
|
|
|
query := database.
|
|
Select(ctx, []string{
|
|
`u.id`, `u.username`,
|
|
`u.created`, `u.updated`,
|
|
`u.email`,
|
|
`u.user_type`,
|
|
`u.url`, `u.location`, `u.bio`,
|
|
`u.suspension_notice`,
|
|
}).
|
|
From(`"user" u`).
|
|
Where(`u.username = ?`, username)
|
|
if rows, err = query.RunWith(db).Query(); err != nil {
|
|
panic(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
return nil, fmt.Errorf("Authenticating for unknown user %s", username)
|
|
}
|
|
if err := rows.Scan(&user.ID, &user.Username, &user.Created, &user.Updated,
|
|
&user.Email, &user.UserType, &user.URL, &user.Location, &user.Bio,
|
|
&user.SuspensionNotice); err != nil {
|
|
panic(err)
|
|
}
|
|
if rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
panic(errors.New("Multiple matching user accounts; invariant broken"))
|
|
}
|
|
|
|
if user.UserType == USER_SUSPENDED {
|
|
return nil, fmt.Errorf("Account suspended with the following notice: %s\nContact support", user.SuspensionNotice)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
type AuthCookie struct {
|
|
// The username of the authenticated user
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
func cookieAuth(db *sql.DB, cookie *http.Cookie,
|
|
w http.ResponseWriter, r *http.Request, next http.Handler) {
|
|
|
|
payload := crypto.Decrypt([]byte(cookie.Value))
|
|
if payload == nil {
|
|
authError(w, "Invalid authentication cookie", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
var auth AuthCookie
|
|
if err := json.Unmarshal(payload, &auth); err != nil {
|
|
panic(err) // Programmer error
|
|
}
|
|
|
|
user, err := authForUsername(r.Context(), db, auth.Name)
|
|
if err != nil {
|
|
authError(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
user.AuthMethod = AUTH_COOKIE
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
|
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
|
|
type InternalAuth struct {
|
|
// The username of the authenticated user
|
|
Name string `json:"name"`
|
|
|
|
// An arbitrary identifier for this internal user, e.g. "git.sr.ht"
|
|
ClientID string `json:"client_id"`
|
|
|
|
// An arbitrary identifier for this internal node, e.g. "us-east-3.git.sr.ht"
|
|
NodeID string `json:"node_id"`
|
|
}
|
|
|
|
func internalAuth(internalNet []*net.IPNet, db *sql.DB, payload []byte,
|
|
w http.ResponseWriter, r *http.Request, next http.Handler) {
|
|
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
host = r.RemoteAddr
|
|
}
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
panic(fmt.Errorf("Unable to parse remote address"))
|
|
}
|
|
var ok bool = false
|
|
for _, ipnet := range internalNet {
|
|
ok = ok || ipnet.Contains(ip)
|
|
if ok {
|
|
break
|
|
}
|
|
}
|
|
if !ok {
|
|
authError(w, "Invalid source IP for internal auth", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
payload = crypto.DecryptWithExpiration(payload, 30*time.Second)
|
|
if payload == nil {
|
|
authError(w, "Invalid Authorization header", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
var auth InternalAuth
|
|
if err := json.Unmarshal(payload, &auth); err != nil {
|
|
panic(err) // Programmer error
|
|
}
|
|
|
|
if auth.ClientID == "" || auth.NodeID == "" {
|
|
authError(w, "Invalid Authorization header", http.StatusForbidden)
|
|
}
|
|
|
|
user, err := authForUsername(r.Context(), db, auth.Name)
|
|
if err != nil {
|
|
authError(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
user.AuthMethod = AUTH_INTERNAL
|
|
user.InternalAuth = auth
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
|
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
|
|
func OAuth2(db *sql.DB, token string, hash [64]byte,
|
|
w http.ResponseWriter, r *http.Request, next http.Handler) {
|
|
|
|
var (
|
|
err error
|
|
res int32
|
|
rows *sql.Rows
|
|
user User
|
|
wg sync.WaitGroup
|
|
)
|
|
wg.Add(2)
|
|
|
|
ot := DecodeToken(token)
|
|
if ot == nil {
|
|
authError(w, `Invalid or expired OAuth 2.0 bearer token`, http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
// XXX: This branch could probably be eliminated if we:
|
|
// - Revoked all tokens upon suspending a user
|
|
// - Deferred loading the additional user details from the database
|
|
// until they were necessary (e.g. resolving query { me })
|
|
|
|
query := database.
|
|
Select(r.Context(), []string{
|
|
`u.id`, `u.username`,
|
|
`u.created`, `u.updated`,
|
|
`u.email`,
|
|
`u.user_type`,
|
|
`u.url`, `u.location`, `u.bio`,
|
|
`u.suspension_notice`,
|
|
}).
|
|
From(`"user" u`).
|
|
Where(`u.username = ?`, ot.Username)
|
|
if rows, err = query.RunWith(db).Query(); err != nil {
|
|
panic(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
log.Println("Failed to look up user associated with bearer token")
|
|
return
|
|
}
|
|
if err := rows.Scan(
|
|
&user.ID, &user.Username,
|
|
&user.Created, &user.Updated,
|
|
&user.Email,
|
|
&user.UserType,
|
|
&user.URL,
|
|
&user.Location,
|
|
&user.Bio,
|
|
&user.SuspensionNotice); err != nil {
|
|
panic(err)
|
|
}
|
|
if rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
panic(errors.New("Multiple users of the same username; invariant broken"))
|
|
}
|
|
|
|
atomic.AddInt32(&res, 1)
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
// Fetch revocation status for this token
|
|
|
|
conf := config.ForContext(r.Context())
|
|
meta, ok := conf.Get("meta.sr.ht", "origin")
|
|
if !ok {
|
|
panic(errors.New("No meta.sr.ht origin specified in config.ini"))
|
|
}
|
|
|
|
type GraphQLQuery struct {
|
|
Query string `json:"query"`
|
|
Variables map[string]interface{} `json:"variables"`
|
|
}
|
|
type GraphQLResponse struct {
|
|
Data struct {
|
|
RevocationStatus bool `json:"tokenRevocationStatus"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
query := GraphQLQuery{
|
|
Query: `
|
|
query RevocationStatus($hash: String!, $clientId: String) {
|
|
tokenRevocationStatus(hash: $hash, clientId: $clientId)
|
|
}`,
|
|
Variables: map[string]interface{} {
|
|
"hash": hex.EncodeToString(hash[:]),
|
|
"clientId": ot.ClientID,
|
|
},
|
|
}
|
|
|
|
body, err := json.Marshal(query)
|
|
if err != nil {
|
|
panic(err) // Programmer error
|
|
}
|
|
|
|
reader := bytes.NewBuffer(body)
|
|
req, err := http.NewRequestWithContext(r.Context(),
|
|
"POST", fmt.Sprintf("%s/query", meta), reader)
|
|
if err != nil {
|
|
log.Printf("http.NewRequest: %e")
|
|
return
|
|
}
|
|
req.Header.Add("Content-Type", "application/json")
|
|
auth := InternalAuth{
|
|
Name: ot.Username,
|
|
// TODO: Populate these better
|
|
ClientID: "gql.sr.ht",
|
|
NodeID: "gql.sr.ht",
|
|
}
|
|
authBlob, err := json.Marshal(&auth)
|
|
if err != nil {
|
|
panic(err) // Programmer error
|
|
}
|
|
req.Header.Add("Authorization", fmt.Sprintf("Internal %s",
|
|
crypto.Encrypt(authBlob)))
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Printf("http.Do: %e", err)
|
|
return
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.Printf("ioutil.ReadAll: %e")
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
log.Printf("meta.sr.ht returned status %d: %s",
|
|
resp.StatusCode, string(respBody))
|
|
return
|
|
}
|
|
|
|
var result GraphQLResponse
|
|
if err = json.Unmarshal(respBody, &result); err != nil {
|
|
log.Printf("json.Unmarshal: %e")
|
|
return
|
|
}
|
|
|
|
if !result.Data.RevocationStatus {
|
|
atomic.AddInt32(&res, 1)
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
if res != 2 {
|
|
authError(w, "Invalid or expired OAuth 2.0 bearer token", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if user.UserType == USER_SUSPENDED {
|
|
authError(w, fmt.Sprintf("Account suspended with the following notice: %s\nContact support",
|
|
user.SuspensionNotice), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
user.AuthMethod = AUTH_OAUTH2
|
|
user.OAuth2Token = ot
|
|
|
|
if ot.Scopes != "" {
|
|
user.Access = make(map[string]string)
|
|
for _, grant := range strings.Split(ot.Scopes, ",") {
|
|
var (
|
|
service string
|
|
scope string
|
|
access string
|
|
)
|
|
parts := strings.Split(grant, "/")
|
|
if len(parts) != 2 {
|
|
panic(errors.New("OAuth grant without service/scope format"))
|
|
}
|
|
service = parts[0]
|
|
parts = strings.Split(parts[1], ":")
|
|
scope = parts[0]
|
|
if len(parts) == 1 {
|
|
access = "RO"
|
|
} else {
|
|
access = parts[1]
|
|
}
|
|
if service == config.ServiceName(r.Context()) {
|
|
user.Access[scope] = access
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey, &user)
|
|
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
|
|
// TODO: Remove legacy OAuth support
|
|
func LegacyOAuth(db *sql.DB, bearer string, hash [64]byte,
|
|
w http.ResponseWriter, r *http.Request, next http.Handler) {
|
|
|
|
var (
|
|
err error
|
|
expires time.Time
|
|
rows *sql.Rows
|
|
scopes string
|
|
user User
|
|
)
|
|
query := database.
|
|
Select(r.Context(), []string{
|
|
`ot.expires`,
|
|
`ot.scopes`,
|
|
`u.id`, `u.username`,
|
|
`u.created`, `u.updated`,
|
|
`u.email`,
|
|
`u.user_type`,
|
|
`u.url`, `u.location`, `u.bio`,
|
|
`u.suspension_notice`,
|
|
}).
|
|
From(`oauthtoken ot`).
|
|
Join(`"user" u ON u.id = ot.user_id`).
|
|
Where(`ot.token_hash = ?`, bearer)
|
|
if rows, err = query.RunWith(db).Query(); err != nil {
|
|
panic(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
authError(w, "Invalid or expired OAuth token", http.StatusForbidden)
|
|
return
|
|
}
|
|
if err := rows.Scan(&expires, &scopes,
|
|
&user.ID, &user.Username,
|
|
&user.Created, &user.Updated,
|
|
&user.Email,
|
|
&user.UserType,
|
|
&user.URL,
|
|
&user.Location,
|
|
&user.Bio,
|
|
&user.SuspensionNotice); err != nil {
|
|
panic(err)
|
|
}
|
|
if rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
panic(errors.New("Multiple matching OAuth tokens; invariant broken"))
|
|
}
|
|
|
|
if time.Now().UTC().After(expires) {
|
|
authError(w, "Invalid or expired OAuth token", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if user.UserType == USER_SUSPENDED {
|
|
authError(w, fmt.Sprintf("Account suspended with the following notice: %s\nContact support",
|
|
user.SuspensionNotice), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if scopes != "*" {
|
|
authError(w, "Presently, OAuth authentication to the GraphQL API is only supported for OAuth tokens with all permissions, namely '*'.", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
user.AuthMethod = AUTH_OAUTH_LEGACY
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey, &user)
|
|
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
|
|
func Middleware(conf ini.File, apiconf string) func(http.Handler) http.Handler {
|
|
var internalNet []*net.IPNet
|
|
src, ok := conf.Get(apiconf, "internal-ipnet")
|
|
if !ok {
|
|
// Conservative default
|
|
src = "127.0.0.1/24,::1/64"
|
|
}
|
|
for _, cidr := range strings.Split(src, ",") {
|
|
_, ipnet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
internalNet = append(internalNet, ipnet)
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/query" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
db := database.ForContext(r.Context())
|
|
|
|
cookie, err := r.Cookie("sr.ht.unified-login.v1")
|
|
if err == nil {
|
|
cookieAuth(db, cookie, w, r, next)
|
|
return
|
|
}
|
|
|
|
auth := r.Header.Get("Authorization")
|
|
if auth == "" {
|
|
authError(w, `Authorization header is required. Expected 'Authorization: Bearer <token>'`, http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
z := strings.SplitN(auth, " ", 2)
|
|
if len(z) != 2 {
|
|
authError(w, "Invalid Authorization header", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var bearer string
|
|
switch z[0] {
|
|
case "Bearer":
|
|
token := []byte(z[1])
|
|
if oauth2BearerRegex.Match(token) {
|
|
hash := sha512.Sum512(token)
|
|
bearer = z[1]
|
|
OAuth2(db, bearer, hash, w, r, next)
|
|
return
|
|
}
|
|
if oauthBearerRegex.Match(token) {
|
|
hash := sha512.Sum512(token)
|
|
bearer = hex.EncodeToString(hash[:])
|
|
LegacyOAuth(db, bearer, hash, w, r, next)
|
|
return
|
|
}
|
|
authError(w, "Invalid OAuth bearer token", http.StatusBadRequest)
|
|
return
|
|
case "Internal":
|
|
payload := []byte(z[1])
|
|
internalAuth(internalNet, db, payload, w, r, next)
|
|
return
|
|
default:
|
|
authError(w, "Invalid Authorization header", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
func ForContext(ctx context.Context) *User {
|
|
raw, ok := ctx.Value(userCtxKey).(*User)
|
|
if !ok {
|
|
panic(errors.New("Invalid authentication context"))
|
|
}
|
|
return raw
|
|
}
|