diff options
Diffstat (limited to 'http/request/context_test.go')
-rw-r--r-- | http/request/context_test.go | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/http/request/context_test.go b/http/request/context_test.go new file mode 100644 index 0000000..e741362 --- /dev/null +++ b/http/request/context_test.go @@ -0,0 +1,436 @@ +// Copyright 2018 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package request // import "miniflux.app/http/request" + +import ( + "context" + "net/http" + "testing" +) + +func TestContextStringValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "IP") + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "IP" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, 0) + r = r.WithContext(ctx) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextStringValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextStringValue(r, ClientIPContextKey) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestContextBoolValue(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextBoolValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextBoolValue(r, IsAdminUserContextKey) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestContextInt64Value(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(1234)) + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(1234) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWithInvalidType(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, "invalid") + r = r.WithContext(ctx) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestContextInt64ValueWhenUnset(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := getContextInt64Value(r, UserIDContextKey) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected) + } +} + +func TestIsAdmin(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAdminUser(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAdminUserContextKey, true) + r = r.WithContext(ctx) + + result = IsAdminUser(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestIsAuthenticated(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := IsAuthenticated(r) + expected := false + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) + r = r.WithContext(ctx) + + result = IsAuthenticated(r) + expected = true + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserID(r) + expected := int64(0) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserIDContextKey, int64(123)) + r = r.WithContext(ctx) + + result = UserID(r) + expected = int64(123) + + if result != expected { + t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected) + } +} + +func TestUserTimezone(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTimezone(r) + expected := "UTC" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris") + r = r.WithContext(ctx) + + result = UserTimezone(r) + expected = "Europe/Paris" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserLanguage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserLanguage(r) + expected := "en_US" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR") + r = r.WithContext(ctx) + + result = UserLanguage(r) + expected = "fr_FR" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserTheme(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserTheme(r) + expected := "default" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserThemeContextKey, "black") + r = r.WithContext(ctx) + + result = UserTheme(r) + expected = "black" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestCSRF(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := CSRF(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, CSRFContextKey, "secret") + r = r.WithContext(ctx) + + result = CSRF(r) + expected = "secret" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestSessionID(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := SessionID(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, SessionIDContextKey, "id") + r = r.WithContext(ctx) + + result = SessionID(r) + expected = "id" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestUserSessionToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := UserSessionToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token") + r = r.WithContext(ctx) + + result = UserSessionToken(r) + expected = "token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestOAuth2State(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := OAuth2State(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, OAuth2StateContextKey, "state") + r = r.WithContext(ctx) + + result = OAuth2State(r) + expected = "state" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashMessageContextKey, "message") + r = r.WithContext(ctx) + + result = FlashMessage(r) + expected = "message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestFlashErrorMessage(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := FlashErrorMessage(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message") + r = r.WithContext(ctx) + + result = FlashErrorMessage(r) + expected = "error message" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestPocketRequestToken(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := PocketRequestToken(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token") + r = r.WithContext(ctx) + + result = PocketRequestToken(r) + expected = "request token" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} + +func TestClientIP(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.org", nil) + + result := ClientIP(r) + expected := "" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1") + r = r.WithContext(ctx) + + result = ClientIP(r) + expected = "127.0.0.1" + + if result != expected { + t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected) + } +} |