From 9d08139f4363d3503398002bc82cb3746e3438cf Mon Sep 17 00:00:00 2001 From: Frédéric Guillot Date: Sun, 23 Sep 2018 21:02:26 -0700 Subject: Improve request package and add more unit tests --- api/category.go | 12 +- api/entry.go | 47 ++--- api/feed.go | 30 +-- api/icon.go | 6 +- api/user.go | 23 +-- fever/fever.go | 2 +- http/request/client_ip.go | 38 ++++ http/request/client_ip_test.go | 82 ++++++++ http/request/context.go | 21 +- http/request/context_test.go | 436 +++++++++++++++++++++++++++++++++++++++++ http/request/cookie.go | 17 ++ http/request/cookie_test.go | 33 ++++ http/request/params.go | 84 ++++++++ http/request/params_test.go | 215 ++++++++++++++++++++ http/request/request.go | 128 ------------ http/request/request_test.go | 82 -------- middleware/app_session.go | 2 +- middleware/user_session.go | 2 +- ui/category_edit.go | 7 +- ui/category_entries.go | 7 +- ui/category_remove.go | 7 +- ui/category_update.go | 7 +- ui/entry_bookmark.go | 7 +- ui/entry_category.go | 13 +- ui/entry_feed.go | 13 +- ui/entry_read.go | 7 +- ui/entry_save.go | 7 +- ui/entry_scraper.go | 7 +- ui/entry_search.go | 9 +- ui/entry_toggle_bookmark.go | 7 +- ui/entry_unread.go | 7 +- ui/feed_edit.go | 7 +- ui/feed_entries.go | 7 +- ui/feed_icon.go | 7 +- ui/feed_refresh.go | 7 +- ui/feed_remove.go | 9 +- ui/feed_update.go | 7 +- ui/oauth2_callback.go | 6 +- ui/oauth2_redirect.go | 2 +- ui/oauth2_unlink.go | 2 +- ui/proxy.go | 2 +- ui/search_entries.go | 2 +- ui/session_remove.go | 12 +- ui/static_app_icon.go | 4 +- ui/static_javascript.go | 4 +- ui/static_stylesheet.go | 4 +- ui/subscription_bookmarklet.go | 2 +- ui/user_edit.go | 7 +- ui/user_remove.go | 9 +- ui/user_update.go | 7 +- 50 files changed, 994 insertions(+), 478 deletions(-) create mode 100644 http/request/client_ip.go create mode 100644 http/request/client_ip_test.go create mode 100644 http/request/context_test.go create mode 100644 http/request/cookie.go create mode 100644 http/request/cookie_test.go create mode 100644 http/request/params.go create mode 100644 http/request/params_test.go delete mode 100644 http/request/request.go delete mode 100644 http/request/request_test.go diff --git a/api/category.go b/api/category.go index e74aa3b..b8699e2 100644 --- a/api/category.go +++ b/api/category.go @@ -43,11 +43,7 @@ func (c *Controller) CreateCategory(w http.ResponseWriter, r *http.Request) { // UpdateCategory is the API handler to update a category. func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) { - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - json.BadRequest(w, err) - return - } + categoryID := request.RouteInt64Param(r, "categoryID") category, err := decodeCategoryPayload(r.Body) if err != nil { @@ -85,11 +81,7 @@ func (c *Controller) GetCategories(w http.ResponseWriter, r *http.Request) { // RemoveCategory is the API handler to remove a category. func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) { userID := request.UserID(r) - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - json.BadRequest(w, err) - return - } + categoryID := request.RouteInt64Param(r, "categoryID") if !c.store.CategoryExists(userID, categoryID) { json.NotFound(w, errors.New("Category not found")) diff --git a/api/entry.go b/api/entry.go index 7f87888..a1ea87f 100644 --- a/api/entry.go +++ b/api/entry.go @@ -17,17 +17,8 @@ import ( // GetFeedEntry is the API handler to get a single feed entry. func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } - - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } + feedID := request.RouteInt64Param(r, "feedID") + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(request.UserID(r)) builder.WithFeedID(feedID) @@ -49,12 +40,7 @@ func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) { // GetEntry is the API handler to get a single entry. func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) { - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(request.UserID(r)) builder.WithEntryID(entryID) @@ -74,13 +60,9 @@ func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) { // GetFeedEntries is the API handler to get all feed entries. func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } + feedID := request.RouteInt64Param(r, "feedID") - status := request.QueryParam(r, "status", "") + status := request.QueryStringParam(r, "status", "") if status != "" { if err := model.ValidateEntryStatus(status); err != nil { json.BadRequest(w, err) @@ -88,13 +70,13 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) { } } - order := request.QueryParam(r, "order", model.DefaultSortingOrder) + order := request.QueryStringParam(r, "order", model.DefaultSortingOrder) if err := model.ValidateEntryOrder(order); err != nil { json.BadRequest(w, err) return } - direction := request.QueryParam(r, "direction", model.DefaultSortingDirection) + direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection) if err := model.ValidateDirection(direction); err != nil { json.BadRequest(w, err) return @@ -133,7 +115,7 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) { // GetEntries is the API handler to fetch entries. func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) { - status := request.QueryParam(r, "status", "") + status := request.QueryStringParam(r, "status", "") if status != "" { if err := model.ValidateEntryStatus(status); err != nil { json.BadRequest(w, err) @@ -141,13 +123,13 @@ func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) { } } - order := request.QueryParam(r, "order", model.DefaultSortingOrder) + order := request.QueryStringParam(r, "order", model.DefaultSortingOrder) if err := model.ValidateEntryOrder(order); err != nil { json.BadRequest(w, err) return } - direction := request.QueryParam(r, "direction", model.DefaultSortingDirection) + direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection) if err := model.ValidateDirection(direction); err != nil { json.BadRequest(w, err) return @@ -206,12 +188,7 @@ func (c *Controller) SetEntryStatus(w http.ResponseWriter, r *http.Request) { // ToggleBookmark is the API handler to toggle bookmark status. func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) { - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil { json.ServerError(w, err) return @@ -245,7 +222,7 @@ func configureFilters(builder *storage.EntryQueryBuilder, r *http.Request) { builder.WithStarred() } - searchQuery := request.QueryParam(r, "search", "") + searchQuery := request.QueryStringParam(r, "search", "") if searchQuery != "" { builder.WithSearchQuery(searchQuery) } diff --git a/api/feed.go b/api/feed.go index 17bb73d..d193d2e 100644 --- a/api/feed.go +++ b/api/feed.go @@ -65,12 +65,7 @@ func (c *Controller) CreateFeed(w http.ResponseWriter, r *http.Request) { // RefreshFeed is the API handler to refresh a feed. func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") userID := request.UserID(r) if !c.store.FeedExists(userID, feedID) { @@ -78,7 +73,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) { return } - err = c.feedHandler.RefreshFeed(userID, feedID) + err := c.feedHandler.RefreshFeed(userID, feedID) if err != nil { json.ServerError(w, err) return @@ -89,12 +84,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) { // UpdateFeed is the API handler that is used to update a feed. func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") feedChanges, err := decodeFeedModificationPayload(r.Body) if err != nil { json.BadRequest(w, err) @@ -148,12 +138,7 @@ func (c *Controller) GetFeeds(w http.ResponseWriter, r *http.Request) { // GetFeed is the API handler to get a feed. func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") feed, err := c.store.FeedByID(request.UserID(r), feedID) if err != nil { json.ServerError(w, err) @@ -170,12 +155,7 @@ func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) { // RemoveFeed is the API handler to remove a feed. func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") userID := request.UserID(r) if !c.store.FeedExists(userID, feedID) { diff --git a/api/icon.go b/api/icon.go index f9c2964..de01fad 100644 --- a/api/icon.go +++ b/api/icon.go @@ -14,11 +14,7 @@ import ( // FeedIcon returns a feed icon. func (c *Controller) FeedIcon(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - json.BadRequest(w, err) - return - } + feedID := request.RouteInt64Param(r, "feedID") if !c.store.HasIcon(feedID) { json.NotFound(w, errors.New("This feed doesn't have any icon")) diff --git a/api/user.go b/api/user.go index 167fd72..b9274bb 100644 --- a/api/user.go +++ b/api/user.go @@ -63,12 +63,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - json.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") userChanges, err := decodeUserModificationPayload(r.Body) if err != nil { json.BadRequest(w, err) @@ -124,12 +119,7 @@ func (c *Controller) UserByID(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - json.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") user, err := c.store.UserByID(userID) if err != nil { json.BadRequest(w, errors.New("Unable to fetch this user from the database")) @@ -152,7 +142,7 @@ func (c *Controller) UserByUsername(w http.ResponseWriter, r *http.Request) { return } - username := request.Param(r, "username", "") + username := request.RouteStringParam(r, "username") user, err := c.store.UserByUsername(username) if err != nil { json.BadRequest(w, errors.New("Unable to fetch this user from the database")) @@ -174,12 +164,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - json.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") user, err := c.store.UserByID(userID) if err != nil { json.ServerError(w, err) diff --git a/fever/fever.go b/fever/fever.go index e1090b7..b754a89 100644 --- a/fever/fever.go +++ b/fever/fever.go @@ -356,7 +356,7 @@ func (c *Controller) handleItems(w http.ResponseWriter, r *http.Request) { builder.WithOffset(maxID) } - csvItemIDs := request.QueryParam(r, "with_ids", "") + csvItemIDs := request.QueryStringParam(r, "with_ids", "") if csvItemIDs != "" { var itemIDs []int64 diff --git a/http/request/client_ip.go b/http/request/client_ip.go new file mode 100644 index 0000000..52fc05c --- /dev/null +++ b/http/request/client_ip.go @@ -0,0 +1,38 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net" + "net/http" + "strings" +) + +// FindClientIP returns client real IP address. +func FindClientIP(r *http.Request) string { + headers := []string{"X-Forwarded-For", "X-Real-Ip"} + for _, header := range headers { + value := r.Header.Get(header) + + if value != "" { + addresses := strings.Split(value, ",") + address := strings.TrimSpace(addresses[0]) + + if net.ParseIP(address) != nil { + return address + } + } + } + + // Fallback to TCP/IP source IP address. + var remoteIP string + if strings.ContainsRune(r.RemoteAddr, ':') { + remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr) + } else { + remoteIP = r.RemoteAddr + } + + return remoteIP +} diff --git a/http/request/client_ip_test.go b/http/request/client_ip_test.go new file mode 100644 index 0000000..12d6e16 --- /dev/null +++ b/http/request/client_ip_test.go @@ -0,0 +1,82 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "testing" +) + +func TestFindClientIPWithoutHeaders(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.0.1:4242"} + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + r = &http.Request{RemoteAddr: "192.168.0.1"} + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestFindClientIPWithXFFHeader(t *testing.T) { + // Test with multiple IPv4 addresses. + headers := http.Header{} + headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "203.0.113.195" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with single IPv6 address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "2001:db8:85a3:8d3:1319:8a2e:370:7348") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "2001:db8:85a3:8d3:1319:8a2e:370:7348" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with single IPv4 address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "70.41.3.18") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "70.41.3.18" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } + + // Test with invalid IP address. + headers = http.Header{} + headers.Set("X-Forwarded-For", "fake IP") + r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "192.168.0.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestClientIPWithXRealIPHeader(t *testing.T) { + headers := http.Header{} + headers.Set("X-Real-Ip", "192.168.122.1") + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "192.168.122.1" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} + +func TestClientIPWithBothHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") + headers.Set("X-Real-Ip", "192.168.122.1") + + r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} + + if ip := FindClientIP(r); ip != "203.0.113.195" { + t.Fatalf(`Unexpected result, got: %q`, ip) + } +} diff --git a/http/request/context.go b/http/request/context.go index b77365d..51ee46a 100644 --- a/http/request/context.go +++ b/http/request/context.go @@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string { func getContextStringValue(r *http.Request, key ContextKey) string { if v := r.Context().Value(key); v != nil { - return v.(string) + value, valid := v.(string) + if !valid { + return "" + } + + return value } return "" @@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string { func getContextBoolValue(r *http.Request, key ContextKey) bool { if v := r.Context().Value(key); v != nil { - return v.(bool) + value, valid := v.(bool) + if !valid { + return false + } + + return value } return false @@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool { func getContextInt64Value(r *http.Request, key ContextKey) int64 { if v := r.Context().Value(key); v != nil { - return v.(int64) + value, valid := v.(int64) + if !valid { + return 0 + } + + return value } return 0 diff --git a/http/request/context_test.go b/http/request/context_test.go new file mode 100644 index 0000000..e741362 --- /dev/null +++ b/http/request/context_test.go @@ -0,0 +1,436 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "context" + "net/http" + "testing" +) + +func TestContextStringValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "IP") + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "IP" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, 0) + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextBoolValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextInt64Value(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(1234)) + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(1234) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestIsAdmin(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAdminUser(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result = IsAdminUser(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestIsAuthenticated(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAuthenticated(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) + r = r.WithContext(ctx) + + result = IsAuthenticated(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserID(r) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(123)) + r = r.WithContext(ctx) + + result = UserID(r) + expected = int64(123) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserTimezone(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTimezone(r) + expected := "UTC" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris") + r = r.WithContext(ctx) + + result = UserTimezone(r) + expected = "Europe/Paris" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserLanguage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserLanguage(r) + expected := "en_US" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR") + r = r.WithContext(ctx) + + result = UserLanguage(r) + expected = "fr_FR" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserTheme(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTheme(r) + expected := "default" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserThemeContextKey, "black") + r = r.WithContext(ctx) + + result = UserTheme(r) + expected = "black" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestCSRF(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := CSRF(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, CSRFContextKey, "secret") + r = r.WithContext(ctx) + + result = CSRF(r) + expected = "secret" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestSessionID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := SessionID(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, SessionIDContextKey, "id") + r = r.WithContext(ctx) + + result = SessionID(r) + expected = "id" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserSessionToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserSessionToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token") + r = r.WithContext(ctx) + + result = UserSessionToken(r) + expected = "token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestOAuth2State(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := OAuth2State(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, OAuth2StateContextKey, "state") + r = r.WithContext(ctx) + + result = OAuth2State(r) + expected = "state" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashMessageContextKey, "message") + r = r.WithContext(ctx) + + result = FlashMessage(r) + expected = "message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashErrorMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashErrorMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message") + r = r.WithContext(ctx) + + result = FlashErrorMessage(r) + expected = "error message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestPocketRequestToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := PocketRequestToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token") + r = r.WithContext(ctx) + + result = PocketRequestToken(r) + expected = "request token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestClientIP(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := ClientIP(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1") + r = r.WithContext(ctx) + + result = ClientIP(r) + expected = "127.0.0.1" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} diff --git a/http/request/cookie.go b/http/request/cookie.go new file mode 100644 index 0000000..88cc626 --- /dev/null +++ b/http/request/cookie.go @@ -0,0 +1,17 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import "net/http" + +// CookieValue returns the cookie value. +func CookieValue(r *http.Request, name string) string { + cookie, err := r.Cookie(name) + if err == http.ErrNoCookie { + return "" + } + + return cookie.Value +} diff --git a/http/request/cookie_test.go b/http/request/cookie_test.go new file mode 100644 index 0000000..9c3b54d --- /dev/null +++ b/http/request/cookie_test.go @@ -0,0 +1,33 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "testing" +) + +func TestGetCookieValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"}) + + result := CookieValue(r, "my_cookie") + expected := "cookie_value" + + if result != expected { + t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected) + } +} + +func TestGetCookieValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := CookieValue(r, "my_cookie") + expected := "" + + if result != expected { + t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected) + } +} diff --git a/http/request/params.go b/http/request/params.go new file mode 100644 index 0000000..8218a0d --- /dev/null +++ b/http/request/params.go @@ -0,0 +1,84 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "strconv" + + "github.com/gorilla/mux" +) + +// FormInt64Value returns a form value as integer. +func FormInt64Value(r *http.Request, param string) int64 { + value := r.FormValue(param) + integer, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return 0 + } + + return integer +} + +// RouteInt64Param returns an URL route parameter as int64. +func RouteInt64Param(r *http.Request, param string) int64 { + vars := mux.Vars(r) + value, err := strconv.ParseInt(vars[param], 10, 64) + if err != nil { + return 0 + } + + if value < 0 { + return 0 + } + + return value +} + +// RouteStringParam returns a URL route parameter as string. +func RouteStringParam(r *http.Request, param string) string { + vars := mux.Vars(r) + return vars[param] +} + +// QueryStringParam returns a query string parameter as string. +func QueryStringParam(r *http.Request, param, defaultValue string) string { + value := r.URL.Query().Get(param) + if value == "" { + value = defaultValue + } + return value +} + +// QueryIntParam returns a query string parameter as integer. +func QueryIntParam(r *http.Request, param string, defaultValue int) int { + return int(QueryInt64Param(r, param, int64(defaultValue))) +} + +// QueryInt64Param returns a query string parameter as int64. +func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 { + value := r.URL.Query().Get(param) + if value == "" { + return defaultValue + } + + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return defaultValue + } + + if val < 0 { + return defaultValue + } + + return val +} + +// HasQueryParam checks if the query string contains the given parameter. +func HasQueryParam(r *http.Request, param string) bool { + values := r.URL.Query() + _, ok := values[param] + return ok +} diff --git a/http/request/params_test.go b/http/request/params_test.go new file mode 100644 index 0000000..7f1f880 --- /dev/null +++ b/http/request/params_test.go @@ -0,0 +1,215 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" +) + +func TestFormInt64Value(t *testing.T) { + f := url.Values{} + f.Set("integer value", "42") + f.Set("invalid value", "invalid integer") + + r := &http.Request{Form: f} + + result := FormInt64Value(r, "integer value") + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = FormInt64Value(r, "invalid value") + expected = int64(0) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = FormInt64Value(r, "missing value") + expected = int64(0) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestRouteStringParam(t *testing.T) { + router := mux.NewRouter() + router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) { + result := RouteStringParam(r, "variable") + expected := "value" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + + result = RouteStringParam(r, "missing variable") + expected = "" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + }) + + r, err := http.NewRequest("GET", "/route/value/index", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + router.ServeHTTP(w, r) +} + +func TestRouteInt64Param(t *testing.T) { + router := mux.NewRouter() + router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) { + result := RouteInt64Param(r, "variable1") + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "missing variable") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "variable2") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = RouteInt64Param(r, "variable3") + expected = 0 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + }) + + r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + router.ServeHTTP(w, r) +} + +func TestQueryStringParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=value") + r := &http.Request{URL: u} + + result := QueryStringParam(r, "key", "fallback") + expected := "value" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } + + result = QueryStringParam(r, "missing key", "fallback") + expected = "fallback" + + if result != expected { + t.Errorf(`Unexpected result, got %q instead of %q`, result, expected) + } +} + +func TestQueryIntParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5") + r := &http.Request{URL: u} + + result := QueryIntParam(r, "key", 84) + expected := 42 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "missing key", 84) + expected = 84 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "negative", 69) + expected = 69 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryIntParam(r, "invalid", 99) + expected = 99 + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestQueryInt64Param(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5") + r := &http.Request{URL: u} + + result := QueryInt64Param(r, "key", int64(84)) + expected := int64(42) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "missing key", int64(84)) + expected = int64(84) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "invalid", int64(69)) + expected = int64(69) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } + + result = QueryInt64Param(r, "invalid", int64(99)) + expected = int64(99) + + if result != expected { + t.Errorf(`Unexpected result, got %d instead of %d`, result, expected) + } +} + +func TestHasQueryParam(t *testing.T) { + u, _ := url.Parse("http://example.org/?key=42") + r := &http.Request{URL: u} + + result := HasQueryParam(r, "key") + expected := true + + if result != expected { + t.Errorf(`Unexpected result, got %v instead of %v`, result, expected) + } + + result = HasQueryParam(r, "missing key") + expected = false + + if result != expected { + t.Errorf(`Unexpected result, got %v instead of %v`, result, expected) + } +} diff --git a/http/request/request.go b/http/request/request.go deleted file mode 100644 index d27137b..0000000 --- a/http/request/request.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2018 Frédéric Guillot. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package request // import "miniflux.app/http/request" - -import ( - "fmt" - "net" - "net/http" - "strconv" - "strings" - - "github.com/gorilla/mux" -) - -// Cookie returns the cookie value. -func Cookie(r *http.Request, name string) string { - cookie, err := r.Cookie(name) - if err == http.ErrNoCookie { - return "" - } - - return cookie.Value -} - -// FormInt64Value returns a form value as integer. -func FormInt64Value(r *http.Request, param string) int64 { - value := r.FormValue(param) - integer, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return 0 - } - - return integer -} - -// IntParam returns an URL route parameter as integer. -func IntParam(r *http.Request, param string) (int64, error) { - vars := mux.Vars(r) - value, err := strconv.Atoi(vars[param]) - if err != nil { - return 0, fmt.Errorf("request: %s parameter is not an integer", param) - } - - if value < 0 { - return 0, nil - } - - return int64(value), nil -} - -// Param returns an URL route parameter as string. -func Param(r *http.Request, param, defaultValue string) string { - vars := mux.Vars(r) - value := vars[param] - if value == "" { - value = defaultValue - } - return value -} - -// QueryParam returns a querystring parameter as string. -func QueryParam(r *http.Request, param, defaultValue string) string { - value := r.URL.Query().Get(param) - if value == "" { - value = defaultValue - } - return value -} - -// QueryIntParam returns a querystring parameter as integer. -func QueryIntParam(r *http.Request, param string, defaultValue int) int { - return int(QueryInt64Param(r, param, int64(defaultValue))) -} - -// QueryInt64Param returns a querystring parameter as int64. -func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 { - value := r.URL.Query().Get(param) - if value == "" { - return defaultValue - } - - val, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return defaultValue - } - - if val < 0 { - return defaultValue - } - - return val -} - -// HasQueryParam checks if the query string contains the given parameter. -func HasQueryParam(r *http.Request, param string) bool { - values := r.URL.Query() - _, ok := values[param] - return ok -} - -// FindClientIP returns client's real IP address. -func FindClientIP(r *http.Request) string { - headers := []string{"X-Forwarded-For", "X-Real-Ip"} - for _, header := range headers { - value := r.Header.Get(header) - - if value != "" { - addresses := strings.Split(value, ",") - address := strings.TrimSpace(addresses[0]) - - if net.ParseIP(address) != nil { - return address - } - } - } - - // Fallback to TCP/IP source IP address. - var remoteIP string - if strings.ContainsRune(r.RemoteAddr, ':') { - remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr) - } else { - remoteIP = r.RemoteAddr - } - - return remoteIP -} diff --git a/http/request/request_test.go b/http/request/request_test.go deleted file mode 100644 index 946b132..0000000 --- a/http/request/request_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 Frédéric Guillot. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package request // import "miniflux.app/http/request" - -import ( - "net/http" - "testing" -) - -func TestRealIPWithoutHeaders(t *testing.T) { - r := &http.Request{RemoteAddr: "192.168.0.1:4242"} - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - r = &http.Request{RemoteAddr: "192.168.0.1"} - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithXFFHeader(t *testing.T) { - // Test with multiple IPv4 addresses. - headers := http.Header{} - headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "203.0.113.195" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with single IPv6 address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "2001:db8:85a3:8d3:1319:8a2e:370:7348") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "2001:db8:85a3:8d3:1319:8a2e:370:7348" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with single IPv4 address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "70.41.3.18") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "70.41.3.18" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } - - // Test with invalid IP address. - headers = http.Header{} - headers.Set("X-Forwarded-For", "fake IP") - r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "192.168.0.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithXRealIPHeader(t *testing.T) { - headers := http.Header{} - headers.Set("X-Real-Ip", "192.168.122.1") - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "192.168.122.1" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} - -func TestRealIPWithBothHeaders(t *testing.T) { - headers := http.Header{} - headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178") - headers.Set("X-Real-Ip", "192.168.122.1") - - r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers} - - if ip := FindClientIP(r); ip != "203.0.113.195" { - t.Fatalf(`Unexpected result, got: %q`, ip) - } -} diff --git a/middleware/app_session.go b/middleware/app_session.go index 7f2b214..b505ee1 100644 --- a/middleware/app_session.go +++ b/middleware/app_session.go @@ -62,7 +62,7 @@ func (m *Middleware) AppSession(next http.Handler) http.Handler { } func (m *Middleware) getAppSessionValueFromCookie(r *http.Request) *model.Session { - cookieValue := request.Cookie(r, cookie.CookieSessionID) + cookieValue := request.CookieValue(r, cookie.CookieSessionID) if cookieValue == "" { return nil } diff --git a/middleware/user_session.go b/middleware/user_session.go index 66f0ad0..bddb47b 100644 --- a/middleware/user_session.go +++ b/middleware/user_session.go @@ -62,7 +62,7 @@ func (m *Middleware) isPublicRoute(r *http.Request) bool { } func (m *Middleware) getUserSessionFromCookie(r *http.Request) *model.UserSession { - cookieValue := request.Cookie(r, cookie.CookieUserSessionID) + cookieValue := request.CookieValue(r, cookie.CookieUserSessionID) if cookieValue == "" { return nil } diff --git a/ui/category_edit.go b/ui/category_edit.go index 0ab4b1b..b99d2a0 100644 --- a/ui/category_edit.go +++ b/ui/category_edit.go @@ -25,12 +25,7 @@ func (c *Controller) EditCategory(w http.ResponseWriter, r *http.Request) { return } - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - html.BadRequest(w, err) - return - } - + categoryID := request.RouteInt64Param(r, "categoryID") category, err := c.store.Category(request.UserID(r), categoryID) if err != nil { html.ServerError(w, err) diff --git a/ui/category_entries.go b/ui/category_entries.go index ac39fa1..caa98cd 100644 --- a/ui/category_entries.go +++ b/ui/category_entries.go @@ -23,12 +23,7 @@ func (c *Controller) CategoryEntries(w http.ResponseWriter, r *http.Request) { return } - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - html.BadRequest(w, err) - return - } - + categoryID := request.RouteInt64Param(r, "categoryID") category, err := c.store.Category(request.UserID(r), categoryID) if err != nil { html.ServerError(w, err) diff --git a/ui/category_remove.go b/ui/category_remove.go index 033d9e3..b424af5 100644 --- a/ui/category_remove.go +++ b/ui/category_remove.go @@ -21,12 +21,7 @@ func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) { return } - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - html.BadRequest(w, err) - return - } - + categoryID := request.RouteInt64Param(r, "categoryID") category, err := c.store.Category(request.UserID(r), categoryID) if err != nil { html.ServerError(w, err) diff --git a/ui/category_update.go b/ui/category_update.go index 480b4fb..90672c0 100644 --- a/ui/category_update.go +++ b/ui/category_update.go @@ -25,12 +25,7 @@ func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) { return } - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - html.BadRequest(w, err) - return - } - + categoryID := request.RouteInt64Param(r, "categoryID") category, err := c.store.Category(request.UserID(r), categoryID) if err != nil { html.ServerError(w, err) diff --git a/ui/entry_bookmark.go b/ui/entry_bookmark.go index e6d2f0e..7c42a5c 100644 --- a/ui/entry_bookmark.go +++ b/ui/entry_bookmark.go @@ -25,12 +25,7 @@ func (c *Controller) ShowStarredEntry(w http.ResponseWriter, r *http.Request) { return } - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithEntryID(entryID) builder.WithoutStatus(model.EntryStatusRemoved) diff --git a/ui/entry_category.go b/ui/entry_category.go index 795aaf1..283f015 100644 --- a/ui/entry_category.go +++ b/ui/entry_category.go @@ -25,17 +25,8 @@ func (c *Controller) ShowCategoryEntry(w http.ResponseWriter, r *http.Request) { return } - categoryID, err := request.IntParam(r, "categoryID") - if err != nil { - html.BadRequest(w, err) - return - } - - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } + categoryID := request.RouteInt64Param(r, "categoryID") + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithCategoryID(categoryID) diff --git a/ui/entry_feed.go b/ui/entry_feed.go index 968c8f4..86dd2c9 100644 --- a/ui/entry_feed.go +++ b/ui/entry_feed.go @@ -25,17 +25,8 @@ func (c *Controller) ShowFeedEntry(w http.ResponseWriter, r *http.Request) { return } - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } - - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.BadRequest(w, err) - return - } + entryID := request.RouteInt64Param(r, "entryID") + feedID := request.RouteInt64Param(r, "feedID") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithFeedID(feedID) diff --git a/ui/entry_read.go b/ui/entry_read.go index 61c7114..eeaca8e 100644 --- a/ui/entry_read.go +++ b/ui/entry_read.go @@ -24,12 +24,7 @@ func (c *Controller) ShowReadEntry(w http.ResponseWriter, r *http.Request) { return } - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithEntryID(entryID) builder.WithoutStatus(model.EntryStatusRemoved) diff --git a/ui/entry_save.go b/ui/entry_save.go index 488022a..1f846ba 100644 --- a/ui/entry_save.go +++ b/ui/entry_save.go @@ -16,12 +16,7 @@ import ( // SaveEntry send the link to external services. func (c *Controller) SaveEntry(w http.ResponseWriter, r *http.Request) { - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(request.UserID(r)) builder.WithEntryID(entryID) builder.WithoutStatus(model.EntryStatusRemoved) diff --git a/ui/entry_scraper.go b/ui/entry_scraper.go index b4a290f..4c2d58c 100644 --- a/ui/entry_scraper.go +++ b/ui/entry_scraper.go @@ -17,12 +17,7 @@ import ( // FetchContent downloads the original HTML page and returns relevant contents. func (c *Controller) FetchContent(w http.ResponseWriter, r *http.Request) { - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(request.UserID(r)) builder.WithEntryID(entryID) builder.WithoutStatus(model.EntryStatusRemoved) diff --git a/ui/entry_search.go b/ui/entry_search.go index 15babc1..8acf103 100644 --- a/ui/entry_search.go +++ b/ui/entry_search.go @@ -25,13 +25,8 @@ func (c *Controller) ShowSearchEntry(w http.ResponseWriter, r *http.Request) { return } - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } - - searchQuery := request.QueryParam(r, "q", "") + entryID := request.RouteInt64Param(r, "entryID") + searchQuery := request.QueryStringParam(r, "q", "") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithSearchQuery(searchQuery) builder.WithEntryID(entryID) diff --git a/ui/entry_toggle_bookmark.go b/ui/entry_toggle_bookmark.go index e8e87ca..14a2c75 100644 --- a/ui/entry_toggle_bookmark.go +++ b/ui/entry_toggle_bookmark.go @@ -14,12 +14,7 @@ import ( // ToggleBookmark handles Ajax request to toggle bookmark value. func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) { - entryID, err := request.IntParam(r, "entryID") - if err != nil { - json.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil { logger.Error("[Controller:ToggleBookmark] %v", err) json.ServerError(w, nil) diff --git a/ui/entry_unread.go b/ui/entry_unread.go index 30a6d34..4ef5731 100644 --- a/ui/entry_unread.go +++ b/ui/entry_unread.go @@ -25,12 +25,7 @@ func (c *Controller) ShowUnreadEntry(w http.ResponseWriter, r *http.Request) { return } - entryID, err := request.IntParam(r, "entryID") - if err != nil { - html.BadRequest(w, err) - return - } - + entryID := request.RouteInt64Param(r, "entryID") builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithEntryID(entryID) builder.WithoutStatus(model.EntryStatusRemoved) diff --git a/ui/feed_edit.go b/ui/feed_edit.go index 9063d51..8b1b3cb 100644 --- a/ui/feed_edit.go +++ b/ui/feed_edit.go @@ -23,12 +23,7 @@ func (c *Controller) EditFeed(w http.ResponseWriter, r *http.Request) { return } - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") feed, err := c.store.FeedByID(user.ID, feedID) if err != nil { html.ServerError(w, err) diff --git a/ui/feed_entries.go b/ui/feed_entries.go index dc0f05d..06a298b 100644 --- a/ui/feed_entries.go +++ b/ui/feed_entries.go @@ -23,12 +23,7 @@ func (c *Controller) ShowFeedEntries(w http.ResponseWriter, r *http.Request) { return } - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") feed, err := c.store.FeedByID(user.ID, feedID) if err != nil { html.ServerError(w, err) diff --git a/ui/feed_icon.go b/ui/feed_icon.go index c5a7414..0aa7089 100644 --- a/ui/feed_icon.go +++ b/ui/feed_icon.go @@ -15,12 +15,7 @@ import ( // ShowIcon shows the feed icon. func (c *Controller) ShowIcon(w http.ResponseWriter, r *http.Request) { - iconID, err := request.IntParam(r, "iconID") - if err != nil { - html.BadRequest(w, err) - return - } - + iconID := request.RouteInt64Param(r, "iconID") icon, err := c.store.IconByID(iconID) if err != nil { html.ServerError(w, err) diff --git a/ui/feed_refresh.go b/ui/feed_refresh.go index d9da238..df93c6e 100644 --- a/ui/feed_refresh.go +++ b/ui/feed_refresh.go @@ -16,12 +16,7 @@ import ( // RefreshFeed refresh a subscription and redirect to the feed entries page. func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") if err := c.feedHandler.RefreshFeed(request.UserID(r), feedID); err != nil { logger.Error("[Controller:RefreshFeed] %v", err) } diff --git a/ui/feed_remove.go b/ui/feed_remove.go index 9071a33..d1ab01a 100644 --- a/ui/feed_remove.go +++ b/ui/feed_remove.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "net/http" @@ -15,12 +15,7 @@ import ( // RemoveFeed deletes a subscription from the database and redirect to the list of feeds page. func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) { - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.ServerError(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") if err := c.store.RemoveFeed(request.UserID(r), feedID); err != nil { html.ServerError(w, err) return diff --git a/ui/feed_update.go b/ui/feed_update.go index a6fbcbf..6cc4776 100644 --- a/ui/feed_update.go +++ b/ui/feed_update.go @@ -26,12 +26,7 @@ func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) { return } - feedID, err := request.IntParam(r, "feedID") - if err != nil { - html.BadRequest(w, err) - return - } - + feedID := request.RouteInt64Param(r, "feedID") feed, err := c.store.FeedByID(user.ID, feedID) if err != nil { html.ServerError(w, err) diff --git a/ui/oauth2_callback.go b/ui/oauth2_callback.go index 9f51d0a..1902d6e 100644 --- a/ui/oauth2_callback.go +++ b/ui/oauth2_callback.go @@ -23,21 +23,21 @@ func (c *Controller) OAuth2Callback(w http.ResponseWriter, r *http.Request) { printer := locale.NewPrinter(request.UserLanguage(r)) sess := session.New(c.store, request.SessionID(r)) - provider := request.Param(r, "provider", "") + provider := request.RouteStringParam(r, "provider") if provider == "" { logger.Error("[OAuth2] Invalid or missing provider") response.Redirect(w, r, route.Path(c.router, "login")) return } - code := request.QueryParam(r, "code", "") + code := request.QueryStringParam(r, "code", "") if code == "" { logger.Error("[OAuth2] No code received on callback") response.Redirect(w, r, route.Path(c.router, "login")) return } - state := request.QueryParam(r, "state", "") + state := request.QueryStringParam(r, "state", "") if state == "" || state != request.OAuth2State(r) { logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r)) response.Redirect(w, r, route.Path(c.router, "login")) diff --git a/ui/oauth2_redirect.go b/ui/oauth2_redirect.go index 3b7c88a..90c20b1 100644 --- a/ui/oauth2_redirect.go +++ b/ui/oauth2_redirect.go @@ -18,7 +18,7 @@ import ( func (c *Controller) OAuth2Redirect(w http.ResponseWriter, r *http.Request) { sess := session.New(c.store, request.SessionID(r)) - provider := request.Param(r, "provider", "") + provider := request.RouteStringParam(r, "provider") if provider == "" { logger.Error("[OAuth2] Invalid or missing provider: %s", provider) response.Redirect(w, r, route.Path(c.router, "login")) diff --git a/ui/oauth2_unlink.go b/ui/oauth2_unlink.go index 4435733..022e282 100644 --- a/ui/oauth2_unlink.go +++ b/ui/oauth2_unlink.go @@ -19,7 +19,7 @@ import ( // OAuth2Unlink unlink an account from the external provider. func (c *Controller) OAuth2Unlink(w http.ResponseWriter, r *http.Request) { printer := locale.NewPrinter(request.UserLanguage(r)) - provider := request.Param(r, "provider", "") + provider := request.RouteStringParam(r, "provider") if provider == "" { logger.Info("[OAuth2] Invalid or missing provider") response.Redirect(w, r, route.Path(c.router, "login")) diff --git a/ui/proxy.go b/ui/proxy.go index 68e4db0..553dcce 100644 --- a/ui/proxy.go +++ b/ui/proxy.go @@ -27,7 +27,7 @@ func (c *Controller) ImageProxy(w http.ResponseWriter, r *http.Request) { return } - encodedURL := request.Param(r, "encodedURL", "") + encodedURL := request.RouteStringParam(r, "encodedURL") if encodedURL == "" { html.BadRequest(w, errors.New("No URL provided")) return diff --git a/ui/search_entries.go b/ui/search_entries.go index 3f17ede..e1dcbad 100644 --- a/ui/search_entries.go +++ b/ui/search_entries.go @@ -23,7 +23,7 @@ func (c *Controller) ShowSearchEntries(w http.ResponseWriter, r *http.Request) { return } - searchQuery := request.QueryParam(r, "q", "") + searchQuery := request.QueryStringParam(r, "q", "") offset := request.QueryIntParam(r, "offset", 0) builder := c.store.NewEntryQueryBuilder(user.ID) builder.WithSearchQuery(searchQuery) diff --git a/ui/session_remove.go b/ui/session_remove.go index f08cd3d..27201fd 100644 --- a/ui/session_remove.go +++ b/ui/session_remove.go @@ -2,27 +2,21 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "net/http" "miniflux.app/http/request" "miniflux.app/http/response" - "miniflux.app/http/response/html" "miniflux.app/http/route" "miniflux.app/logger" ) // RemoveSession remove a user session. func (c *Controller) RemoveSession(w http.ResponseWriter, r *http.Request) { - sessionID, err := request.IntParam(r, "sessionID") - if err != nil { - html.BadRequest(w, err) - return - } - - err = c.store.RemoveUserSessionByID(request.UserID(r), sessionID) + sessionID := request.RouteInt64Param(r, "sessionID") + err := c.store.RemoveUserSessionByID(request.UserID(r), sessionID) if err != nil { logger.Error("[Controller:RemoveSession] %v", err) } diff --git a/ui/static_app_icon.go b/ui/static_app_icon.go index 6da3f1a..2ad42fd 100644 --- a/ui/static_app_icon.go +++ b/ui/static_app_icon.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "encoding/base64" @@ -18,7 +18,7 @@ import ( // AppIcon renders application icons. func (c *Controller) AppIcon(w http.ResponseWriter, r *http.Request) { - filename := request.Param(r, "filename", "favicon.png") + filename := request.RouteStringParam(r, "filename") encodedBlob, found := static.Binaries[filename] if !found { logger.Info("[Controller:AppIcon] This icon doesn't exists: %s", filename) diff --git a/ui/static_javascript.go b/ui/static_javascript.go index c52251c..248fae3 100644 --- a/ui/static_javascript.go +++ b/ui/static_javascript.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "net/http" @@ -16,7 +16,7 @@ import ( // Javascript renders application client side code. func (c *Controller) Javascript(w http.ResponseWriter, r *http.Request) { - filename := request.Param(r, "name", "app") + filename := request.RouteStringParam(r, "name") if _, found := static.Javascripts[filename]; !found { html.NotFound(w) return diff --git a/ui/static_stylesheet.go b/ui/static_stylesheet.go index de540fe..8e475bd 100644 --- a/ui/static_stylesheet.go +++ b/ui/static_stylesheet.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "net/http" @@ -16,7 +16,7 @@ import ( // Stylesheet renders the CSS. func (c *Controller) Stylesheet(w http.ResponseWriter, r *http.Request) { - stylesheet := request.Param(r, "name", "default") + stylesheet := request.RouteStringParam(r, "name") if _, found := static.Stylesheets[stylesheet]; !found { html.NotFound(w) return diff --git a/ui/subscription_bookmarklet.go b/ui/subscription_bookmarklet.go index 3aa1392..cba039a 100644 --- a/ui/subscription_bookmarklet.go +++ b/ui/subscription_bookmarklet.go @@ -32,7 +32,7 @@ func (c *Controller) Bookmarklet(w http.ResponseWriter, r *http.Request) { return } - bookmarkletURL := request.QueryParam(r, "uri", "") + bookmarkletURL := request.QueryStringParam(r, "uri", "") view.Set("form", form.SubscriptionForm{URL: bookmarkletURL}) view.Set("categories", categories) diff --git a/ui/user_edit.go b/ui/user_edit.go index 79e6d4a..f2c1abf 100644 --- a/ui/user_edit.go +++ b/ui/user_edit.go @@ -30,12 +30,7 @@ func (c *Controller) EditUser(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - html.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") selectedUser, err := c.store.UserByID(userID) if err != nil { html.ServerError(w, err) diff --git a/ui/user_remove.go b/ui/user_remove.go index 4b5440d..981a7a2 100644 --- a/ui/user_remove.go +++ b/ui/user_remove.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. -package ui // import "miniflux.app/ui" +package ui // import "miniflux.app/ui" import ( "net/http" @@ -26,12 +26,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - html.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") selectedUser, err := c.store.UserByID(userID) if err != nil { html.ServerError(w, err) diff --git a/ui/user_update.go b/ui/user_update.go index da28e39..006d49a 100644 --- a/ui/user_update.go +++ b/ui/user_update.go @@ -30,12 +30,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - userID, err := request.IntParam(r, "userID") - if err != nil { - html.BadRequest(w, err) - return - } - + userID := request.RouteInt64Param(r, "userID") selectedUser, err := c.store.UserByID(userID) if err != nil { html.ServerError(w, err) -- cgit v1.2.3