feat: add reflection-based system to manage defaults and validations per struct

This commit is contained in:
Derrick Hammer 2024-02-28 08:47:33 -05:00
parent 1a20a7d35f
commit cb558cdfc3
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
7 changed files with 171 additions and 19 deletions

View File

@ -3,8 +3,8 @@ package config
import ( import (
"errors" "errors"
"fmt" "fmt"
"reflect"
"github.com/docker/go-units"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -17,6 +17,14 @@ var (
} }
) )
type Defaults interface {
Defaults() map[string]interface{}
}
type Validator interface {
Validate() error
}
type Config struct { type Config struct {
Core CoreConfig `mapstructure:"core"` Core CoreConfig `mapstructure:"core"`
Protocol map[string]ProtocolConfig `mapstructure:"protocol"` Protocol map[string]ProtocolConfig `mapstructure:"protocol"`
@ -41,7 +49,7 @@ func NewManager() (*Manager, error) {
root: &config, root: &config,
} }
m.setDefaults(m.coreDefaults(), "") m.setDefaultsForObject(m.root.Core, "")
err = m.maybeSave() err = m.maybeSave()
if err != nil { if err != nil {
return nil, err return nil, err
@ -52,29 +60,126 @@ func NewManager() (*Manager, error) {
return nil, err return nil, err
} }
err = m.validateObject(m.root)
if err != nil {
return nil, err
}
return m, nil return m, nil
} }
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error { func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
defaults := cfg.Defaults()
protocolPrefix := fmt.Sprintf("protocol.%s", name) protocolPrefix := fmt.Sprintf("protocol.%s", name)
m.setDefaults(defaults, protocolPrefix) m.setDefaultsForObject(cfg, protocolPrefix)
err := m.maybeSave() err := m.maybeSave()
if err != nil { if err != nil {
return err return err
} }
return m.viper.Sub(protocolPrefix).Unmarshal(cfg) err = m.viper.Sub(protocolPrefix).Unmarshal(cfg)
if err != nil {
return err
}
err = m.validateObject(cfg)
if err != nil {
return err
}
return nil
} }
func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) { func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) {
for key, value := range defaults { // Reflect on the object to traverse its fields
if prefix != "" { objValue := reflect.ValueOf(obj)
key = fmt.Sprintf("%s.%s", prefix, key) objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
// Check if the object itself implements Defaults
if setter, ok := obj.(Defaults); ok {
m.applyDefaults(setter, prefix)
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
mapstructureTag := fieldType.Tag.Get("mapstructure")
// Construct new prefix based on the mapstructure tag, if available
newPrefix := prefix
if mapstructureTag != "" && mapstructureTag != "-" {
if newPrefix != "" {
newPrefix += "."
}
newPrefix += mapstructureTag
} }
if m.setDefault(key, value) {
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
m.setDefaultsForObject(field.Interface(), newPrefix)
}
}
}
func (m *Manager) validateObject(obj interface{}) error {
// Reflect on the object to traverse its fields
objValue := reflect.ValueOf(obj)
objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
// Check if the object itself implements Defaults
if validator, ok := obj.(Validator); ok {
err := validator.Validate()
if err != nil {
return err
}
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
err := m.validateObject(field.Interface())
if err != nil {
return err
}
}
}
return nil
}
func (m *Manager) applyDefaults(setter Defaults, prefix string) {
defaults := setter.Defaults()
for key, value := range defaults {
fullKey := key
if prefix != "" {
fullKey = fmt.Sprintf("%s.%s", prefix, key)
}
if m.setDefault(fullKey, value) {
m.changes = true m.changes = true
} }
} }
@ -102,13 +207,7 @@ func (m *Manager) maybeSave() error {
} }
func (m *Manager) coreDefaults() map[string]interface{} { func (m *Manager) coreDefaults() map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{}
"core.post_upload_limit": units.MiB * 100,
"core.log.level": "info",
"core.db.charset": "utf8mb4",
"core.db.port": 3306,
"core.db.name": "portal",
}
} }
func (m *Manager) Config() *Config { func (m *Manager) Config() *Config {

View File

@ -1,5 +1,9 @@
package config package config
import "github.com/docker/go-units"
var _ Defaults = (*CoreConfig)(nil)
type CoreConfig struct { type CoreConfig struct {
DB DatabaseConfig `mapstructure:"db"` DB DatabaseConfig `mapstructure:"db"`
Domain string `mapstructure:"domain"` Domain string `mapstructure:"domain"`
@ -15,3 +19,9 @@ type CoreConfig struct {
Mail MailConfig `mapstructure:"mail"` Mail MailConfig `mapstructure:"mail"`
Clustered *ClusterConfig `mapstructure:"clustered"` Clustered *ClusterConfig `mapstructure:"clustered"`
} }
func (c CoreConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"post_upload_limit": units.MiB * 100,
}
}

View File

@ -6,6 +6,8 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
var _ Defaults = (*DatabaseConfig)(nil)
type DatabaseConfig struct { type DatabaseConfig struct {
Charset string `mapstructure:"charset"` Charset string `mapstructure:"charset"`
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
@ -16,6 +18,14 @@ type DatabaseConfig struct {
Cache *CacheConfig `mapstructure:"cache"` Cache *CacheConfig `mapstructure:"cache"`
} }
func (d DatabaseConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"charset": "utf8mb4",
"port": 3306,
"name": "portal",
}
}
type CacheConfig struct { type CacheConfig struct {
Mode string `mapstructure:"mode"` Mode string `mapstructure:"mode"`
Options interface{} `mapstructure:"options"` Options interface{} `mapstructure:"options"`

View File

@ -2,6 +2,8 @@ package config
import "errors" import "errors"
var _ Defaults = (*EtcdConfig)(nil)
type EtcdConfig struct { type EtcdConfig struct {
Endpoints []string `mapstructure:"endpoints"` Endpoints []string `mapstructure:"endpoints"`
DialTimeout int `mapstructure:"dial_timeout"` DialTimeout int `mapstructure:"dial_timeout"`
@ -13,3 +15,9 @@ func (r *EtcdConfig) Validate() error {
} }
return nil return nil
} }
func (r *EtcdConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"dial_timeout": 5,
}
}

View File

@ -1,5 +1,13 @@
package config package config
var _ Defaults = (*LogConfig)(nil)
type LogConfig struct { type LogConfig struct {
Level string `mapstructure:"level"` Level string `mapstructure:"level"`
} }
func (l LogConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"level": "info",
}
}

View File

@ -1,5 +1,9 @@
package config package config
import "errors"
var _ Validator = (*MailConfig)(nil)
type MailConfig struct { type MailConfig struct {
Host string Host string
Port int Port int
@ -8,3 +12,16 @@ type MailConfig struct {
Username string Username string
Password string Password string
} }
func (m MailConfig) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Username == "" {
return errors.New("username is required")
}
if m.Password == "" {
return errors.New("password is required")
}
return nil
}

View File

@ -1,5 +1,5 @@
package config package config
type ProtocolConfig interface { type ProtocolConfig interface {
Defaults() map[string]interface{} Defaults
} }