From cb558cdfc3da70a8e5ab7f1be1981df0f2ba7c4f Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Wed, 28 Feb 2024 08:47:33 -0500 Subject: [PATCH] feat: add reflection-based system to manage defaults and validations per struct --- config/config.go | 135 +++++++++++++++++++++++++++++++++++++++------ config/core.go | 10 ++++ config/database.go | 10 ++++ config/etcd.go | 8 +++ config/log.go | 8 +++ config/mail.go | 17 ++++++ config/protocol.go | 2 +- 7 files changed, 171 insertions(+), 19 deletions(-) diff --git a/config/config.go b/config/config.go index ac3184f..682f5ba 100644 --- a/config/config.go +++ b/config/config.go @@ -3,8 +3,8 @@ package config import ( "errors" "fmt" + "reflect" - "github.com/docker/go-units" "github.com/spf13/viper" "go.uber.org/zap" ) @@ -17,6 +17,14 @@ var ( } ) +type Defaults interface { + Defaults() map[string]interface{} +} + +type Validator interface { + Validate() error +} + type Config struct { Core CoreConfig `mapstructure:"core"` Protocol map[string]ProtocolConfig `mapstructure:"protocol"` @@ -41,7 +49,7 @@ func NewManager() (*Manager, error) { root: &config, } - m.setDefaults(m.coreDefaults(), "") + m.setDefaultsForObject(m.root.Core, "") err = m.maybeSave() if err != nil { return nil, err @@ -52,29 +60,126 @@ func NewManager() (*Manager, error) { return nil, err } + err = m.validateObject(m.root) + if err != nil { + return nil, err + } + return m, nil } func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error { - defaults := cfg.Defaults() - protocolPrefix := fmt.Sprintf("protocol.%s", name) - m.setDefaults(defaults, protocolPrefix) + m.setDefaultsForObject(cfg, protocolPrefix) err := m.maybeSave() if err != nil { return err } - return m.viper.Sub(protocolPrefix).Unmarshal(cfg) + err = m.viper.Sub(protocolPrefix).Unmarshal(cfg) + if err != nil { + return err + } + + err = m.validateObject(cfg) + if err != nil { + return err + } + + return nil } -func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) { - for key, value := range defaults { - if prefix != "" { - key = fmt.Sprintf("%s.%s", prefix, key) +func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) { + // Reflect on the object to traverse its fields + objValue := reflect.ValueOf(obj) + objType := reflect.TypeOf(obj) + + // If the object is a pointer, we need to work with its element + if objValue.Kind() == reflect.Ptr { + objValue = objValue.Elem() + objType = objType.Elem() + } + + // Check if the object itself implements Defaults + if setter, ok := obj.(Defaults); ok { + m.applyDefaults(setter, prefix) + } + + // Recursively handle struct fields + for i := 0; i < objValue.NumField(); i++ { + field := objValue.Field(i) + fieldType := objType.Field(i) + mapstructureTag := fieldType.Tag.Get("mapstructure") + + // Construct new prefix based on the mapstructure tag, if available + newPrefix := prefix + if mapstructureTag != "" && mapstructureTag != "-" { + if newPrefix != "" { + newPrefix += "." + } + newPrefix += mapstructureTag } - if m.setDefault(key, value) { + + // If field is a struct or pointer to a struct, recurse + if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) { + if field.Kind() == reflect.Ptr && field.IsNil() { + // Initialize nil pointer to struct + field.Set(reflect.New(fieldType.Type.Elem())) + } + m.setDefaultsForObject(field.Interface(), newPrefix) + } + } +} + +func (m *Manager) validateObject(obj interface{}) error { + // Reflect on the object to traverse its fields + objValue := reflect.ValueOf(obj) + objType := reflect.TypeOf(obj) + + // If the object is a pointer, we need to work with its element + if objValue.Kind() == reflect.Ptr { + objValue = objValue.Elem() + objType = objType.Elem() + } + + // Check if the object itself implements Defaults + if validator, ok := obj.(Validator); ok { + err := validator.Validate() + if err != nil { + return err + } + } + + // Recursively handle struct fields + for i := 0; i < objValue.NumField(); i++ { + field := objValue.Field(i) + fieldType := objType.Field(i) + + // If field is a struct or pointer to a struct, recurse + if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) { + if field.Kind() == reflect.Ptr && field.IsNil() { + // Initialize nil pointer to struct + field.Set(reflect.New(fieldType.Type.Elem())) + } + err := m.validateObject(field.Interface()) + if err != nil { + return err + } + } + } + + return nil +} + +func (m *Manager) applyDefaults(setter Defaults, prefix string) { + defaults := setter.Defaults() + for key, value := range defaults { + fullKey := key + if prefix != "" { + fullKey = fmt.Sprintf("%s.%s", prefix, key) + } + if m.setDefault(fullKey, value) { m.changes = true } } @@ -102,13 +207,7 @@ func (m *Manager) maybeSave() error { } func (m *Manager) coreDefaults() map[string]interface{} { - return map[string]interface{}{ - "core.post_upload_limit": units.MiB * 100, - "core.log.level": "info", - "core.db.charset": "utf8mb4", - "core.db.port": 3306, - "core.db.name": "portal", - } + return map[string]interface{}{} } func (m *Manager) Config() *Config { diff --git a/config/core.go b/config/core.go index 5037938..4b1507a 100644 --- a/config/core.go +++ b/config/core.go @@ -1,5 +1,9 @@ package config +import "github.com/docker/go-units" + +var _ Defaults = (*CoreConfig)(nil) + type CoreConfig struct { DB DatabaseConfig `mapstructure:"db"` Domain string `mapstructure:"domain"` @@ -15,3 +19,9 @@ type CoreConfig struct { Mail MailConfig `mapstructure:"mail"` Clustered *ClusterConfig `mapstructure:"clustered"` } + +func (c CoreConfig) Defaults() map[string]interface{} { + return map[string]interface{}{ + "post_upload_limit": units.MiB * 100, + } +} diff --git a/config/database.go b/config/database.go index 4581f32..dc3e106 100644 --- a/config/database.go +++ b/config/database.go @@ -6,6 +6,8 @@ import ( "github.com/mitchellh/mapstructure" ) +var _ Defaults = (*DatabaseConfig)(nil) + type DatabaseConfig struct { Charset string `mapstructure:"charset"` Host string `mapstructure:"host"` @@ -16,6 +18,14 @@ type DatabaseConfig struct { Cache *CacheConfig `mapstructure:"cache"` } +func (d DatabaseConfig) Defaults() map[string]interface{} { + return map[string]interface{}{ + "charset": "utf8mb4", + "port": 3306, + "name": "portal", + } +} + type CacheConfig struct { Mode string `mapstructure:"mode"` Options interface{} `mapstructure:"options"` diff --git a/config/etcd.go b/config/etcd.go index 01a9957..42ab703 100644 --- a/config/etcd.go +++ b/config/etcd.go @@ -2,6 +2,8 @@ package config import "errors" +var _ Defaults = (*EtcdConfig)(nil) + type EtcdConfig struct { Endpoints []string `mapstructure:"endpoints"` DialTimeout int `mapstructure:"dial_timeout"` @@ -13,3 +15,9 @@ func (r *EtcdConfig) Validate() error { } return nil } + +func (r *EtcdConfig) Defaults() map[string]interface{} { + return map[string]interface{}{ + "dial_timeout": 5, + } +} diff --git a/config/log.go b/config/log.go index 4a83a2b..24ddf70 100644 --- a/config/log.go +++ b/config/log.go @@ -1,5 +1,13 @@ package config +var _ Defaults = (*LogConfig)(nil) + type LogConfig struct { Level string `mapstructure:"level"` } + +func (l LogConfig) Defaults() map[string]interface{} { + return map[string]interface{}{ + "level": "info", + } +} diff --git a/config/mail.go b/config/mail.go index 6a1ae9e..f5bc4cc 100644 --- a/config/mail.go +++ b/config/mail.go @@ -1,5 +1,9 @@ package config +import "errors" + +var _ Validator = (*MailConfig)(nil) + type MailConfig struct { Host string Port int @@ -8,3 +12,16 @@ type MailConfig struct { Username string Password string } + +func (m MailConfig) Validate() error { + if m.Host == "" { + return errors.New("host is required") + } + if m.Username == "" { + return errors.New("username is required") + } + if m.Password == "" { + return errors.New("password is required") + } + return nil +} diff --git a/config/protocol.go b/config/protocol.go index c0cef98..8d5d781 100644 --- a/config/protocol.go +++ b/config/protocol.go @@ -1,5 +1,5 @@ package config type ProtocolConfig interface { - Defaults() map[string]interface{} + Defaults }