aboutsummaryrefslogtreecommitdiffhomepage
path: root/http/request/context.go
blob: 5befb063f6f2facd6ad4cc8a230698189e51e96f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.

package request // import "miniflux.app/http/request"

import "net/http"

// ContextKey represents a context key.
type ContextKey int

// List of context keys.
const (
	UserIDContextKey ContextKey = iota
	UserTimezoneContextKey
	IsAdminUserContextKey
	IsAuthenticatedContextKey
	UserSessionTokenContextKey
	UserLanguageContextKey
	UserThemeContextKey
	SessionIDContextKey
	CSRFContextKey
	OAuth2StateContextKey
	FlashMessageContextKey
	FlashErrorMessageContextKey
	PocketRequestTokenContextKey
	ClientIPContextKey
)

// IsAdminUser checks if the logged user is administrator.
func IsAdminUser(r *http.Request) bool {
	return getContextBoolValue(r, IsAdminUserContextKey)
}

// IsAuthenticated returns a boolean if the user is authenticated.
func IsAuthenticated(r *http.Request) bool {
	return getContextBoolValue(r, IsAuthenticatedContextKey)
}

// UserID returns the UserID of the logged user.
func UserID(r *http.Request) int64 {
	return getContextInt64Value(r, UserIDContextKey)
}

// UserTimezone returns the timezone used by the logged user.
func UserTimezone(r *http.Request) string {
	value := getContextStringValue(r, UserTimezoneContextKey)
	if value == "" {
		value = "UTC"
	}
	return value
}

// UserLanguage get the locale used by the current logged user.
func UserLanguage(r *http.Request) string {
	language := getContextStringValue(r, UserLanguageContextKey)
	if language == "" {
		language = "en_US"
	}
	return language
}

// UserTheme get the theme used by the current logged user.
func UserTheme(r *http.Request) string {
	theme := getContextStringValue(r, UserThemeContextKey)
	if theme == "" {
		theme = "light_serif"
	}
	return theme
}

// CSRF returns the current CSRF token.
func CSRF(r *http.Request) string {
	return getContextStringValue(r, CSRFContextKey)
}

// SessionID returns the current session ID.
func SessionID(r *http.Request) string {
	return getContextStringValue(r, SessionIDContextKey)
}

// UserSessionToken returns the current user session token.
func UserSessionToken(r *http.Request) string {
	return getContextStringValue(r, UserSessionTokenContextKey)
}

// OAuth2State returns the current OAuth2 state.
func OAuth2State(r *http.Request) string {
	return getContextStringValue(r, OAuth2StateContextKey)
}

// FlashMessage returns the message message if any.
func FlashMessage(r *http.Request) string {
	return getContextStringValue(r, FlashMessageContextKey)
}

// FlashErrorMessage returns the message error message if any.
func FlashErrorMessage(r *http.Request) string {
	return getContextStringValue(r, FlashErrorMessageContextKey)
}

// PocketRequestToken returns the Pocket Request Token if any.
func PocketRequestToken(r *http.Request) string {
	return getContextStringValue(r, PocketRequestTokenContextKey)
}

// ClientIP returns the client IP address stored in the context.
func ClientIP(r *http.Request) string {
	return getContextStringValue(r, ClientIPContextKey)
}

func getContextStringValue(r *http.Request, key ContextKey) string {
	if v := r.Context().Value(key); v != nil {
		value, valid := v.(string)
		if !valid {
			return ""
		}

		return value
	}

	return ""
}

func getContextBoolValue(r *http.Request, key ContextKey) bool {
	if v := r.Context().Value(key); v != nil {
		value, valid := v.(bool)
		if !valid {
			return false
		}

		return value
	}

	return false
}

func getContextInt64Value(r *http.Request, key ContextKey) int64 {
	if v := r.Context().Value(key); v != nil {
		value, valid := v.(int64)
		if !valid {
			return 0
		}

		return value
	}

	return 0
}