portal/config/config.go

171 lines
2.8 KiB
Go

package config
import (
"errors"
"fmt"
"github.com/docker/go-units"
"github.com/spf13/viper"
"go.uber.org/zap"
)
var (
ConfigFilePaths = []string{
"/etc/lumeweb/portal/",
"$HOME/.lumeweb/portal/",
".",
}
)
type Config struct {
Core CoreConfig `mapstructure:"core"`
}
type Manager struct {
viper *viper.Viper
root *Config
changes bool
}
func NewManager() (*Manager, error) {
v, err := newConfig()
if err != nil {
return nil, err
}
var config Config
m := &Manager{
viper: v,
root: &config,
}
m.setDefaults(m.coreDefaults(), "")
err = m.maybeSave()
if err != nil {
return nil, err
}
err = v.Unmarshal(&config, viper.DecodeHook(cacheConfigHook()))
if err != nil {
return nil, err
}
return m, nil
}
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
defaults := cfg.Defaults()
protocolPrefix := fmt.Sprintf("protocol.%s", name)
m.setDefaults(defaults, protocolPrefix)
err := m.maybeSave()
if err != nil {
return err
}
return m.viper.Sub(protocolPrefix).Unmarshal(cfg)
}
func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) {
for key, value := range defaults {
if prefix != "" {
key = fmt.Sprintf("%s.%s", prefix, key)
}
if m.setDefault(key, value) {
m.changes = true
}
}
}
func (m *Manager) setDefault(key string, value interface{}) bool {
if !m.viper.IsSet(key) {
m.viper.SetDefault(key, value)
return true
}
return false
}
func (m *Manager) maybeSave() error {
if m.changes {
ret := m.viper.WriteConfig()
if ret != nil {
return ret
}
m.changes = false
}
return nil
}
func (m *Manager) coreDefaults() map[string]interface{} {
return map[string]interface{}{
"core.post_upload_limit": units.MiB * 100,
"core.log.level": "info",
"core.db.charset": "utf8mb4",
"core.db.port": 3306,
"core.db.name": "portal",
}
}
func (m *Manager) Config() *Config {
return m.root
}
func (m *Manager) Viper() *viper.Viper {
return m.viper
}
func (m *Manager) Save() error {
err := m.viper.WriteConfig()
if err != nil {
return err
}
err = m.viper.Unmarshal(&m.root)
if err != nil {
return err
}
return nil
}
func newConfig() (*viper.Viper, error) {
logger := newFallbackLogger()
viper.SetConfigName("config")
viper.SetConfigType("yaml")
for _, path := range ConfigFilePaths {
viper.AddConfigPath(path)
}
viper.SetEnvPrefix("LUME_WEB_PORTAL")
viper.AutomaticEnv()
err := viper.ReadInConfig()
if err != nil {
if !errors.Is(err, &viper.ConfigFileNotFoundError{}) {
return nil, err
}
logger.Info("Config file not found, using default settings.")
err := viper.SafeWriteConfig()
if err != nil {
return nil, err
}
return viper.GetViper(), nil
}
return viper.GetViper(), nil
}
func newFallbackLogger() *zap.Logger {
l, _ := zap.NewDevelopment()
return l
}