aboutsummaryrefslogtreecommitdiffhomepage
path: root/vendor/github.com/gorilla/mux/middleware_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/mux/middleware_test.go')
-rw-r--r--vendor/github.com/gorilla/mux/middleware_test.go336
1 files changed, 336 insertions, 0 deletions
diff --git a/vendor/github.com/gorilla/mux/middleware_test.go b/vendor/github.com/gorilla/mux/middleware_test.go
new file mode 100644
index 0000000..93947e8
--- /dev/null
+++ b/vendor/github.com/gorilla/mux/middleware_test.go
@@ -0,0 +1,336 @@
+package mux
+
+import (
+ "bytes"
+ "net/http"
+ "testing"
+)
+
+type testMiddleware struct {
+ timesCalled uint
+}
+
+func (tm *testMiddleware) Middleware(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ tm.timesCalled++
+ h.ServeHTTP(w, r)
+ })
+}
+
+func dummyHandler(w http.ResponseWriter, r *http.Request) {}
+
+func TestMiddlewareAdd(t *testing.T) {
+ router := NewRouter()
+ router.HandleFunc("/", dummyHandler).Methods("GET")
+
+ mw := &testMiddleware{}
+
+ router.useInterface(mw)
+ if len(router.middlewares) != 1 || router.middlewares[0] != mw {
+ t.Fatal("Middleware was not added correctly")
+ }
+
+ router.Use(mw.Middleware)
+ if len(router.middlewares) != 2 {
+ t.Fatal("MiddlewareFunc method was not added correctly")
+ }
+
+ banalMw := func(handler http.Handler) http.Handler {
+ return handler
+ }
+ router.Use(banalMw)
+ if len(router.middlewares) != 3 {
+ t.Fatal("MiddlewareFunc method was not added correctly")
+ }
+}
+
+func TestMiddleware(t *testing.T) {
+ router := NewRouter()
+ router.HandleFunc("/", dummyHandler).Methods("GET")
+
+ mw := &testMiddleware{}
+ router.useInterface(mw)
+
+ rw := NewRecorder()
+ req := newRequest("GET", "/")
+
+ // Test regular middleware call
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 1 {
+ t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
+ }
+
+ // Middleware should not be called for 404
+ req = newRequest("GET", "/not/found")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 1 {
+ t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
+ }
+
+ // Middleware should not be called if there is a method mismatch
+ req = newRequest("POST", "/")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 1 {
+ t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
+ }
+
+ // Add the middleware again as function
+ router.Use(mw.Middleware)
+ req = newRequest("GET", "/")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 3 {
+ t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
+ }
+
+}
+
+func TestMiddlewareSubrouter(t *testing.T) {
+ router := NewRouter()
+ router.HandleFunc("/", dummyHandler).Methods("GET")
+
+ subrouter := router.PathPrefix("/sub").Subrouter()
+ subrouter.HandleFunc("/x", dummyHandler).Methods("GET")
+
+ mw := &testMiddleware{}
+ subrouter.useInterface(mw)
+
+ rw := NewRecorder()
+ req := newRequest("GET", "/")
+
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 0 {
+ t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
+ }
+
+ req = newRequest("GET", "/sub/")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 0 {
+ t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
+ }
+
+ req = newRequest("GET", "/sub/x")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 1 {
+ t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
+ }
+
+ req = newRequest("GET", "/sub/not/found")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 1 {
+ t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
+ }
+
+ router.useInterface(mw)
+
+ req = newRequest("GET", "/")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 2 {
+ t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
+ }
+
+ req = newRequest("GET", "/sub/x")
+ router.ServeHTTP(rw, req)
+ if mw.timesCalled != 4 {
+ t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
+ }
+}
+
+func TestMiddlewareExecution(t *testing.T) {
+ mwStr := []byte("Middleware\n")
+ handlerStr := []byte("Logic\n")
+
+ router := NewRouter()
+ router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ })
+
+ rw := NewRecorder()
+ req := newRequest("GET", "/")
+
+ // Test handler-only call
+ router.ServeHTTP(rw, req)
+
+ if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 {
+ t.Fatal("Handler response is not what it should be")
+ }
+
+ // Test middleware call
+ rw = NewRecorder()
+
+ router.Use(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(mwStr)
+ h.ServeHTTP(w, r)
+ })
+ })
+
+ router.ServeHTTP(rw, req)
+ if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 {
+ t.Fatal("Middleware + handler response is not what it should be")
+ }
+}
+
+func TestMiddlewareNotFound(t *testing.T) {
+ mwStr := []byte("Middleware\n")
+ handlerStr := []byte("Logic\n")
+
+ router := NewRouter()
+ router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ })
+ router.Use(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(mwStr)
+ h.ServeHTTP(w, r)
+ })
+ })
+
+ // Test not found call with default handler
+ rw := NewRecorder()
+ req := newRequest("GET", "/notfound")
+
+ router.ServeHTTP(rw, req)
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a 404")
+ }
+
+ // Test not found call with custom handler
+ rw = NewRecorder()
+ req = newRequest("GET", "/notfound")
+
+ router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ rw.Write([]byte("Custom 404 handler"))
+ })
+ router.ServeHTTP(rw, req)
+
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a custom 404")
+ }
+}
+
+func TestMiddlewareMethodMismatch(t *testing.T) {
+ mwStr := []byte("Middleware\n")
+ handlerStr := []byte("Logic\n")
+
+ router := NewRouter()
+ router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ }).Methods("GET")
+
+ router.Use(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(mwStr)
+ h.ServeHTTP(w, r)
+ })
+ })
+
+ // Test method mismatch
+ rw := NewRecorder()
+ req := newRequest("POST", "/")
+
+ router.ServeHTTP(rw, req)
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a method mismatch")
+ }
+
+ // Test not found call
+ rw = NewRecorder()
+ req = newRequest("POST", "/")
+
+ router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ rw.Write([]byte("Method not allowed"))
+ })
+ router.ServeHTTP(rw, req)
+
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a method mismatch")
+ }
+}
+
+func TestMiddlewareNotFoundSubrouter(t *testing.T) {
+ mwStr := []byte("Middleware\n")
+ handlerStr := []byte("Logic\n")
+
+ router := NewRouter()
+ router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ })
+
+ subrouter := router.PathPrefix("/sub/").Subrouter()
+ subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ })
+
+ router.Use(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(mwStr)
+ h.ServeHTTP(w, r)
+ })
+ })
+
+ // Test not found call for default handler
+ rw := NewRecorder()
+ req := newRequest("GET", "/sub/notfound")
+
+ router.ServeHTTP(rw, req)
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a 404")
+ }
+
+ // Test not found call with custom handler
+ rw = NewRecorder()
+ req = newRequest("GET", "/sub/notfound")
+
+ subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ rw.Write([]byte("Custom 404 handler"))
+ })
+ router.ServeHTTP(rw, req)
+
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a custom 404")
+ }
+}
+
+func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
+ mwStr := []byte("Middleware\n")
+ handlerStr := []byte("Logic\n")
+
+ router := NewRouter()
+ router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ })
+
+ subrouter := router.PathPrefix("/sub/").Subrouter()
+ subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
+ w.Write(handlerStr)
+ }).Methods("GET")
+
+ router.Use(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(mwStr)
+ h.ServeHTTP(w, r)
+ })
+ })
+
+ // Test method mismatch without custom handler
+ rw := NewRecorder()
+ req := newRequest("POST", "/sub/")
+
+ router.ServeHTTP(rw, req)
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a method mismatch")
+ }
+
+ // Test method mismatch with custom handler
+ rw = NewRecorder()
+ req = newRequest("POST", "/sub/")
+
+ router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ rw.Write([]byte("Method not allowed"))
+ })
+ router.ServeHTTP(rw, req)
+
+ if bytes.Contains(rw.Body.Bytes(), mwStr) {
+ t.Fatal("Middleware was called for a method mismatch")
+ }
+}