From 9d08139f4363d3503398002bc82cb3746e3438cf Mon Sep 17 00:00:00 2001 From: Frédéric Guillot Date: Sun, 23 Sep 2018 21:02:26 -0700 Subject: Improve request package and add more unit tests --- http/request/client_ip.go | 38 ++++ http/request/client_ip_test.go | 82 ++++++++ http/request/context.go | 21 +- http/request/context_test.go | 436 +++++++++++++++++++++++++++++++++++++++++ http/request/cookie.go | 17 ++ http/request/cookie_test.go | 33 ++++ http/request/params.go | 84 ++++++++ http/request/params_test.go | 215 ++++++++++++++++++++ http/request/request.go | 128 ------------ http/request/request_test.go | 82 -------- 10 files changed, 923 insertions(+), 213 deletions(-) create mode 100644 http/request/client_ip.go create mode 100644 http/request/client_ip_test.go create mode 100644 http/request/context_test.go create mode 100644 http/request/cookie.go create mode 100644 http/request/cookie_test.go create mode 100644 http/request/params.go create mode 100644 http/request/params_test.go delete mode 100644 http/request/request.go delete mode 100644 http/request/request_test.go (limited to 'http') diff --git a/http/request/client_ip.go b/http/request/client_ip.go new file mode 100644 index 0000000..52fc05c --- /dev/null +++ b/http/request/client_ip.go @@ -0,0 +1,38 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net" + "net/http" + "strings" +) + +// FindClientIP returns client real IP address. +func FindClientIP(r *http.Request) string { + headers := []string{"X-Forwarded-For", "X-Real-Ip"} + for _, header := range headers { + value := r.Header.Get(header) + + if value != "" { + addresses := strings.Split(value, ",") + address := strings.TrimSpace(addresses[0]) + + if net.ParseIP(address) != nil { + return address + } + } + } + + // Fallback to TCP/IP source IP address. + var remoteIP string + if strings.ContainsRune(r.RemoteAddr, ':') { + remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr) + } else { + remoteIP = r.RemoteAddr + } + + return remoteIP +} diff --git a/http/request/client_ip_test.go b/http/request/client_ip_test.go new file mode 100644 index 0000000..12d6e16 --- /dev/null +++ b/http/request/client_ip_test.go @@ -0,0 +1,82 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "testing" +) + +func TestFindClientIPWithoutHeaders(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.0.1:4242"} + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + r = &http.Request{RemoteAddr: "192.168.0.1"} + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestFindClientIPWithXFFHeader(t *testing.T) { + // Test with multiple IPv4 addresses. + headers := http.Header{} + headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "203.0.113.195" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with single IPv6 address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "2001:db8:85a3:8d3:1319:8a2e:370:7348") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "2001:db8:85a3:8d3:1319:8a2e:370:7348" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with single IPv4 address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "70.41.3.18") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "70.41.3.18" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with invalid IP address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "fake IP") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestClientIPWithXRealIPHeader(t *testing.T) { + headers := http.Header{} + headers.Set("X-Real-Ip", "192.168.122.1") + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "192.168.122.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestClientIPWithBothHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") + headers.Set("X-Real-Ip", "192.168.122.1") + + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "203.0.113.195" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} diff --git a/http/request/context.go b/http/request/context.go index b77365d..51ee46a 100644 --- a/http/request/context.go +++ b/http/request/context.go @@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string { func getContextStringValue(r *http.Request, key ContextKey) string { if v := r.Context().Value(key); v != nil { - return v.(string) + value, valid := v.(string) + if !valid { + return "" + } + + return value } return "" @@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string { func getContextBoolValue(r *http.Request, key ContextKey) bool { if v := r.Context().Value(key); v != nil { - return v.(bool) + value, valid := v.(bool) + if !valid { + return false + } + + return value } return false @@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool { func getContextInt64Value(r *http.Request, key ContextKey) int64 { if v := r.Context().Value(key); v != nil { - return v.(int64) + value, valid := v.(int64) + if !valid { + return 0 + } + + return value } return 0 diff --git a/http/request/context_test.go b/http/request/context_test.go new file mode 100644 index 0000000..e741362 --- /dev/null +++ b/http/request/context_test.go @@ -0,0 +1,436 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "context" + "net/http" + "testing" +) + +func TestContextStringValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "IP") + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "IP" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, 0) + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextBoolValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextInt64Value(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(1234)) + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(1234) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestIsAdmin(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAdminUser(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result = IsAdminUser(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestIsAuthenticated(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAuthenticated(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) + r = r.WithContext(ctx) + + result = IsAuthenticated(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserID(r) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(123)) + r = r.WithContext(ctx) + + result = UserID(r) + expected = int64(123) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserTimezone(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTimezone(r) + expected := "UTC" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris") + r = r.WithContext(ctx) + + result = UserTimezone(r) + expected = "Europe/Paris" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserLanguage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserLanguage(r) + expected := "en_US" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR") + r = r.WithContext(ctx) + + result = UserLanguage(r) + expected = "fr_FR" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserTheme(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTheme(r) + expected := "default" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserThemeContextKey, "black") + r = r.WithContext(ctx) + + result = UserTheme(r) + expected = "black" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestCSRF(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := CSRF(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, CSRFContextKey, "secret") + r = r.WithContext(ctx) + + result = CSRF(r) + expected = "secret" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestSessionID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := SessionID(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, SessionIDContextKey, "id") + r = r.WithContext(ctx) + + result = SessionID(r) + expected = "id" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserSessionToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserSessionToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token") + r = r.WithContext(ctx) + + result = UserSessionToken(r) + expected = "token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestOAuth2State(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := OAuth2State(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, OAuth2StateContextKey, "state") + r = r.WithContext(ctx) + + result = OAuth2State(r) + expected = "state" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashMessageContextKey, "message") + r = r.WithContext(ctx) + + result = FlashMessage(r) + expected = "message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashErrorMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashErrorMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message") + r = r.WithContext(ctx) + + result = FlashErrorMessage(r) + expected = "error message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestPocketRequestToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := PocketRequestToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token") + r = r.WithContext(ctx) + + result = PocketRequestToken(r) + expected = "request token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestClientIP(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := ClientIP(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1") + r = r.WithContext(ctx) + + result = ClientIP(r) + expected = "127.0.0.1" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} diff --git a/http/request/cookie.go b/http/request/cookie.go new file mode 100644 index 0000000..88cc626 --- /dev/null +++ b/http/request/cookie.go @@ -0,0 +1,17 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import "net/http" + +// CookieValue returns the cookie value. +func CookieValue(r *http.Request, name string) string { + cookie, err := r.Cookie(name) + if err == http.ErrNoCookie { + return "" + } + + return cookie.Value +} diff --git a/http/request/cookie_test.go b/http/request/cookie_test.go new file mode 100644 index 0000000..9c3b54d --- /dev/null +++ b/http/request/cookie_test.go @@ -0,0 +1,33 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "testing" +) + +func TestGetCookieValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"}) + + result := CookieValue(r, "my_cookie") + expected := "cookie_value" + + if result != expected { + t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected) + } +} + +func TestGetCookieValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := CookieValue(r, "my_cookie") + expected := "" + + if result != expected { + t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected) + } +} diff --git a/http/request/params.go b/http/request/params.go new file mode 100644 index 0000000..8218a0d --- /dev/null +++ b/http/request/params.go @@ -0,0 +1,84 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "strconv" + + "github.com/gorilla/mux" +) + +// FormInt64Value returns a form value as integer. +func FormInt64Value(r *http.Request, param string) int64 { + value := r.FormValue(param) + integer, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return 0 + } + + return integer +} + +// RouteInt64Param returns an URL route parameter as int64. +func RouteInt64Param(r *http.Request, param string) int64 { + vars := mux.Vars(r) + value, err := strconv.ParseInt(vars[param], 10, 64) + if err != nil { + return 0 + } + + if value < 0 { + return 0 + } + + return value +} + +// RouteStringParam returns a URL route parameter as string. +func RouteStringParam(r *http.Request, param string) string { + vars := mux.Vars(r) + return vars[param] +} + +// QueryStringParam returns a query string parameter as string. +func QueryStringParam(r *http.Request, param, defaultValue string) string { + value := r.URL.Query().Get(param) + if value == "" { + value = defaultValue + } + return value +} + +// QueryIntParam returns a query string parameter as integer. +func QueryIntParam(r *http.Request, param string, defaultValue int) int { + return int(QueryInt64Param(r, param, int64(defaultValue))) +} + +// QueryInt64Param returns a query string parameter as int64. +func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 { + value := r.URL.Query().Get(param) + if value == "" { + return defaultValue + } + + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return defaultValue + } + + if val < 0 { + return defaultValue + } + + return val +} + +// HasQueryParam checks if the query string contains the given parameter. +func HasQueryParam(r *http.Request, param string) bool { + values := r.URL.Query() + _, ok := values[param] + return ok +} diff --git a/http/request/params_test.go b/http/request/params_test.go new file mode 100644 index 0000000..7f1f880 --- /dev/null +++ b/http/request/params_test.go @@ -0,0 +1,215 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" +) + +func TestFormInt64Value(t *testing.T) { + f := url.Values{} + f.Set("integer value", "42") + f.Set("invalid value", "invalid integer") + + r := &http.Request{Form: f} + + result := FormInt64Value(r, "integer value") + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = FormInt64Value(r, "invalid value") + expected = int64(0) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = FormInt64Value(r, "missing value") + expected = int64(0) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestRouteStringParam(t *testing.T) { + router := mux.NewRouter() + router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) { + result := RouteStringParam(r, "variable") + expected := "value" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + + result = RouteStringParam(r, "missing variable") + expected = "" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + }) + + r, err := http.NewRequest("GET", "/route/value/index", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + router.ServeHTTP(w, r) +} + +func TestRouteInt64Param(t *testing.T) { + router := mux.NewRouter() + router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) { + result := RouteInt64Param(r, "variable1") + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "missing variable") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "variable2") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "variable3") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + }) + + r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + router.ServeHTTP(w, r) +} + +func TestQueryStringParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=value") + r := &http.Request{URL: u} + + result := QueryStringParam(r, "key", "fallback") + expected := "value" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + + result = QueryStringParam(r, "missing key", "fallback") + expected = "fallback" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } +} + +func TestQueryIntParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5") + r := &http.Request{URL: u} + + result := QueryIntParam(r, "key", 84) + expected := 42 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "missing key", 84) + expected = 84 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "negative", 69) + expected = 69 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "invalid", 99) + expected = 99 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestQueryInt64Param(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5") + r := &http.Request{URL: u} + + result := QueryInt64Param(r, "key", int64(84)) + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "missing key", int64(84)) + expected = int64(84) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "invalid", int64(69)) + expected = int64(69) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "invalid", int64(99)) + expected = int64(99) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestHasQueryParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42") + r := &http.Request{URL: u} + + result := HasQueryParam(r, "key") + expected := true + + if result != expected { + t.Errorf(`Unexpected result, got %v instead of %v`, result, expected) + } + + result = HasQueryParam(r, "missing key") + expected = false + + if result != expected { + t.Errorf(`Unexpected result, got %v instead of %v`, result, expected) + } +} diff --git a/http/request/request.go b/http/request/request.go deleted file mode 100644 index d27137b..0000000 --- a/http/request/request.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2018 Frédéric Guillot. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package request // import "miniflux.app/http/request" - -import ( - "fmt" - "net" - "net/http" - "strconv" - "strings" - - "github.com/gorilla/mux" -) - -// Cookie returns the cookie value. -func Cookie(r *http.Request, name string) string { - cookie, err := r.Cookie(name) - if err == http.ErrNoCookie { - return "" - } - - return cookie.Value -} - -// FormInt64Value returns a form value as integer. -func FormInt64Value(r *http.Request, param string) int64 { - value := r.FormValue(param) - integer, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return 0 - } - - return integer -} - -// IntParam returns an URL route parameter as integer. -func IntParam(r *http.Request, param string) (int64, error) { - vars := mux.Vars(r) - value, err := strconv.Atoi(vars[param]) - if err != nil { - return 0, fmt.Errorf("request: %s parameter is not an integer", param) - } - - if value < 0 { - return 0, nil - } - - return int64(value), nil -} - -// Param returns an URL route parameter as string. -func Param(r *http.Request, param, defaultValue string) string { - vars := mux.Vars(r) - value := vars[param] - if value == "" { - value = defaultValue - } - return value -} - -// QueryParam returns a querystring parameter as string. -func QueryParam(r *http.Request, param, defaultValue string) string { - value := r.URL.Query().Get(param) - if value == "" { - value = defaultValue - } - return value -} - -// QueryIntParam returns a querystring parameter as integer. -func QueryIntParam(r *http.Request, param string, defaultValue int) int { - return int(QueryInt64Param(r, param, int64(defaultValue))) -} - -// QueryInt64Param returns a querystring parameter as int64. -func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 { - value := r.URL.Query().Get(param) - if value == "" { - return defaultValue - } - - val, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return defaultValue - } - - if val < 0 { - return defaultValue - } - - return val -} - -// HasQueryParam checks if the query string contains the given parameter. -func HasQueryParam(r *http.Request, param string) bool { - values := r.URL.Query() - _, ok := values[param] - return ok -} - -// FindClientIP returns client's real IP address. -func FindClientIP(r *http.Request) string { - headers := []string{"X-Forwarded-For", "X-Real-Ip"} - for _, header := range headers { - value := r.Header.Get(header) - - if value != "" { - addresses := strings.Split(value, ",") - address := strings.TrimSpace(addresses[0]) - - if net.ParseIP(address) != nil { - return address - } - } - } - - // Fallback to TCP/IP source IP address. - var remoteIP string - if strings.ContainsRune(r.RemoteAddr, ':') { - remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr) - } else { - remoteIP = r.RemoteAddr - } - - return remoteIP -} diff --git a/http/request/request_test.go b/http/request/request_test.go deleted file mode 100644 index 946b132..0000000 --- a/http/request/request_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 Frédéric Guillot. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package request // import "miniflux.app/http/request" - -import ( - "net/http" - "testing" -) - -func TestRealIPWithoutHeaders(t *testing.T) { - r := &http.Request{RemoteAddr: "192.168.0.1:4242"} - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - r = &http.Request{RemoteAddr: "192.168.0.1"} - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithXFFHeader(t *testing.T) { - // Test with multiple IPv4 addresses. - headers := http.Header{} - headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "203.0.113.195" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with single IPv6 address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "2001:db8:85a3:8d3:1319:8a2e:370:7348") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "2001:db8:85a3:8d3:1319:8a2e:370:7348" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with single IPv4 address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "70.41.3.18") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "70.41.3.18" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with invalid IP address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "fake IP") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithXRealIPHeader(t *testing.T) { - headers := http.Header{} - headers.Set("X-Real-Ip", "192.168.122.1") - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "192.168.122.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithBothHeaders(t *testing.T) { - headers := http.Header{} - headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") - headers.Set("X-Real-Ip", "192.168.122.1") - - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "203.0.113.195" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} -- cgit v1.2.3