package config import ( "errors" "os" "github.com/gobuffalo/packr/v2" // explicit yaml designation is a work around for github.com/golang/go/issues/26882 yaml "gopkg.in/yaml.v2" ) const ( configPath = "/etc/updown/config.yml" distBoxName = "distConfig" distConfigName = "config.yml" distConfigDir = "./dist" ) // Environment represents the environment in which the app is being run, such as development or production type Environment string const ( // EnvironmentDev represents the development environment EnvironmentDev Environment = "dev" // EnvironmentProd represents the production environment EnvironmentProd = "prod" ) // ErrDefaultConfig is thrown when the default configuration is used var ErrDefaultConfig = errors.New("config: using default configuration") // ErrConfigNotLoaded is returned when the config has not been loaded yet var ErrConfigNotLoaded = errors.New("config: config not loaded") // ErrBadEnvironment is returned when a bad environment is specified var ErrBadEnvironment = errors.New("config: invalid value for environment") // To prevent reading the config file each time we want a new config value, we store it as a package variable var config Config // Config stores all configurations for the application type Config struct { Server ServerConfig StoragePath string `yaml:"storage_path"` DatabaseURI string `yaml:"database_uri"` Environment Environment } // ServerConfig stores configuration for the webserver type ServerConfig struct { ListenAddr string `yaml:"listen_addr"` Port int } // UnmarshalYAML performs standard unmarshaling, followed by validation of the Environment enum func (conf *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { // Make a config type without the UnmarshalYAML function implemented, to avoid infinite recursion type SafeConfig Config safeConfig := (*SafeConfig)(conf) err := unmarshal(safeConfig) if err != nil { return err } switch conf.Environment { case "": conf.Environment = EnvironmentProd fallthrough case EnvironmentDev, EnvironmentProd: return nil default: conf.Environment = "" return ErrBadEnvironment } } // Get loads the currently stored config. If one is not stored, panic. func Get() Config { if config == (Config{}) { panic(ErrConfigNotLoaded) } return config } // Load checks for the stored configuration, and returns the default config and DefaultConfigError if not found. func Load() (Config, error) { parsedConfig := Config{} var err error if shouldUseDistConfig() { parsedConfig, err = parseDistConfig() if err != nil { return Config{}, err } err = ErrDefaultConfig } else { parsedConfig, err = parseSuppliedConfig() } if err != nil && err != ErrDefaultConfig { return Config{}, err } config = parsedConfig return config, err } func shouldUseDistConfig() bool { _, err := os.Stat(configPath) return os.IsNotExist(err) } func parseDistConfig() (Config, error) { distBox := packr.New(distBoxName, distConfigDir) rawDistConfig, err := distBox.Find(distConfigName) if err != nil { return Config{}, err } distConfig := Config{} err = yaml.UnmarshalStrict(rawDistConfig, &distConfig) if err != nil { return Config{}, err } return distConfig, nil } func parseSuppliedConfig() (Config, error) { configFile, err := os.Open(configPath) if err != nil { return Config{}, err } configDecoder := yaml.NewDecoder(configFile) configDecoder.SetStrict(true) config := Config{} return config, configDecoder.Decode(&config) }