diff options
Diffstat (limited to 'config/parser.go')
-rw-r--r-- | config/parser.go | 214 |
1 files changed, 146 insertions, 68 deletions
diff --git a/config/parser.go b/config/parser.go index b2ed2e7..6de4550 100644 --- a/config/parser.go +++ b/config/parser.go @@ -5,113 +5,184 @@ package config // import "miniflux.app/config" import ( + "bufio" "errors" "fmt" - "net/url" + "io" + url_parser "net/url" "os" "strconv" "strings" ) -func parse() (opts *Options, err error) { - opts = &Options{} - opts.baseURL, opts.rootURL, opts.basePath, err = parseBaseURL() +// Parser handles configuration parsing. +type Parser struct { + opts *Options +} + +// NewParser returns a new Parser. +func NewParser() *Parser { + return &Parser{ + opts: NewOptions(), + } +} + +// ParseEnvironmentVariables loads configuration values from environment variables. +func (p *Parser) ParseEnvironmentVariables() (*Options, error) { + err := p.parseLines(os.Environ()) if err != nil { return nil, err } + return p.opts, nil +} - opts.debug = getBooleanValue("DEBUG") - opts.listenAddr = parseListenAddr() - - opts.databaseURL = getStringValue("DATABASE_URL", defaultDatabaseURL) - opts.databaseMaxConns = getIntValue("DATABASE_MAX_CONNS", defaultDatabaseMaxConns) - opts.databaseMinConns = getIntValue("DATABASE_MIN_CONNS", defaultDatabaseMinConns) - opts.runMigrations = getBooleanValue("RUN_MIGRATIONS") - - opts.hsts = !getBooleanValue("DISABLE_HSTS") - opts.HTTPS = getBooleanValue("HTTPS") - - opts.schedulerService = !getBooleanValue("DISABLE_SCHEDULER_SERVICE") - opts.httpService = !getBooleanValue("DISABLE_HTTP_SERVICE") - - opts.certFile = getStringValue("CERT_FILE", defaultCertFile) - opts.certKeyFile = getStringValue("KEY_FILE", defaultKeyFile) - opts.certDomain = getStringValue("CERT_DOMAIN", defaultCertDomain) - opts.certCache = getStringValue("CERT_CACHE", defaultCertCache) +// ParseFile loads configuration values from a local file. +func (p *Parser) ParseFile(filename string) (*Options, error) { + fp, err := os.Open(filename) + if err != nil { + return nil, err + } + defer fp.Close() - opts.cleanupFrequency = getIntValue("CLEANUP_FREQUENCY", defaultCleanupFrequency) - opts.workerPoolSize = getIntValue("WORKER_POOL_SIZE", defaultWorkerPoolSize) - opts.pollingFrequency = getIntValue("POLLING_FREQUENCY", defaultPollingFrequency) - opts.batchSize = getIntValue("BATCH_SIZE", defaultBatchSize) - opts.archiveReadDays = getIntValue("ARCHIVE_READ_DAYS", defaultArchiveReadDays) - opts.proxyImages = getStringValue("PROXY_IMAGES", defaultProxyImages) - opts.createAdmin = getBooleanValue("CREATE_ADMIN") - opts.pocketConsumerKey = getStringValue("POCKET_CONSUMER_KEY", "") + err = p.parseLines(p.parseFileContent(fp)) + if err != nil { + return nil, err + } + return p.opts, nil +} - opts.oauth2UserCreationAllowed = getBooleanValue("OAUTH2_USER_CREATION") - opts.oauth2ClientID = getStringValue("OAUTH2_CLIENT_ID", defaultOAuth2ClientID) - opts.oauth2ClientSecret = getStringValue("OAUTH2_CLIENT_SECRET", defaultOAuth2ClientSecret) - opts.oauth2RedirectURL = getStringValue("OAUTH2_REDIRECT_URL", defaultOAuth2RedirectURL) - opts.oauth2Provider = getStringValue("OAUTH2_PROVIDER", defaultOAuth2Provider) +func (p *Parser) parseFileContent(r io.Reader) (lines []string) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if len(line) > 0 && !strings.HasPrefix(line, "#") && strings.Index(line, "=") > 0 { + lines = append(lines, line) + } + } + return lines +} - opts.httpClientTimeout = getIntValue("HTTP_CLIENT_TIMEOUT", defaultHTTPClientTimeout) - opts.httpClientMaxBodySize = int64(getIntValue("HTTP_CLIENT_MAX_BODY_SIZE", defaultHTTPClientMaxBodySize) * 1024 * 1024) +func (p *Parser) parseLines(lines []string) (err error) { + var port string + + for _, line := range lines { + fields := strings.SplitN(line, "=", 2) + key := strings.TrimSpace(fields[0]) + value := strings.TrimSpace(fields[1]) + + switch key { + case "DEBUG": + p.opts.debug = parseBool(value, defaultDebug) + case "BASE_URL": + p.opts.baseURL, p.opts.rootURL, p.opts.basePath, err = parseBaseURL(value) + if err != nil { + return err + } + case "PORT": + port = value + case "LISTEN_ADDR": + p.opts.listenAddr = parseString(value, defaultListenAddr) + case "DATABASE_URL": + p.opts.databaseURL = parseString(value, defaultDatabaseURL) + case "DATABASE_MAX_CONNS": + p.opts.databaseMaxConns = parseInt(value, defaultDatabaseMaxConns) + case "DATABASE_MIN_CONNS": + p.opts.databaseMinConns = parseInt(value, defaultDatabaseMinConns) + case "RUN_MIGRATIONS": + p.opts.runMigrations = parseBool(value, defaultRunMigrations) + case "DISABLE_HSTS": + p.opts.hsts = !parseBool(value, defaultHSTS) + case "HTTPS": + p.opts.HTTPS = parseBool(value, defaultHTTPS) + case "DISABLE_SCHEDULER_SERVICE": + p.opts.schedulerService = !parseBool(value, defaultSchedulerService) + case "DISABLE_HTTP_SERVICE": + p.opts.httpService = !parseBool(value, defaultHTTPService) + case "CERT_FILE": + p.opts.certFile = parseString(value, defaultCertFile) + case "KEY_FILE": + p.opts.certKeyFile = parseString(value, defaultKeyFile) + case "CERT_DOMAIN": + p.opts.certDomain = parseString(value, defaultCertDomain) + case "CERT_CACHE": + p.opts.certCache = parseString(value, defaultCertCache) + case "CLEANUP_FREQUENCY": + p.opts.cleanupFrequency = parseInt(value, defaultCleanupFrequency) + case "WORKER_POOL_SIZE": + p.opts.workerPoolSize = parseInt(value, defaultWorkerPoolSize) + case "POLLING_FREQUENCY": + p.opts.pollingFrequency = parseInt(value, defaultPollingFrequency) + case "BATCH_SIZE": + p.opts.batchSize = parseInt(value, defaultBatchSize) + case "ARCHIVE_READ_DAYS": + p.opts.archiveReadDays = parseInt(value, defaultArchiveReadDays) + case "PROXY_IMAGES": + p.opts.proxyImages = parseString(value, defaultProxyImages) + case "CREATE_ADMIN": + p.opts.createAdmin = parseBool(value, defaultCreateAdmin) + case "POCKET_CONSUMER_KEY": + p.opts.pocketConsumerKey = parseString(value, defaultPocketConsumerKey) + case "OAUTH2_USER_CREATION": + p.opts.oauth2UserCreationAllowed = parseBool(value, defaultOAuth2UserCreation) + case "OAUTH2_CLIENT_ID": + p.opts.oauth2ClientID = parseString(value, defaultOAuth2ClientID) + case "OAUTH2_CLIENT_SECRET": + p.opts.oauth2ClientSecret = parseString(value, defaultOAuth2ClientSecret) + case "OAUTH2_REDIRECT_URL": + p.opts.oauth2RedirectURL = parseString(value, defaultOAuth2RedirectURL) + case "OAUTH2_PROVIDER": + p.opts.oauth2Provider = parseString(value, defaultOAuth2Provider) + case "HTTP_CLIENT_TIMEOUT": + p.opts.httpClientTimeout = parseInt(value, defaultHTTPClientTimeout) + case "HTTP_CLIENT_MAX_BODY_SIZE": + p.opts.httpClientMaxBodySize = int64(parseInt(value, defaultHTTPClientMaxBodySize) * 1024 * 1024) + } + } - return opts, nil + if port != "" { + p.opts.listenAddr = ":" + port + } + return nil } -func parseBaseURL() (string, string, string, error) { - baseURL := os.Getenv("BASE_URL") - if baseURL == "" { - return defaultBaseURL, defaultBaseURL, "", nil +func parseBaseURL(value string) (string, string, string, error) { + if value == "" { + return defaultBaseURL, defaultRootURL, "", nil } - if baseURL[len(baseURL)-1:] == "/" { - baseURL = baseURL[:len(baseURL)-1] + if value[len(value)-1:] == "/" { + value = value[:len(value)-1] } - u, err := url.Parse(baseURL) + url, err := url_parser.Parse(value) if err != nil { return "", "", "", fmt.Errorf("Invalid BASE_URL: %v", err) } - scheme := strings.ToLower(u.Scheme) + scheme := strings.ToLower(url.Scheme) if scheme != "https" && scheme != "http" { return "", "", "", errors.New("Invalid BASE_URL: scheme must be http or https") } - basePath := u.Path - u.Path = "" - return baseURL, u.String(), basePath, nil + basePath := url.Path + url.Path = "" + return value, url.String(), basePath, nil } -func parseListenAddr() string { - if port := os.Getenv("PORT"); port != "" { - return ":" + port +func parseBool(value string, fallback bool) bool { + if value == "" { + return fallback } - return getStringValue("LISTEN_ADDR", defaultListenAddr) -} - -func getBooleanValue(key string) bool { - value := strings.ToLower(os.Getenv(key)) + value = strings.ToLower(value) if value == "1" || value == "yes" || value == "true" || value == "on" { return true } - return false -} - -func getStringValue(key, fallback string) string { - value := os.Getenv(key) - if value == "" { - return fallback - } - return value + return false } -func getIntValue(key string, fallback int) int { - value := os.Getenv(key) +func parseInt(value string, fallback int) int { if value == "" { return fallback } @@ -123,3 +194,10 @@ func getIntValue(key string, fallback int) int { return v } + +func parseString(value string, fallback string) string { + if value == "" { + return fallback + } + return value +} |