aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--api/category.go12
-rw-r--r--api/entry.go47
-rw-r--r--api/feed.go30
-rw-r--r--api/icon.go6
-rw-r--r--api/user.go23
-rw-r--r--fever/fever.go2
-rw-r--r--http/request/client_ip.go38
-rw-r--r--http/request/client_ip_test.go (renamed from http/request/request_test.go)8
-rw-r--r--http/request/context.go21
-rw-r--r--http/request/context_test.go436
-rw-r--r--http/request/cookie.go17
-rw-r--r--http/request/cookie_test.go33
-rw-r--r--http/request/params.go84
-rw-r--r--http/request/params_test.go215
-rw-r--r--http/request/request.go128
-rw-r--r--middleware/app_session.go2
-rw-r--r--middleware/user_session.go2
-rw-r--r--ui/category_edit.go7
-rw-r--r--ui/category_entries.go7
-rw-r--r--ui/category_remove.go7
-rw-r--r--ui/category_update.go7
-rw-r--r--ui/entry_bookmark.go7
-rw-r--r--ui/entry_category.go13
-rw-r--r--ui/entry_feed.go13
-rw-r--r--ui/entry_read.go7
-rw-r--r--ui/entry_save.go7
-rw-r--r--ui/entry_scraper.go7
-rw-r--r--ui/entry_search.go9
-rw-r--r--ui/entry_toggle_bookmark.go7
-rw-r--r--ui/entry_unread.go7
-rw-r--r--ui/feed_edit.go7
-rw-r--r--ui/feed_entries.go7
-rw-r--r--ui/feed_icon.go7
-rw-r--r--ui/feed_refresh.go7
-rw-r--r--ui/feed_remove.go9
-rw-r--r--ui/feed_update.go7
-rw-r--r--ui/oauth2_callback.go6
-rw-r--r--ui/oauth2_redirect.go2
-rw-r--r--ui/oauth2_unlink.go2
-rw-r--r--ui/proxy.go2
-rw-r--r--ui/search_entries.go2
-rw-r--r--ui/session_remove.go12
-rw-r--r--ui/static_app_icon.go4
-rw-r--r--ui/static_javascript.go4
-rw-r--r--ui/static_stylesheet.go4
-rw-r--r--ui/subscription_bookmarklet.go2
-rw-r--r--ui/user_edit.go7
-rw-r--r--ui/user_remove.go9
-rw-r--r--ui/user_update.go7
49 files changed, 916 insertions, 400 deletions
diff --git a/api/category.go b/api/category.go
index e74aa3b..b8699e2 100644
--- a/api/category.go
+++ b/api/category.go
@@ -43,11 +43,7 @@ func (c *Controller) CreateCategory(w http.ResponseWriter, r *http.Request) {
// UpdateCategory is the API handler to update a category.
func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) {
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
+ categoryID := request.RouteInt64Param(r, "categoryID")
category, err := decodeCategoryPayload(r.Body)
if err != nil {
@@ -85,11 +81,7 @@ func (c *Controller) GetCategories(w http.ResponseWriter, r *http.Request) {
// RemoveCategory is the API handler to remove a category.
func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) {
userID := request.UserID(r)
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
+ categoryID := request.RouteInt64Param(r, "categoryID")
if !c.store.CategoryExists(userID, categoryID) {
json.NotFound(w, errors.New("Category not found"))
diff --git a/api/entry.go b/api/entry.go
index 7f87888..a1ea87f 100644
--- a/api/entry.go
+++ b/api/entry.go
@@ -17,17 +17,8 @@ import (
// GetFeedEntry is the API handler to get a single feed entry.
func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
+ feedID := request.RouteInt64Param(r, "feedID")
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
builder.WithFeedID(feedID)
@@ -49,12 +40,7 @@ func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) {
// GetEntry is the API handler to get a single entry.
func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) {
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
builder.WithEntryID(entryID)
@@ -74,13 +60,9 @@ func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) {
// GetFeedEntries is the API handler to get all feed entries.
func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
+ feedID := request.RouteInt64Param(r, "feedID")
- status := request.QueryParam(r, "status", "")
+ status := request.QueryStringParam(r, "status", "")
if status != "" {
if err := model.ValidateEntryStatus(status); err != nil {
json.BadRequest(w, err)
@@ -88,13 +70,13 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
}
}
- order := request.QueryParam(r, "order", model.DefaultSortingOrder)
+ order := request.QueryStringParam(r, "order", model.DefaultSortingOrder)
if err := model.ValidateEntryOrder(order); err != nil {
json.BadRequest(w, err)
return
}
- direction := request.QueryParam(r, "direction", model.DefaultSortingDirection)
+ direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection)
if err := model.ValidateDirection(direction); err != nil {
json.BadRequest(w, err)
return
@@ -133,7 +115,7 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
// GetEntries is the API handler to fetch entries.
func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) {
- status := request.QueryParam(r, "status", "")
+ status := request.QueryStringParam(r, "status", "")
if status != "" {
if err := model.ValidateEntryStatus(status); err != nil {
json.BadRequest(w, err)
@@ -141,13 +123,13 @@ func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) {
}
}
- order := request.QueryParam(r, "order", model.DefaultSortingOrder)
+ order := request.QueryStringParam(r, "order", model.DefaultSortingOrder)
if err := model.ValidateEntryOrder(order); err != nil {
json.BadRequest(w, err)
return
}
- direction := request.QueryParam(r, "direction", model.DefaultSortingDirection)
+ direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection)
if err := model.ValidateDirection(direction); err != nil {
json.BadRequest(w, err)
return
@@ -206,12 +188,7 @@ func (c *Controller) SetEntryStatus(w http.ResponseWriter, r *http.Request) {
// ToggleBookmark is the API handler to toggle bookmark status.
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
json.ServerError(w, err)
return
@@ -245,7 +222,7 @@ func configureFilters(builder *storage.EntryQueryBuilder, r *http.Request) {
builder.WithStarred()
}
- searchQuery := request.QueryParam(r, "search", "")
+ searchQuery := request.QueryStringParam(r, "search", "")
if searchQuery != "" {
builder.WithSearchQuery(searchQuery)
}
diff --git a/api/feed.go b/api/feed.go
index 17bb73d..d193d2e 100644
--- a/api/feed.go
+++ b/api/feed.go
@@ -65,12 +65,7 @@ func (c *Controller) CreateFeed(w http.ResponseWriter, r *http.Request) {
// RefreshFeed is the API handler to refresh a feed.
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
userID := request.UserID(r)
if !c.store.FeedExists(userID, feedID) {
@@ -78,7 +73,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
return
}
- err = c.feedHandler.RefreshFeed(userID, feedID)
+ err := c.feedHandler.RefreshFeed(userID, feedID)
if err != nil {
json.ServerError(w, err)
return
@@ -89,12 +84,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
// UpdateFeed is the API handler that is used to update a feed.
func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
feedChanges, err := decodeFeedModificationPayload(r.Body)
if err != nil {
json.BadRequest(w, err)
@@ -148,12 +138,7 @@ func (c *Controller) GetFeeds(w http.ResponseWriter, r *http.Request) {
// GetFeed is the API handler to get a feed.
func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
feed, err := c.store.FeedByID(request.UserID(r), feedID)
if err != nil {
json.ServerError(w, err)
@@ -170,12 +155,7 @@ func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) {
// RemoveFeed is the API handler to remove a feed.
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
userID := request.UserID(r)
if !c.store.FeedExists(userID, feedID) {
diff --git a/api/icon.go b/api/icon.go
index f9c2964..de01fad 100644
--- a/api/icon.go
+++ b/api/icon.go
@@ -14,11 +14,7 @@ import (
// FeedIcon returns a feed icon.
func (c *Controller) FeedIcon(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
+ feedID := request.RouteInt64Param(r, "feedID")
if !c.store.HasIcon(feedID) {
json.NotFound(w, errors.New("This feed doesn't have any icon"))
diff --git a/api/user.go b/api/user.go
index 167fd72..b9274bb 100644
--- a/api/user.go
+++ b/api/user.go
@@ -63,12 +63,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
userChanges, err := decodeUserModificationPayload(r.Body)
if err != nil {
json.BadRequest(w, err)
@@ -124,12 +119,7 @@ func (c *Controller) UserByID(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
user, err := c.store.UserByID(userID)
if err != nil {
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
@@ -152,7 +142,7 @@ func (c *Controller) UserByUsername(w http.ResponseWriter, r *http.Request) {
return
}
- username := request.Param(r, "username", "")
+ username := request.RouteStringParam(r, "username")
user, err := c.store.UserByUsername(username)
if err != nil {
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
@@ -174,12 +164,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
user, err := c.store.UserByID(userID)
if err != nil {
json.ServerError(w, err)
diff --git a/fever/fever.go b/fever/fever.go
index e1090b7..b754a89 100644
--- a/fever/fever.go
+++ b/fever/fever.go
@@ -356,7 +356,7 @@ func (c *Controller) handleItems(w http.ResponseWriter, r *http.Request) {
builder.WithOffset(maxID)
}
- csvItemIDs := request.QueryParam(r, "with_ids", "")
+ csvItemIDs := request.QueryStringParam(r, "with_ids", "")
if csvItemIDs != "" {
var itemIDs []int64
diff --git a/http/request/client_ip.go b/http/request/client_ip.go
new file mode 100644
index 0000000..52fc05c
--- /dev/null
+++ b/http/request/client_ip.go
@@ -0,0 +1,38 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net"
+ "net/http"
+ "strings"
+)
+
+// FindClientIP returns client real IP address.
+func FindClientIP(r *http.Request) string {
+ headers := []string{"X-Forwarded-For", "X-Real-Ip"}
+ for _, header := range headers {
+ value := r.Header.Get(header)
+
+ if value != "" {
+ addresses := strings.Split(value, ",")
+ address := strings.TrimSpace(addresses[0])
+
+ if net.ParseIP(address) != nil {
+ return address
+ }
+ }
+ }
+
+ // Fallback to TCP/IP source IP address.
+ var remoteIP string
+ if strings.ContainsRune(r.RemoteAddr, ':') {
+ remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
+ } else {
+ remoteIP = r.RemoteAddr
+ }
+
+ return remoteIP
+}
diff --git a/http/request/request_test.go b/http/request/client_ip_test.go
index 946b132..12d6e16 100644
--- a/http/request/request_test.go
+++ b/http/request/client_ip_test.go
@@ -9,7 +9,7 @@ import (
"testing"
)
-func TestRealIPWithoutHeaders(t *testing.T) {
+func TestFindClientIPWithoutHeaders(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
if ip := FindClientIP(r); ip != "192.168.0.1" {
t.Fatalf(`Unexpected result, got: %q`, ip)
@@ -21,7 +21,7 @@ func TestRealIPWithoutHeaders(t *testing.T) {
}
}
-func TestRealIPWithXFFHeader(t *testing.T) {
+func TestFindClientIPWithXFFHeader(t *testing.T) {
// Test with multiple IPv4 addresses.
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
@@ -59,7 +59,7 @@ func TestRealIPWithXFFHeader(t *testing.T) {
}
}
-func TestRealIPWithXRealIPHeader(t *testing.T) {
+func TestClientIPWithXRealIPHeader(t *testing.T) {
headers := http.Header{}
headers.Set("X-Real-Ip", "192.168.122.1")
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
@@ -69,7 +69,7 @@ func TestRealIPWithXRealIPHeader(t *testing.T) {
}
}
-func TestRealIPWithBothHeaders(t *testing.T) {
+func TestClientIPWithBothHeaders(t *testing.T) {
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
headers.Set("X-Real-Ip", "192.168.122.1")
diff --git a/http/request/context.go b/http/request/context.go
index b77365d..51ee46a 100644
--- a/http/request/context.go
+++ b/http/request/context.go
@@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string {
func getContextStringValue(r *http.Request, key ContextKey) string {
if v := r.Context().Value(key); v != nil {
- return v.(string)
+ value, valid := v.(string)
+ if !valid {
+ return ""
+ }
+
+ return value
}
return ""
@@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string {
func getContextBoolValue(r *http.Request, key ContextKey) bool {
if v := r.Context().Value(key); v != nil {
- return v.(bool)
+ value, valid := v.(bool)
+ if !valid {
+ return false
+ }
+
+ return value
}
return false
@@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool {
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
if v := r.Context().Value(key); v != nil {
- return v.(int64)
+ value, valid := v.(int64)
+ if !valid {
+ return 0
+ }
+
+ return value
}
return 0
diff --git a/http/request/context_test.go b/http/request/context_test.go
new file mode 100644
index 0000000..e741362
--- /dev/null
+++ b/http/request/context_test.go
@@ -0,0 +1,436 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "context"
+ "net/http"
+ "testing"
+)
+
+func TestContextStringValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
+ r = r.WithContext(ctx)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := "IP"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextStringValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, 0)
+ r = r.WithContext(ctx)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextStringValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextStringValue(r, ClientIPContextKey)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestContextBoolValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
+ r = r.WithContext(ctx)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextBoolValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
+ r = r.WithContext(ctx)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextBoolValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextBoolValue(r, IsAdminUserContextKey)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestContextInt64Value(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
+ r = r.WithContext(ctx)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(1234)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestContextInt64ValueWithInvalidType(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
+ r = r.WithContext(ctx)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestContextInt64ValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := getContextInt64Value(r, UserIDContextKey)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestIsAdmin(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := IsAdminUser(r)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
+ r = r.WithContext(ctx)
+
+ result = IsAdminUser(r)
+ expected = true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestIsAuthenticated(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := IsAuthenticated(r)
+ expected := false
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
+ r = r.WithContext(ctx)
+
+ result = IsAuthenticated(r)
+ expected = true
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestUserID(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserID(r)
+ expected := int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
+ r = r.WithContext(ctx)
+
+ result = UserID(r)
+ expected = int64(123)
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+ }
+}
+
+func TestUserTimezone(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserTimezone(r)
+ expected := "UTC"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
+ r = r.WithContext(ctx)
+
+ result = UserTimezone(r)
+ expected = "Europe/Paris"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserLanguage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserLanguage(r)
+ expected := "en_US"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
+ r = r.WithContext(ctx)
+
+ result = UserLanguage(r)
+ expected = "fr_FR"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserTheme(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserTheme(r)
+ expected := "default"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserThemeContextKey, "black")
+ r = r.WithContext(ctx)
+
+ result = UserTheme(r)
+ expected = "black"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestCSRF(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := CSRF(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, CSRFContextKey, "secret")
+ r = r.WithContext(ctx)
+
+ result = CSRF(r)
+ expected = "secret"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestSessionID(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := SessionID(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, SessionIDContextKey, "id")
+ r = r.WithContext(ctx)
+
+ result = SessionID(r)
+ expected = "id"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestUserSessionToken(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := UserSessionToken(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
+ r = r.WithContext(ctx)
+
+ result = UserSessionToken(r)
+ expected = "token"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestOAuth2State(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := OAuth2State(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
+ r = r.WithContext(ctx)
+
+ result = OAuth2State(r)
+ expected = "state"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestFlashMessage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := FlashMessage(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
+ r = r.WithContext(ctx)
+
+ result = FlashMessage(r)
+ expected = "message"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestFlashErrorMessage(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := FlashErrorMessage(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
+ r = r.WithContext(ctx)
+
+ result = FlashErrorMessage(r)
+ expected = "error message"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestPocketRequestToken(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := PocketRequestToken(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
+ r = r.WithContext(ctx)
+
+ result = PocketRequestToken(r)
+ expected = "request token"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestClientIP(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := ClientIP(r)
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
+ r = r.WithContext(ctx)
+
+ result = ClientIP(r)
+ expected = "127.0.0.1"
+
+ if result != expected {
+ t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+ }
+}
diff --git a/http/request/cookie.go b/http/request/cookie.go
new file mode 100644
index 0000000..88cc626
--- /dev/null
+++ b/http/request/cookie.go
@@ -0,0 +1,17 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import "net/http"
+
+// CookieValue returns the cookie value.
+func CookieValue(r *http.Request, name string) string {
+ cookie, err := r.Cookie(name)
+ if err == http.ErrNoCookie {
+ return ""
+ }
+
+ return cookie.Value
+}
diff --git a/http/request/cookie_test.go b/http/request/cookie_test.go
new file mode 100644
index 0000000..9c3b54d
--- /dev/null
+++ b/http/request/cookie_test.go
@@ -0,0 +1,33 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "testing"
+)
+
+func TestGetCookieValue(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+ r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"})
+
+ result := CookieValue(r, "my_cookie")
+ expected := "cookie_value"
+
+ if result != expected {
+ t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestGetCookieValueWhenUnset(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+ result := CookieValue(r, "my_cookie")
+ expected := ""
+
+ if result != expected {
+ t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
+ }
+}
diff --git a/http/request/params.go b/http/request/params.go
new file mode 100644
index 0000000..8218a0d
--- /dev/null
+++ b/http/request/params.go
@@ -0,0 +1,84 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "strconv"
+
+ "github.com/gorilla/mux"
+)
+
+// FormInt64Value returns a form value as integer.
+func FormInt64Value(r *http.Request, param string) int64 {
+ value := r.FormValue(param)
+ integer, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return 0
+ }
+
+ return integer
+}
+
+// RouteInt64Param returns an URL route parameter as int64.
+func RouteInt64Param(r *http.Request, param string) int64 {
+ vars := mux.Vars(r)
+ value, err := strconv.ParseInt(vars[param], 10, 64)
+ if err != nil {
+ return 0
+ }
+
+ if value < 0 {
+ return 0
+ }
+
+ return value
+}
+
+// RouteStringParam returns a URL route parameter as string.
+func RouteStringParam(r *http.Request, param string) string {
+ vars := mux.Vars(r)
+ return vars[param]
+}
+
+// QueryStringParam returns a query string parameter as string.
+func QueryStringParam(r *http.Request, param, defaultValue string) string {
+ value := r.URL.Query().Get(param)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// QueryIntParam returns a query string parameter as integer.
+func QueryIntParam(r *http.Request, param string, defaultValue int) int {
+ return int(QueryInt64Param(r, param, int64(defaultValue)))
+}
+
+// QueryInt64Param returns a query string parameter as int64.
+func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
+ value := r.URL.Query().Get(param)
+ if value == "" {
+ return defaultValue
+ }
+
+ val, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return defaultValue
+ }
+
+ if val < 0 {
+ return defaultValue
+ }
+
+ return val
+}
+
+// HasQueryParam checks if the query string contains the given parameter.
+func HasQueryParam(r *http.Request, param string) bool {
+ values := r.URL.Query()
+ _, ok := values[param]
+ return ok
+}
diff --git a/http/request/params_test.go b/http/request/params_test.go
new file mode 100644
index 0000000..7f1f880
--- /dev/null
+++ b/http/request/params_test.go
@@ -0,0 +1,215 @@
+// Copyright 2018 Frédéric Guillot. All rights reserved.
+// Use of this source code is governed by the Apache 2.0
+// license that can be found in the LICENSE file.
+
+package request // import "miniflux.app/http/request"
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/gorilla/mux"
+)
+
+func TestFormInt64Value(t *testing.T) {
+ f := url.Values{}
+ f.Set("integer value", "42")
+ f.Set("invalid value", "invalid integer")
+
+ r := &http.Request{Form: f}
+
+ result := FormInt64Value(r, "integer value")
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = FormInt64Value(r, "invalid value")
+ expected = int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = FormInt64Value(r, "missing value")
+ expected = int64(0)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestRouteStringParam(t *testing.T) {
+ router := mux.NewRouter()
+ router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
+ result := RouteStringParam(r, "variable")
+ expected := "value"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+
+ result = RouteStringParam(r, "missing variable")
+ expected = ""
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+ })
+
+ r, err := http.NewRequest("GET", "/route/value/index", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+}
+
+func TestRouteInt64Param(t *testing.T) {
+ router := mux.NewRouter()
+ router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
+ result := RouteInt64Param(r, "variable1")
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "missing variable")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "variable2")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = RouteInt64Param(r, "variable3")
+ expected = 0
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+ })
+
+ r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+}
+
+func TestQueryStringParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=value")
+ r := &http.Request{URL: u}
+
+ result := QueryStringParam(r, "key", "fallback")
+ expected := "value"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+
+ result = QueryStringParam(r, "missing key", "fallback")
+ expected = "fallback"
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+ }
+}
+
+func TestQueryIntParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
+ r := &http.Request{URL: u}
+
+ result := QueryIntParam(r, "key", 84)
+ expected := 42
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "missing key", 84)
+ expected = 84
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "negative", 69)
+ expected = 69
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryIntParam(r, "invalid", 99)
+ expected = 99
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestQueryInt64Param(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
+ r := &http.Request{URL: u}
+
+ result := QueryInt64Param(r, "key", int64(84))
+ expected := int64(42)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "missing key", int64(84))
+ expected = int64(84)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "invalid", int64(69))
+ expected = int64(69)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+
+ result = QueryInt64Param(r, "invalid", int64(99))
+ expected = int64(99)
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+ }
+}
+
+func TestHasQueryParam(t *testing.T) {
+ u, _ := url.Parse("http://example.org/?key=42")
+ r := &http.Request{URL: u}
+
+ result := HasQueryParam(r, "key")
+ expected := true
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+ }
+
+ result = HasQueryParam(r, "missing key")
+ expected = false
+
+ if result != expected {
+ t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+ }
+}
diff --git a/http/request/request.go b/http/request/request.go
deleted file mode 100644
index d27137b..0000000
--- a/http/request/request.go
+++ /dev/null
@@ -1,128 +0,0 @@
-// Copyright 2018 Frédéric Guillot. All rights reserved.
-// Use of this source code is governed by the Apache 2.0
-// license that can be found in the LICENSE file.
-
-package request // import "miniflux.app/http/request"
-
-import (
- "fmt"
- "net"
- "net/http"
- "strconv"
- "strings"
-
- "github.com/gorilla/mux"
-)
-
-// Cookie returns the cookie value.
-func Cookie(r *http.Request, name string) string {
- cookie, err := r.Cookie(name)
- if err == http.ErrNoCookie {
- return ""
- }
-
- return cookie.Value
-}
-
-// FormInt64Value returns a form value as integer.
-func FormInt64Value(r *http.Request, param string) int64 {
- value := r.FormValue(param)
- integer, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return 0
- }
-
- return integer
-}
-
-// IntParam returns an URL route parameter as integer.
-func IntParam(r *http.Request, param string) (int64, error) {
- vars := mux.Vars(r)
- value, err := strconv.Atoi(vars[param])
- if err != nil {
- return 0, fmt.Errorf("request: %s parameter is not an integer", param)
- }
-
- if value < 0 {
- return 0, nil
- }
-
- return int64(value), nil
-}
-
-// Param returns an URL route parameter as string.
-func Param(r *http.Request, param, defaultValue string) string {
- vars := mux.Vars(r)
- value := vars[param]
- if value == "" {
- value = defaultValue
- }
- return value
-}
-
-// QueryParam returns a querystring parameter as string.
-func QueryParam(r *http.Request, param, defaultValue string) string {
- value := r.URL.Query().Get(param)
- if value == "" {
- value = defaultValue
- }
- return value
-}
-
-// QueryIntParam returns a querystring parameter as integer.
-func QueryIntParam(r *http.Request, param string, defaultValue int) int {
- return int(QueryInt64Param(r, param, int64(defaultValue)))
-}
-
-// QueryInt64Param returns a querystring parameter as int64.
-func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
- value := r.URL.Query().Get(param)
- if value == "" {
- return defaultValue
- }
-
- val, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return defaultValue
- }
-
- if val < 0 {
- return defaultValue
- }
-
- return val
-}
-
-// HasQueryParam checks if the query string contains the given parameter.
-func HasQueryParam(r *http.Request, param string) bool {
- values := r.URL.Query()
- _, ok := values[param]
- return ok
-}
-
-// FindClientIP returns client's real IP address.
-func FindClientIP(r *http.Request) string {
- headers := []string{"X-Forwarded-For", "X-Real-Ip"}
- for _, header := range headers {
- value := r.Header.Get(header)
-
- if value != "" {
- addresses := strings.Split(value, ",")
- address := strings.TrimSpace(addresses[0])
-
- if net.ParseIP(address) != nil {
- return address
- }
- }
- }
-
- // Fallback to TCP/IP source IP address.
- var remoteIP string
- if strings.ContainsRune(r.RemoteAddr, ':') {
- remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
- } else {
- remoteIP = r.RemoteAddr
- }
-
- return remoteIP
-}
diff --git a/middleware/app_session.go b/middleware/app_session.go
index 7f2b214..b505ee1 100644
--- a/middleware/app_session.go
+++ b/middleware/app_session.go
@@ -62,7 +62,7 @@ func (m *Middleware) AppSession(next http.Handler) http.Handler {
}
func (m *Middleware) getAppSessionValueFromCookie(r *http.Request) *model.Session {
- cookieValue := request.Cookie(r, cookie.CookieSessionID)
+ cookieValue := request.CookieValue(r, cookie.CookieSessionID)
if cookieValue == "" {
return nil
}
diff --git a/middleware/user_session.go b/middleware/user_session.go
index 66f0ad0..bddb47b 100644
--- a/middleware/user_session.go
+++ b/middleware/user_session.go
@@ -62,7 +62,7 @@ func (m *Middleware) isPublicRoute(r *http.Request) bool {
}
func (m *Middleware) getUserSessionFromCookie(r *http.Request) *model.UserSession {
- cookieValue := request.Cookie(r, cookie.CookieUserSessionID)
+ cookieValue := request.CookieValue(r, cookie.CookieUserSessionID)
if cookieValue == "" {
return nil
}
diff --git a/ui/category_edit.go b/ui/category_edit.go
index 0ab4b1b..b99d2a0 100644
--- a/ui/category_edit.go
+++ b/ui/category_edit.go
@@ -25,12 +25,7 @@ func (c *Controller) EditCategory(w http.ResponseWriter, r *http.Request) {
return
}
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ categoryID := request.RouteInt64Param(r, "categoryID")
category, err := c.store.Category(request.UserID(r), categoryID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/category_entries.go b/ui/category_entries.go
index ac39fa1..caa98cd 100644
--- a/ui/category_entries.go
+++ b/ui/category_entries.go
@@ -23,12 +23,7 @@ func (c *Controller) CategoryEntries(w http.ResponseWriter, r *http.Request) {
return
}
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ categoryID := request.RouteInt64Param(r, "categoryID")
category, err := c.store.Category(request.UserID(r), categoryID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/category_remove.go b/ui/category_remove.go
index 033d9e3..b424af5 100644
--- a/ui/category_remove.go
+++ b/ui/category_remove.go
@@ -21,12 +21,7 @@ func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) {
return
}
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ categoryID := request.RouteInt64Param(r, "categoryID")
category, err := c.store.Category(request.UserID(r), categoryID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/category_update.go b/ui/category_update.go
index 480b4fb..90672c0 100644
--- a/ui/category_update.go
+++ b/ui/category_update.go
@@ -25,12 +25,7 @@ func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) {
return
}
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ categoryID := request.RouteInt64Param(r, "categoryID")
category, err := c.store.Category(request.UserID(r), categoryID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/entry_bookmark.go b/ui/entry_bookmark.go
index e6d2f0e..7c42a5c 100644
--- a/ui/entry_bookmark.go
+++ b/ui/entry_bookmark.go
@@ -25,12 +25,7 @@ func (c *Controller) ShowStarredEntry(w http.ResponseWriter, r *http.Request) {
return
}
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithEntryID(entryID)
builder.WithoutStatus(model.EntryStatusRemoved)
diff --git a/ui/entry_category.go b/ui/entry_category.go
index 795aaf1..283f015 100644
--- a/ui/entry_category.go
+++ b/ui/entry_category.go
@@ -25,17 +25,8 @@ func (c *Controller) ShowCategoryEntry(w http.ResponseWriter, r *http.Request) {
return
}
- categoryID, err := request.IntParam(r, "categoryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
+ categoryID := request.RouteInt64Param(r, "categoryID")
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithCategoryID(categoryID)
diff --git a/ui/entry_feed.go b/ui/entry_feed.go
index 968c8f4..86dd2c9 100644
--- a/ui/entry_feed.go
+++ b/ui/entry_feed.go
@@ -25,17 +25,8 @@ func (c *Controller) ShowFeedEntry(w http.ResponseWriter, r *http.Request) {
return
}
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
+ entryID := request.RouteInt64Param(r, "entryID")
+ feedID := request.RouteInt64Param(r, "feedID")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithFeedID(feedID)
diff --git a/ui/entry_read.go b/ui/entry_read.go
index 61c7114..eeaca8e 100644
--- a/ui/entry_read.go
+++ b/ui/entry_read.go
@@ -24,12 +24,7 @@ func (c *Controller) ShowReadEntry(w http.ResponseWriter, r *http.Request) {
return
}
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithEntryID(entryID)
builder.WithoutStatus(model.EntryStatusRemoved)
diff --git a/ui/entry_save.go b/ui/entry_save.go
index 488022a..1f846ba 100644
--- a/ui/entry_save.go
+++ b/ui/entry_save.go
@@ -16,12 +16,7 @@ import (
// SaveEntry send the link to external services.
func (c *Controller) SaveEntry(w http.ResponseWriter, r *http.Request) {
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
builder.WithEntryID(entryID)
builder.WithoutStatus(model.EntryStatusRemoved)
diff --git a/ui/entry_scraper.go b/ui/entry_scraper.go
index b4a290f..4c2d58c 100644
--- a/ui/entry_scraper.go
+++ b/ui/entry_scraper.go
@@ -17,12 +17,7 @@ import (
// FetchContent downloads the original HTML page and returns relevant contents.
func (c *Controller) FetchContent(w http.ResponseWriter, r *http.Request) {
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
builder.WithEntryID(entryID)
builder.WithoutStatus(model.EntryStatusRemoved)
diff --git a/ui/entry_search.go b/ui/entry_search.go
index 15babc1..8acf103 100644
--- a/ui/entry_search.go
+++ b/ui/entry_search.go
@@ -25,13 +25,8 @@ func (c *Controller) ShowSearchEntry(w http.ResponseWriter, r *http.Request) {
return
}
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
- searchQuery := request.QueryParam(r, "q", "")
+ entryID := request.RouteInt64Param(r, "entryID")
+ searchQuery := request.QueryStringParam(r, "q", "")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithSearchQuery(searchQuery)
builder.WithEntryID(entryID)
diff --git a/ui/entry_toggle_bookmark.go b/ui/entry_toggle_bookmark.go
index e8e87ca..14a2c75 100644
--- a/ui/entry_toggle_bookmark.go
+++ b/ui/entry_toggle_bookmark.go
@@ -14,12 +14,7 @@ import (
// ToggleBookmark handles Ajax request to toggle bookmark value.
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- json.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
logger.Error("[Controller:ToggleBookmark] %v", err)
json.ServerError(w, nil)
diff --git a/ui/entry_unread.go b/ui/entry_unread.go
index 30a6d34..4ef5731 100644
--- a/ui/entry_unread.go
+++ b/ui/entry_unread.go
@@ -25,12 +25,7 @@ func (c *Controller) ShowUnreadEntry(w http.ResponseWriter, r *http.Request) {
return
}
- entryID, err := request.IntParam(r, "entryID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ entryID := request.RouteInt64Param(r, "entryID")
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithEntryID(entryID)
builder.WithoutStatus(model.EntryStatusRemoved)
diff --git a/ui/feed_edit.go b/ui/feed_edit.go
index 9063d51..8b1b3cb 100644
--- a/ui/feed_edit.go
+++ b/ui/feed_edit.go
@@ -23,12 +23,7 @@ func (c *Controller) EditFeed(w http.ResponseWriter, r *http.Request) {
return
}
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
feed, err := c.store.FeedByID(user.ID, feedID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/feed_entries.go b/ui/feed_entries.go
index dc0f05d..06a298b 100644
--- a/ui/feed_entries.go
+++ b/ui/feed_entries.go
@@ -23,12 +23,7 @@ func (c *Controller) ShowFeedEntries(w http.ResponseWriter, r *http.Request) {
return
}
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
feed, err := c.store.FeedByID(user.ID, feedID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/feed_icon.go b/ui/feed_icon.go
index c5a7414..0aa7089 100644
--- a/ui/feed_icon.go
+++ b/ui/feed_icon.go
@@ -15,12 +15,7 @@ import (
// ShowIcon shows the feed icon.
func (c *Controller) ShowIcon(w http.ResponseWriter, r *http.Request) {
- iconID, err := request.IntParam(r, "iconID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ iconID := request.RouteInt64Param(r, "iconID")
icon, err := c.store.IconByID(iconID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/feed_refresh.go b/ui/feed_refresh.go
index d9da238..df93c6e 100644
--- a/ui/feed_refresh.go
+++ b/ui/feed_refresh.go
@@ -16,12 +16,7 @@ import (
// RefreshFeed refresh a subscription and redirect to the feed entries page.
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
if err := c.feedHandler.RefreshFeed(request.UserID(r), feedID); err != nil {
logger.Error("[Controller:RefreshFeed] %v", err)
}
diff --git a/ui/feed_remove.go b/ui/feed_remove.go
index 9071a33..d1ab01a 100644
--- a/ui/feed_remove.go
+++ b/ui/feed_remove.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"net/http"
@@ -15,12 +15,7 @@ import (
// RemoveFeed deletes a subscription from the database and redirect to the list of feeds page.
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.ServerError(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
if err := c.store.RemoveFeed(request.UserID(r), feedID); err != nil {
html.ServerError(w, err)
return
diff --git a/ui/feed_update.go b/ui/feed_update.go
index a6fbcbf..6cc4776 100644
--- a/ui/feed_update.go
+++ b/ui/feed_update.go
@@ -26,12 +26,7 @@ func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) {
return
}
- feedID, err := request.IntParam(r, "feedID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ feedID := request.RouteInt64Param(r, "feedID")
feed, err := c.store.FeedByID(user.ID, feedID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/oauth2_callback.go b/ui/oauth2_callback.go
index 9f51d0a..1902d6e 100644
--- a/ui/oauth2_callback.go
+++ b/ui/oauth2_callback.go
@@ -23,21 +23,21 @@ func (c *Controller) OAuth2Callback(w http.ResponseWriter, r *http.Request) {
printer := locale.NewPrinter(request.UserLanguage(r))
sess := session.New(c.store, request.SessionID(r))
- provider := request.Param(r, "provider", "")
+ provider := request.RouteStringParam(r, "provider")
if provider == "" {
logger.Error("[OAuth2] Invalid or missing provider")
response.Redirect(w, r, route.Path(c.router, "login"))
return
}
- code := request.QueryParam(r, "code", "")
+ code := request.QueryStringParam(r, "code", "")
if code == "" {
logger.Error("[OAuth2] No code received on callback")
response.Redirect(w, r, route.Path(c.router, "login"))
return
}
- state := request.QueryParam(r, "state", "")
+ state := request.QueryStringParam(r, "state", "")
if state == "" || state != request.OAuth2State(r) {
logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r))
response.Redirect(w, r, route.Path(c.router, "login"))
diff --git a/ui/oauth2_redirect.go b/ui/oauth2_redirect.go
index 3b7c88a..90c20b1 100644
--- a/ui/oauth2_redirect.go
+++ b/ui/oauth2_redirect.go
@@ -18,7 +18,7 @@ import (
func (c *Controller) OAuth2Redirect(w http.ResponseWriter, r *http.Request) {
sess := session.New(c.store, request.SessionID(r))
- provider := request.Param(r, "provider", "")
+ provider := request.RouteStringParam(r, "provider")
if provider == "" {
logger.Error("[OAuth2] Invalid or missing provider: %s", provider)
response.Redirect(w, r, route.Path(c.router, "login"))
diff --git a/ui/oauth2_unlink.go b/ui/oauth2_unlink.go
index 4435733..022e282 100644
--- a/ui/oauth2_unlink.go
+++ b/ui/oauth2_unlink.go
@@ -19,7 +19,7 @@ import (
// OAuth2Unlink unlink an account from the external provider.
func (c *Controller) OAuth2Unlink(w http.ResponseWriter, r *http.Request) {
printer := locale.NewPrinter(request.UserLanguage(r))
- provider := request.Param(r, "provider", "")
+ provider := request.RouteStringParam(r, "provider")
if provider == "" {
logger.Info("[OAuth2] Invalid or missing provider")
response.Redirect(w, r, route.Path(c.router, "login"))
diff --git a/ui/proxy.go b/ui/proxy.go
index 68e4db0..553dcce 100644
--- a/ui/proxy.go
+++ b/ui/proxy.go
@@ -27,7 +27,7 @@ func (c *Controller) ImageProxy(w http.ResponseWriter, r *http.Request) {
return
}
- encodedURL := request.Param(r, "encodedURL", "")
+ encodedURL := request.RouteStringParam(r, "encodedURL")
if encodedURL == "" {
html.BadRequest(w, errors.New("No URL provided"))
return
diff --git a/ui/search_entries.go b/ui/search_entries.go
index 3f17ede..e1dcbad 100644
--- a/ui/search_entries.go
+++ b/ui/search_entries.go
@@ -23,7 +23,7 @@ func (c *Controller) ShowSearchEntries(w http.ResponseWriter, r *http.Request) {
return
}
- searchQuery := request.QueryParam(r, "q", "")
+ searchQuery := request.QueryStringParam(r, "q", "")
offset := request.QueryIntParam(r, "offset", 0)
builder := c.store.NewEntryQueryBuilder(user.ID)
builder.WithSearchQuery(searchQuery)
diff --git a/ui/session_remove.go b/ui/session_remove.go
index f08cd3d..27201fd 100644
--- a/ui/session_remove.go
+++ b/ui/session_remove.go
@@ -2,27 +2,21 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"net/http"
"miniflux.app/http/request"
"miniflux.app/http/response"
- "miniflux.app/http/response/html"
"miniflux.app/http/route"
"miniflux.app/logger"
)
// RemoveSession remove a user session.
func (c *Controller) RemoveSession(w http.ResponseWriter, r *http.Request) {
- sessionID, err := request.IntParam(r, "sessionID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
- err = c.store.RemoveUserSessionByID(request.UserID(r), sessionID)
+ sessionID := request.RouteInt64Param(r, "sessionID")
+ err := c.store.RemoveUserSessionByID(request.UserID(r), sessionID)
if err != nil {
logger.Error("[Controller:RemoveSession] %v", err)
}
diff --git a/ui/static_app_icon.go b/ui/static_app_icon.go
index 6da3f1a..2ad42fd 100644
--- a/ui/static_app_icon.go
+++ b/ui/static_app_icon.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"encoding/base64"
@@ -18,7 +18,7 @@ import (
// AppIcon renders application icons.
func (c *Controller) AppIcon(w http.ResponseWriter, r *http.Request) {
- filename := request.Param(r, "filename", "favicon.png")
+ filename := request.RouteStringParam(r, "filename")
encodedBlob, found := static.Binaries[filename]
if !found {
logger.Info("[Controller:AppIcon] This icon doesn't exists: %s", filename)
diff --git a/ui/static_javascript.go b/ui/static_javascript.go
index c52251c..248fae3 100644
--- a/ui/static_javascript.go
+++ b/ui/static_javascript.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"net/http"
@@ -16,7 +16,7 @@ import (
// Javascript renders application client side code.
func (c *Controller) Javascript(w http.ResponseWriter, r *http.Request) {
- filename := request.Param(r, "name", "app")
+ filename := request.RouteStringParam(r, "name")
if _, found := static.Javascripts[filename]; !found {
html.NotFound(w)
return
diff --git a/ui/static_stylesheet.go b/ui/static_stylesheet.go
index de540fe..8e475bd 100644
--- a/ui/static_stylesheet.go
+++ b/ui/static_stylesheet.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"net/http"
@@ -16,7 +16,7 @@ import (
// Stylesheet renders the CSS.
func (c *Controller) Stylesheet(w http.ResponseWriter, r *http.Request) {
- stylesheet := request.Param(r, "name", "default")
+ stylesheet := request.RouteStringParam(r, "name")
if _, found := static.Stylesheets[stylesheet]; !found {
html.NotFound(w)
return
diff --git a/ui/subscription_bookmarklet.go b/ui/subscription_bookmarklet.go
index 3aa1392..cba039a 100644
--- a/ui/subscription_bookmarklet.go
+++ b/ui/subscription_bookmarklet.go
@@ -32,7 +32,7 @@ func (c *Controller) Bookmarklet(w http.ResponseWriter, r *http.Request) {
return
}
- bookmarkletURL := request.QueryParam(r, "uri", "")
+ bookmarkletURL := request.QueryStringParam(r, "uri", "")
view.Set("form", form.SubscriptionForm{URL: bookmarkletURL})
view.Set("categories", categories)
diff --git a/ui/user_edit.go b/ui/user_edit.go
index 79e6d4a..f2c1abf 100644
--- a/ui/user_edit.go
+++ b/ui/user_edit.go
@@ -30,12 +30,7 @@ func (c *Controller) EditUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
selectedUser, err := c.store.UserByID(userID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/user_remove.go b/ui/user_remove.go
index 4b5440d..981a7a2 100644
--- a/ui/user_remove.go
+++ b/ui/user_remove.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
-package ui // import "miniflux.app/ui"
+package ui // import "miniflux.app/ui"
import (
"net/http"
@@ -26,12 +26,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
selectedUser, err := c.store.UserByID(userID)
if err != nil {
html.ServerError(w, err)
diff --git a/ui/user_update.go b/ui/user_update.go
index da28e39..006d49a 100644
--- a/ui/user_update.go
+++ b/ui/user_update.go
@@ -30,12 +30,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID, err := request.IntParam(r, "userID")
- if err != nil {
- html.BadRequest(w, err)
- return
- }
-
+ userID := request.RouteInt64Param(r, "userID")
selectedUser, err := c.store.UserByID(userID)
if err != nil {
html.ServerError(w, err)