diff options
Diffstat (limited to 'server/middleware')
-rw-r--r-- | server/middleware/basic_auth.go | 16 | ||||
-rw-r--r-- | server/middleware/context_keys.go | 26 | ||||
-rw-r--r-- | server/middleware/csrf.go | 6 | ||||
-rw-r--r-- | server/middleware/middleware.go | 15 | ||||
-rw-r--r-- | server/middleware/session.go | 11 |
5 files changed, 57 insertions, 17 deletions
diff --git a/server/middleware/basic_auth.go b/server/middleware/basic_auth.go index 73dfb98..3ad5318 100644 --- a/server/middleware/basic_auth.go +++ b/server/middleware/basic_auth.go @@ -6,15 +6,18 @@ package middleware import ( "context" - "github.com/miniflux/miniflux2/storage" "log" "net/http" + + "github.com/miniflux/miniflux2/storage" ) +// BasicAuthMiddleware is the middleware for HTTP Basic authentication. type BasicAuthMiddleware struct { store *storage.Storage } +// Handler executes the middleware. func (b *BasicAuthMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) @@ -35,7 +38,7 @@ func (b *BasicAuthMiddleware) Handler(next http.Handler) http.Handler { return } - user, err := b.store.GetUserByUsername(username) + user, err := b.store.UserByUsername(username) if err != nil || user == nil { log.Println("[Middleware:BasicAuth] User not found:", username) w.WriteHeader(http.StatusUnauthorized) @@ -47,15 +50,16 @@ func (b *BasicAuthMiddleware) Handler(next http.Handler) http.Handler { b.store.SetLastLogin(user.ID) ctx := r.Context() - ctx = context.WithValue(ctx, "UserId", user.ID) - ctx = context.WithValue(ctx, "UserTimezone", user.Timezone) - ctx = context.WithValue(ctx, "IsAdminUser", user.IsAdmin) - ctx = context.WithValue(ctx, "IsAuthenticated", true) + ctx = context.WithValue(ctx, UserIDContextKey, user.ID) + ctx = context.WithValue(ctx, UserTimezoneContextKey, user.Timezone) + ctx = context.WithValue(ctx, IsAdminUserContextKey, user.IsAdmin) + ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) next.ServeHTTP(w, r.WithContext(ctx)) }) } +// NewBasicAuthMiddleware returns a new BasicAuthMiddleware. func NewBasicAuthMiddleware(s *storage.Storage) *BasicAuthMiddleware { return &BasicAuthMiddleware{store: s} } diff --git a/server/middleware/context_keys.go b/server/middleware/context_keys.go new file mode 100644 index 0000000..c011fbb --- /dev/null +++ b/server/middleware/context_keys.go @@ -0,0 +1,26 @@ +// Copyright 2017 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package middleware + +type contextKey struct { + name string +} + +var ( + // UserIDContextKey is the context key used to store the user ID. + UserIDContextKey = &contextKey{"UserID"} + + // UserTimezoneContextKey is the context key used to store the user timezone. + UserTimezoneContextKey = &contextKey{"UserTimezone"} + + // IsAdminUserContextKey is the context key used to store the user role. + IsAdminUserContextKey = &contextKey{"IsAdminUser"} + + // IsAuthenticatedContextKey is the context key used to store the authentication flag. + IsAuthenticatedContextKey = &contextKey{"IsAuthenticated"} + + // CsrfContextKey is the context key used to store CSRF token. + CsrfContextKey = &contextKey{"CSRF"} +) diff --git a/server/middleware/csrf.go b/server/middleware/csrf.go index 74736b5..0c07e42 100644 --- a/server/middleware/csrf.go +++ b/server/middleware/csrf.go @@ -6,11 +6,13 @@ package middleware import ( "context" - "github.com/miniflux/miniflux2/helper" "log" "net/http" + + "github.com/miniflux/miniflux2/helper" ) +// Csrf is a middleware that handle CSRF tokens. func Csrf(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var csrfToken string @@ -32,7 +34,7 @@ func Csrf(next http.Handler) http.Handler { } ctx := r.Context() - ctx = context.WithValue(ctx, "CsrfToken", csrfToken) + ctx = context.WithValue(ctx, CsrfContextKey, csrfToken) w.Header().Add("Vary", "Cookie") isTokenValid := csrfToken == r.FormValue("csrf") || csrfToken == r.Header.Get("X-Csrf-Token") diff --git a/server/middleware/middleware.go b/server/middleware/middleware.go index cab01c8..9853bc3 100644 --- a/server/middleware/middleware.go +++ b/server/middleware/middleware.go @@ -8,13 +8,16 @@ import ( "net/http" ) +// Middleware represents a HTTP middleware. type Middleware func(http.Handler) http.Handler -type MiddlewareChain struct { +// Chain handles a list of middlewares. +type Chain struct { middlewares []Middleware } -func (m *MiddlewareChain) Wrap(h http.Handler) http.Handler { +// Wrap adds a HTTP handler into the chain. +func (m *Chain) Wrap(h http.Handler) http.Handler { for i := range m.middlewares { h = m.middlewares[len(m.middlewares)-1-i](h) } @@ -22,10 +25,12 @@ func (m *MiddlewareChain) Wrap(h http.Handler) http.Handler { return h } -func (m *MiddlewareChain) WrapFunc(fn http.HandlerFunc) http.Handler { +// WrapFunc adds a HTTP handler function into the chain. +func (m *Chain) WrapFunc(fn http.HandlerFunc) http.Handler { return m.Wrap(fn) } -func NewMiddlewareChain(middlewares ...Middleware) *MiddlewareChain { - return &MiddlewareChain{append(([]Middleware)(nil), middlewares...)} +// NewChain returns a new Chain. +func NewChain(middlewares ...Middleware) *Chain { + return &Chain{append(([]Middleware)(nil), middlewares...)} } diff --git a/server/middleware/session.go b/server/middleware/session.go index 1ab0d0a..e857c1f 100644 --- a/server/middleware/session.go +++ b/server/middleware/session.go @@ -16,11 +16,13 @@ import ( "github.com/gorilla/mux" ) +// SessionMiddleware represents a session middleware. type SessionMiddleware struct { store *storage.Storage router *mux.Router } +// Handler execute the middleware. func (s *SessionMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := s.getSessionFromCookie(r) @@ -30,13 +32,13 @@ func (s *SessionMiddleware) Handler(next http.Handler) http.Handler { if s.isPublicRoute(r) { next.ServeHTTP(w, r) } else { - http.Redirect(w, r, route.GetRoute(s.router, "login"), http.StatusFound) + http.Redirect(w, r, route.Path(s.router, "login"), http.StatusFound) } } else { log.Println("[Middleware:Session]", session) ctx := r.Context() - ctx = context.WithValue(ctx, "UserId", session.UserID) - ctx = context.WithValue(ctx, "IsAuthenticated", true) + ctx = context.WithValue(ctx, UserIDContextKey, session.UserID) + ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) next.ServeHTTP(w, r.WithContext(ctx)) } @@ -59,7 +61,7 @@ func (s *SessionMiddleware) getSessionFromCookie(r *http.Request) *model.Session return nil } - session, err := s.store.GetSessionByToken(sessionCookie.Value) + session, err := s.store.SessionByToken(sessionCookie.Value) if err != nil { log.Println(err) return nil @@ -68,6 +70,7 @@ func (s *SessionMiddleware) getSessionFromCookie(r *http.Request) *model.Session return session } +// NewSessionMiddleware returns a new SessionMiddleware. func NewSessionMiddleware(s *storage.Storage, r *mux.Router) *SessionMiddleware { return &SessionMiddleware{store: s, router: r} } |