package mux import ( "bytes" "net/http" "testing" ) type testMiddleware struct { timesCalled uint } func (tm *testMiddleware) Middleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tm.timesCalled++ h.ServeHTTP(w, r) }) } func dummyHandler(w http.ResponseWriter, r *http.Request) {} func TestMiddlewareAdd(t *testing.T) { router := NewRouter() router.HandleFunc("/", dummyHandler).Methods("GET") mw := &testMiddleware{} router.useInterface(mw) if len(router.middlewares) != 1 || router.middlewares[0] != mw { t.Fatal("Middleware was not added correctly") } router.Use(mw.Middleware) if len(router.middlewares) != 2 { t.Fatal("MiddlewareFunc method was not added correctly") } banalMw := func(handler http.Handler) http.Handler { return handler } router.Use(banalMw) if len(router.middlewares) != 3 { t.Fatal("MiddlewareFunc method was not added correctly") } } func TestMiddleware(t *testing.T) { router := NewRouter() router.HandleFunc("/", dummyHandler).Methods("GET") mw := &testMiddleware{} router.useInterface(mw) rw := NewRecorder() req := newRequest("GET", "/") // Test regular middleware call router.ServeHTTP(rw, req) if mw.timesCalled != 1 { t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) } // Middleware should not be called for 404 req = newRequest("GET", "/not/found") router.ServeHTTP(rw, req) if mw.timesCalled != 1 { t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) } // Middleware should not be called if there is a method mismatch req = newRequest("POST", "/") router.ServeHTTP(rw, req) if mw.timesCalled != 1 { t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) } // Add the middleware again as function router.Use(mw.Middleware) req = newRequest("GET", "/") router.ServeHTTP(rw, req) if mw.timesCalled != 3 { t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) } } func TestMiddlewareSubrouter(t *testing.T) { router := NewRouter() router.HandleFunc("/", dummyHandler).Methods("GET") subrouter := router.PathPrefix("/sub").Subrouter() subrouter.HandleFunc("/x", dummyHandler).Methods("GET") mw := &testMiddleware{} subrouter.useInterface(mw) rw := NewRecorder() req := newRequest("GET", "/") router.ServeHTTP(rw, req) if mw.timesCalled != 0 { t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) } req = newRequest("GET", "/sub/") router.ServeHTTP(rw, req) if mw.timesCalled != 0 { t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) } req = newRequest("GET", "/sub/x") router.ServeHTTP(rw, req) if mw.timesCalled != 1 { t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) } req = newRequest("GET", "/sub/not/found") router.ServeHTTP(rw, req) if mw.timesCalled != 1 { t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) } router.useInterface(mw) req = newRequest("GET", "/") router.ServeHTTP(rw, req) if mw.timesCalled != 2 { t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled) } req = newRequest("GET", "/sub/x") router.ServeHTTP(rw, req) if mw.timesCalled != 4 { t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled) } } func TestMiddlewareExecution(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") router := NewRouter() router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }) rw := NewRecorder() req := newRequest("GET", "/") // Test handler-only call router.ServeHTTP(rw, req) if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 { t.Fatal("Handler response is not what it should be") } // Test middleware call rw = NewRecorder() router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(mwStr) h.ServeHTTP(w, r) }) }) router.ServeHTTP(rw, req) if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 { t.Fatal("Middleware + handler response is not what it should be") } } func TestMiddlewareNotFound(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") router := NewRouter() router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }) router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(mwStr) h.ServeHTTP(w, r) }) }) // Test not found call with default handler rw := NewRecorder() req := newRequest("GET", "/notfound") router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a 404") } // Test not found call with custom handler rw = NewRecorder() req = newRequest("GET", "/notfound") router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Write([]byte("Custom 404 handler")) }) router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a custom 404") } } func TestMiddlewareMethodMismatch(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") router := NewRouter() router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }).Methods("GET") router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(mwStr) h.ServeHTTP(w, r) }) }) // Test method mismatch rw := NewRecorder() req := newRequest("POST", "/") router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a method mismatch") } // Test not found call rw = NewRecorder() req = newRequest("POST", "/") router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Write([]byte("Method not allowed")) }) router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a method mismatch") } } func TestMiddlewareNotFoundSubrouter(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") router := NewRouter() router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }) subrouter := router.PathPrefix("/sub/").Subrouter() subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }) router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(mwStr) h.ServeHTTP(w, r) }) }) // Test not found call for default handler rw := NewRecorder() req := newRequest("GET", "/sub/notfound") router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a 404") } // Test not found call with custom handler rw = NewRecorder() req = newRequest("GET", "/sub/notfound") subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Write([]byte("Custom 404 handler")) }) router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a custom 404") } } func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") router := NewRouter() router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }) subrouter := router.PathPrefix("/sub/").Subrouter() subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { w.Write(handlerStr) }).Methods("GET") router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(mwStr) h.ServeHTTP(w, r) }) }) // Test method mismatch without custom handler rw := NewRecorder() req := newRequest("POST", "/sub/") router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a method mismatch") } // Test method mismatch with custom handler rw = NewRecorder() req = newRequest("POST", "/sub/") router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Write([]byte("Method not allowed")) }) router.ServeHTTP(rw, req) if bytes.Contains(rw.Body.Bytes(), mwStr) { t.Fatal("Middleware was called for a method mismatch") } }