aboutsummaryrefslogtreecommitdiffhomepage
path: root/http
diff options
context:
space:
mode:
authorGravatar Frédéric Guillot <fred@miniflux.net>2018-09-23 21:02:26 -0700
committerGravatar Frédéric Guillot <fred@miniflux.net>2018-09-23 21:02:26 -0700
commit9d08139f4363d3503398002bc82cb3746e3438cf (patch)
treeb09c1efb14445624e4a2771bf8cbcf3a9828ecf9 /http
parent844680e57328034c20a2d2b84bd315b55fee9e23 (diff)
Improve request package and add more unit tests
Diffstat (limited to 'http')
-rw-r--r--http/request/client_ip.go38
-rw-r--r--http/request/client_ip_test.go (renamed from http/request/request_test.go)8
-rw-r--r--http/request/context.go21
-rw-r--r--http/request/context_test.go436
-rw-r--r--http/request/cookie.go17
-rw-r--r--http/request/cookie_test.go33
-rw-r--r--http/request/params.go84
-rw-r--r--http/request/params_test.go215
-rw-r--r--http/request/request.go128
9 files changed, 845 insertions, 135 deletions
diff --git a/http/request/client_ip.go b/http/request/client_ip.go
new file mode 100644
index 0000000..52fc05c
--- /dev/null
+++ b/http/request/client_ip.go
@@ -0,0 +1,38 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net"
+ "net/http"
+ "strings"
+)
+
+// FindClientIP returns client real IP address.
+func FindClientIP(r *http.Request) string {
+ headers := []string{"X-Forwarded-For", "X-Real-Ip"}
+ for _, header := range headers {
+ value := r.Header.Get(header)
+
+ if value != "" {
+ addresses := strings.Split(value, ",")
+ address := strings.TrimSpace(addresses[0])
+
+ if net.ParseIP(address) != nil {
+ return address
+ }
+ }
+ }
+
+ // Fallback to TCP/IP source IP address.
+ var remoteIP string
+ if strings.ContainsRune(r.RemoteAddr, ':') {
+ remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
+ } else {
+ remoteIP = r.RemoteAddr
+ }
+
+ return remoteIP
+}
diff --git a/http/request/request_test.go b/http/request/client_ip_test.go
index 946b132..12d6e16 100644
--- a/http/request/request_test.go
+++ b/http/request/client_ip_test.go
@@ -9,7 +9,7 @@ import (
"testing"
)
-func TestRealIPWithoutHeaders(t *testing.T) {
+func TestFindClientIPWithoutHeaders(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
if ip := FindClientIP(r); ip != "192.168.0.1" {
t.Fatalf(`Unexpected result, got: %q`, ip)
@@ -21,7 +21,7 @@ func TestRealIPWithoutHeaders(t *testing.T) {
}
}
-func TestRealIPWithXFFHeader(t *testing.T) {
+func TestFindClientIPWithXFFHeader(t *testing.T) {
// Test with multiple IPv4 addresses.
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
@@ -59,7 +59,7 @@ func TestRealIPWithXFFHeader(t *testing.T) {
}
}
-func TestRealIPWithXRealIPHeader(t *testing.T) {
+func TestClientIPWithXRealIPHeader(t *testing.T) {
headers := http.Header{}
headers.Set("X-Real-Ip", "192.168.122.1")
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
@@ -69,7 +69,7 @@ func TestRealIPWithXRealIPHeader(t *testing.T) {
}
}
-func TestRealIPWithBothHeaders(t *testing.T) {
+func TestClientIPWithBothHeaders(t *testing.T) {
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
headers.Set("X-Real-Ip", "192.168.122.1")
diff --git a/http/request/context.go b/http/request/context.go
index b77365d..51ee46a 100644
--- a/http/request/context.go
+++ b/http/request/context.go
@@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string {
func getContextStringValue(r *http.Request, key ContextKey) string {
if v := r.Context().Value(key); v != nil {
- return v.(string)
+ value, valid := v.(string)
+ if !valid {
+ return ""
+ }
+
+ return value
}
return ""
@@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string {
func getContextBoolValue(r *http.Request, key ContextKey) bool {
if v := r.Context().Value(key); v != nil {
- return v.(bool)
+ value, valid := v.(bool)
+ if !valid {
+ return false
+ }
+
+ return value
}
return false
@@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool {
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
if v := r.Context().Value(key); v != nil {
- return v.(int64)
+ value, valid := v.(int64)
+ if !valid {
+ return 0
+ }
+
+ return value
}
return 0
diff --git a/http/request/context_test.go b/http/request/context_test.go
new file mode 100644
index 0000000..e741362
--- /dev/null
+++ b/http/request/context_test.go
@@ -0,0 +1,436 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "context"
+ "net/http"
+ "testing"
+)
+
+func TestContextStringValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
+ r = r.WithContext(ctx)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := "IP"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextStringValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, 0)
+ r = r.WithContext(ctx)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextStringValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextBoolValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
+ r = r.WithContext(ctx)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextBoolValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
+ r = r.WithContext(ctx)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextBoolValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextInt64Value(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
+ r = r.WithContext(ctx)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(1234)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestContextInt64ValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
+ r = r.WithContext(ctx)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestContextInt64ValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestIsAdmin(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := IsAdminUser(r)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
+ r = r.WithContext(ctx)
+
+ result = IsAdminUser(r)
+ expected = true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestIsAuthenticated(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := IsAuthenticated(r)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
+ r = r.WithContext(ctx)
+
+ result = IsAuthenticated(r)
+ expected = true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestUserID(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserID(r)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
+ r = r.WithContext(ctx)
+
+ result = UserID(r)
+ expected = int64(123)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestUserTimezone(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserTimezone(r)
+ expected := "UTC"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
+ r = r.WithContext(ctx)
+
+ result = UserTimezone(r)
+ expected = "Europe/Paris"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserLanguage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserLanguage(r)
+ expected := "en_US"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
+ r = r.WithContext(ctx)
+
+ result = UserLanguage(r)
+ expected = "fr_FR"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserTheme(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserTheme(r)
+ expected := "default"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserThemeContextKey, "black")
+ r = r.WithContext(ctx)
+
+ result = UserTheme(r)
+ expected = "black"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestCSRF(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := CSRF(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, CSRFContextKey, "secret")
+ r = r.WithContext(ctx)
+
+ result = CSRF(r)
+ expected = "secret"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestSessionID(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := SessionID(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, SessionIDContextKey, "id")
+ r = r.WithContext(ctx)
+
+ result = SessionID(r)
+ expected = "id"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserSessionToken(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserSessionToken(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
+ r = r.WithContext(ctx)
+
+ result = UserSessionToken(r)
+ expected = "token"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestOAuth2State(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := OAuth2State(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
+ r = r.WithContext(ctx)
+
+ result = OAuth2State(r)
+ expected = "state"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestFlashMessage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := FlashMessage(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
+ r = r.WithContext(ctx)
+
+ result = FlashMessage(r)
+ expected = "message"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestFlashErrorMessage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := FlashErrorMessage(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
+ r = r.WithContext(ctx)
+
+ result = FlashErrorMessage(r)
+ expected = "error message"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestPocketRequestToken(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := PocketRequestToken(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
+ r = r.WithContext(ctx)
+
+ result = PocketRequestToken(r)
+ expected = "request token"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestClientIP(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := ClientIP(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
+ r = r.WithContext(ctx)
+
+ result = ClientIP(r)
+ expected = "127.0.0.1"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
diff --git a/http/request/cookie.go b/http/request/cookie.go
new file mode 100644
index 0000000..88cc626
--- /dev/null
+++ b/http/request/cookie.go
@@ -0,0 +1,17 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import "net/http"
+
+// CookieValue returns the cookie value.
+func CookieValue(r *http.Request, name string) string {
+ cookie, err := r.Cookie(name)
+ if err == http.ErrNoCookie {
+ return ""
+ }
+
+ return cookie.Value
+}
diff --git a/http/request/cookie_test.go b/http/request/cookie_test.go
new file mode 100644
index 0000000..9c3b54d
--- /dev/null
+++ b/http/request/cookie_test.go
@@ -0,0 +1,33 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "testing"
+)
+
+func TestGetCookieValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"})
+
+ result := CookieValue(r, "my_cookie")
+ expected := "cookie_value"
+
+ if result != expected {
+ t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestGetCookieValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := CookieValue(r, "my_cookie")
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
+ }
+}
diff --git a/http/request/params.go b/http/request/params.go
new file mode 100644
index 0000000..8218a0d
--- /dev/null
+++ b/http/request/params.go
@@ -0,0 +1,84 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "strconv"
+
+ "github.com/gorilla/mux"
+)
+
+// FormInt64Value returns a form value as integer.
+func FormInt64Value(r *http.Request, param string) int64 {
+ value := r.FormValue(param)
+ integer, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return 0
+ }
+
+ return integer
+}
+
+// RouteInt64Param returns an URL route parameter as int64.
+func RouteInt64Param(r *http.Request, param string) int64 {
+ vars := mux.Vars(r)
+ value, err := strconv.ParseInt(vars[param], 10, 64)
+ if err != nil {
+ return 0
+ }
+
+ if value < 0 {
+ return 0
+ }
+
+ return value
+}
+
+// RouteStringParam returns a URL route parameter as string.
+func RouteStringParam(r *http.Request, param string) string {
+ vars := mux.Vars(r)
+ return vars[param]
+}
+
+// QueryStringParam returns a query string parameter as string.
+func QueryStringParam(r *http.Request, param, defaultValue string) string {
+ value := r.URL.Query().Get(param)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// QueryIntParam returns a query string parameter as integer.
+func QueryIntParam(r *http.Request, param string, defaultValue int) int {
+ return int(QueryInt64Param(r, param, int64(defaultValue)))
+}
+
+// QueryInt64Param returns a query string parameter as int64.
+func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
+ value := r.URL.Query().Get(param)
+ if value == "" {
+ return defaultValue
+ }
+
+ val, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return defaultValue
+ }
+
+ if val < 0 {
+ return defaultValue
+ }
+
+ return val
+}
+
+// HasQueryParam checks if the query string contains the given parameter.
+func HasQueryParam(r *http.Request, param string) bool {
+ values := r.URL.Query()
+ _, ok := values[param]
+ return ok
+}
diff --git a/http/request/params_test.go b/http/request/params_test.go
new file mode 100644
index 0000000..7f1f880
--- /dev/null
+++ b/http/request/params_test.go
@@ -0,0 +1,215 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/gorilla/mux"
+)
+
+func TestFormInt64Value(t *testing.T) {
+ f := url.Values{}
+ f.Set("integer value", "42")
+ f.Set("invalid value", "invalid integer")
+
+ r := &http.Request{Form: f}
+
+ result := FormInt64Value(r, "integer value")
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = FormInt64Value(r, "invalid value")
+ expected = int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = FormInt64Value(r, "missing value")
+ expected = int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestRouteStringParam(t *testing.T) {
+ router := mux.NewRouter()
+ router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
+ result := RouteStringParam(r, "variable")
+ expected := "value"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+
+ result = RouteStringParam(r, "missing variable")
+ expected = ""
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+ })
+
+ r, err := http.NewRequest("GET", "/route/value/index", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+}
+
+func TestRouteInt64Param(t *testing.T) {
+ router := mux.NewRouter()
+ router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
+ result := RouteInt64Param(r, "variable1")
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "missing variable")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "variable2")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "variable3")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+ })
+
+ r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+}
+
+func TestQueryStringParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=value")
+ r := &http.Request{URL: u}
+
+ result := QueryStringParam(r, "key", "fallback")
+ expected := "value"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+
+ result = QueryStringParam(r, "missing key", "fallback")
+ expected = "fallback"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestQueryIntParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
+ r := &http.Request{URL: u}
+
+ result := QueryIntParam(r, "key", 84)
+ expected := 42
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "missing key", 84)
+ expected = 84
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "negative", 69)
+ expected = 69
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "invalid", 99)
+ expected = 99
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestQueryInt64Param(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
+ r := &http.Request{URL: u}
+
+ result := QueryInt64Param(r, "key", int64(84))
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "missing key", int64(84))
+ expected = int64(84)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "invalid", int64(69))
+ expected = int64(69)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "invalid", int64(99))
+ expected = int64(99)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestHasQueryParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42")
+ r := &http.Request{URL: u}
+
+ result := HasQueryParam(r, "key")
+ expected := true
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+ }
+
+ result = HasQueryParam(r, "missing key")
+ expected = false
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+ }
+}
diff --git a/http/request/request.go b/http/request/request.go
deleted file mode 100644
index d27137b..0000000
--- a/http/request/request.go
+++ /dev/null
@@ -1,128 +0,0 @@
-// Copyright 2018 Frédéric Guillot. All rights reserved.
-// Use of this source code is governed by the Apache 2.0
-// license that can be found in the LICENSE file.
-
-package request // import "miniflux.app/http/request"
-
-import (
- "fmt"
- "net"
- "net/http"
- "strconv"
- "strings"
-
- "github.com/gorilla/mux"
-)
-
-// Cookie returns the cookie value.
-func Cookie(r *http.Request, name string) string {
- cookie, err := r.Cookie(name)
- if err == http.ErrNoCookie {
- return ""
- }
-
- return cookie.Value
-}
-
-// FormInt64Value returns a form value as integer.
-func FormInt64Value(r *http.Request, param string) int64 {
- value := r.FormValue(param)
- integer, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return 0
- }
-
- return integer
-}
-
-// IntParam returns an URL route parameter as integer.
-func IntParam(r *http.Request, param string) (int64, error) {
- vars := mux.Vars(r)
- value, err := strconv.Atoi(vars[param])
- if err != nil {
- return 0, fmt.Errorf("request: %s parameter is not an integer", param)
- }
-
- if value < 0 {
- return 0, nil
- }
-
- return int64(value), nil
-}
-
-// Param returns an URL route parameter as string.
-func Param(r *http.Request, param, defaultValue string) string {
- vars := mux.Vars(r)
- value := vars[param]
- if value == "" {
- value = defaultValue
- }
- return value
-}
-
-// QueryParam returns a querystring parameter as string.
-func QueryParam(r *http.Request, param, defaultValue string) string {
- value := r.URL.Query().Get(param)
- if value == "" {
- value = defaultValue
- }
- return value
-}
-
-// QueryIntParam returns a querystring parameter as integer.
-func QueryIntParam(r *http.Request, param string, defaultValue int) int {
- return int(QueryInt64Param(r, param, int64(defaultValue)))
-}
-
-// QueryInt64Param returns a querystring parameter as int64.
-func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
- value := r.URL.Query().Get(param)
- if value == "" {
- return defaultValue
- }
-
- val, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return defaultValue
- }
-
- if val < 0 {
- return defaultValue
- }
-
- return val
-}
-
-// HasQueryParam checks if the query string contains the given parameter.
-func HasQueryParam(r *http.Request, param string) bool {
- values := r.URL.Query()
- _, ok := values[param]
- return ok
-}
-
-// FindClientIP returns client's real IP address.
-func FindClientIP(r *http.Request) string {
- headers := []string{"X-Forwarded-For", "X-Real-Ip"}
- for _, header := range headers {
- value := r.Header.Get(header)
-
- if value != "" {
- addresses := strings.Split(value, ",")
- address := strings.TrimSpace(addresses[0])
-
- if net.ParseIP(address) != nil {
- return address
- }
- }
- }
-
- // Fallback to TCP/IP source IP address.
- var remoteIP string
- if strings.ContainsRune(r.RemoteAddr, ':') {
- remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
- } else {
- remoteIP = r.RemoteAddr
- }
-
- return remoteIP
-}