diff --git a/redis/middleware.go b/redis/middleware.go index f70a64d..769f3bd 100644 --- a/redis/middleware.go +++ b/redis/middleware.go @@ -2,10 +2,9 @@ package redis import ( "context" - "errors" "net/http" - goRedis "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v8" ) var redisCtxKey = &contextKey{"redis"} @@ -14,23 +13,24 @@ type contextKey struct { name string } -func Middleware(client *goRedis.Client) func(http.Handler) http.Handler { +func Middleware(client *redis.Client) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(Context(r.Context(), client)) + ctx := Context(r.Context(), client) + r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } } -func Context(ctx context.Context, client *goRedis.Client) context.Context { +func Context(ctx context.Context, client *redis.Client) context.Context { return context.WithValue(ctx, redisCtxKey, client) } -func ForContext(ctx context.Context) *goRedis.Client { - raw, ok := ctx.Value(redisCtxKey).(*goRedis.Client) +func ForContext(ctx context.Context) *redis.Client { + raw, ok := ctx.Value(redisCtxKey).(*redis.Client) if !ok { - panic(errors.New("Invalid redis context")) + panic("Invalid redis context") } return raw }