portal/config/config.go

293 lines
5.7 KiB
Go

package config
import (
"errors"
"fmt"
"reflect"
"github.com/spf13/viper"
"go.uber.org/zap"
)
var (
ConfigFilePaths = []string{
"/etc/lumeweb/portal/",
"$HOME/.lumeweb/portal/",
".",
}
)
type Defaults interface {
Defaults() map[string]interface{}
}
type Validator interface {
Validate() error
}
type Config struct {
Core CoreConfig `mapstructure:"core"`
Protocol map[string]interface{} `mapstructure:"protocol"`
}
type Manager struct {
viper *viper.Viper
root *Config
changes bool
}
func NewManager() (*Manager, error) {
v, err := newConfig()
if err != nil {
return nil, err
}
var config Config
m := &Manager{
viper: v,
root: &config,
}
m.setDefaultsForObject(m.root.Core, "core")
err = m.maybeSave()
if err != nil {
return nil, err
}
err = v.Unmarshal(&config, viper.DecodeHook(clusterConfigHook()), viper.DecodeHook(cacheConfigHook()))
if err != nil {
return nil, err
}
err = m.validateObject(m.root)
if err != nil {
return nil, err
}
err = m.maybeConfigureCluster()
if err != nil {
return m, err
}
return m, nil
}
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
protocolPrefix := fmt.Sprintf("protocol.%s", name)
m.setDefaultsForObject(cfg, protocolPrefix)
err := m.maybeSave()
if err != nil {
return err
}
err = m.viper.Sub(protocolPrefix).Unmarshal(cfg)
if err != nil {
return err
}
err = m.validateObject(cfg)
if err != nil {
return err
}
m.root.Protocol[name] = cfg
return nil
}
func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) {
// Reflect on the object to traverse its fields
objValue := reflect.ValueOf(obj)
objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
// Check if the object itself implements Defaults
if setter, ok := obj.(Defaults); ok {
m.applyDefaults(setter, prefix)
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
// Check if the field is exported and can be interfaced
if !field.CanInterface() {
continue
}
mapstructureTag := fieldType.Tag.Get("mapstructure")
// Construct new prefix based on the mapstructure tag, if available
newPrefix := prefix
if mapstructureTag != "" && mapstructureTag != "-" {
if newPrefix != "" {
newPrefix += "."
}
newPrefix += mapstructureTag
}
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
m.setDefaultsForObject(field.Interface(), newPrefix)
}
}
}
func (m *Manager) validateObject(obj interface{}) error {
// Reflect on the object to traverse its fields
objValue := reflect.ValueOf(obj)
objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
// Check if the object itself implements Defaults
if validator, ok := obj.(Validator); ok {
err := validator.Validate()
if err != nil {
return err
}
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
if !field.CanInterface() {
continue
}
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
err := m.validateObject(field.Interface())
if err != nil {
return err
}
}
}
return nil
}
func (m *Manager) applyDefaults(setter Defaults, prefix string) {
defaults := setter.Defaults()
for key, value := range defaults {
fullKey := key
if prefix != "" {
fullKey = fmt.Sprintf("%s.%s", prefix, key)
}
if m.setDefault(fullKey, value) {
m.changes = true
}
}
}
func (m *Manager) setDefault(key string, value interface{}) bool {
if !m.viper.IsSet(key) {
m.viper.SetDefault(key, value)
return true
}
return false
}
func (m *Manager) maybeSave() error {
if m.changes {
ret := m.viper.WriteConfig()
if ret != nil {
return ret
}
m.changes = false
}
return nil
}
func (m *Manager) maybeConfigureCluster() error {
if m.root.Core.Clustered != nil && m.root.Core.Clustered.Enabled {
m.root.Core.DB.Cache.Mode = "redis"
m.root.Core.DB.Cache.Options = m.root.Core.Clustered.Redis
}
return nil
}
func (m *Manager) Config() *Config {
return m.root
}
func (m *Manager) Viper() *viper.Viper {
return m.viper
}
func (m *Manager) Save() error {
err := m.viper.WriteConfig()
if err != nil {
return err
}
err = m.viper.Unmarshal(&m.root)
if err != nil {
return err
}
return nil
}
func newConfig() (*viper.Viper, error) {
logger := newFallbackLogger()
viper.SetConfigName("config")
viper.SetConfigType("yaml")
for _, path := range ConfigFilePaths {
viper.AddConfigPath(path)
}
viper.SetEnvPrefix("LUME_WEB_PORTAL")
viper.AutomaticEnv()
err := viper.ReadInConfig()
if err != nil {
if !errors.As(err, &viper.ConfigFileNotFoundError{}) {
return nil, err
}
logger.Info("Config file not found, using default settings.")
err := viper.SafeWriteConfig()
if err != nil {
return nil, err
}
return viper.GetViper(), nil
}
return viper.GetViper(), nil
}
func newFallbackLogger() *zap.Logger {
l, _ := zap.NewDevelopment()
return l
}