feat: add reflection-based system to manage defaults and validations per struct

This commit is contained in:
Derrick Hammer 2024-02-28 08:47:33 -05:00
parent 1a20a7d35f
commit cb558cdfc3
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
7 changed files with 171 additions and 19 deletions

View File

@ -3,8 +3,8 @@ package config
import (
"errors"
"fmt"
"reflect"
"github.com/docker/go-units"
"github.com/spf13/viper"
"go.uber.org/zap"
)
@ -17,6 +17,14 @@ var (
}
)
type Defaults interface {
Defaults() map[string]interface{}
}
type Validator interface {
Validate() error
}
type Config struct {
Core CoreConfig `mapstructure:"core"`
Protocol map[string]ProtocolConfig `mapstructure:"protocol"`
@ -41,7 +49,7 @@ func NewManager() (*Manager, error) {
root: &config,
}
m.setDefaults(m.coreDefaults(), "")
m.setDefaultsForObject(m.root.Core, "")
err = m.maybeSave()
if err != nil {
return nil, err
@ -52,29 +60,126 @@ func NewManager() (*Manager, error) {
return nil, err
}
err = m.validateObject(m.root)
if err != nil {
return nil, err
}
return m, nil
}
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
defaults := cfg.Defaults()
protocolPrefix := fmt.Sprintf("protocol.%s", name)
m.setDefaults(defaults, protocolPrefix)
m.setDefaultsForObject(cfg, protocolPrefix)
err := m.maybeSave()
if err != nil {
return err
}
return m.viper.Sub(protocolPrefix).Unmarshal(cfg)
err = m.viper.Sub(protocolPrefix).Unmarshal(cfg)
if err != nil {
return err
}
err = m.validateObject(cfg)
if err != nil {
return err
}
return nil
}
func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) {
for key, value := range defaults {
if prefix != "" {
key = fmt.Sprintf("%s.%s", prefix, key)
func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) {
// Reflect on the object to traverse its fields
objValue := reflect.ValueOf(obj)
objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
if m.setDefault(key, value) {
// Check if the object itself implements Defaults
if setter, ok := obj.(Defaults); ok {
m.applyDefaults(setter, prefix)
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
mapstructureTag := fieldType.Tag.Get("mapstructure")
// Construct new prefix based on the mapstructure tag, if available
newPrefix := prefix
if mapstructureTag != "" && mapstructureTag != "-" {
if newPrefix != "" {
newPrefix += "."
}
newPrefix += mapstructureTag
}
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
m.setDefaultsForObject(field.Interface(), newPrefix)
}
}
}
func (m *Manager) validateObject(obj interface{}) error {
// Reflect on the object to traverse its fields
objValue := reflect.ValueOf(obj)
objType := reflect.TypeOf(obj)
// If the object is a pointer, we need to work with its element
if objValue.Kind() == reflect.Ptr {
objValue = objValue.Elem()
objType = objType.Elem()
}
// Check if the object itself implements Defaults
if validator, ok := obj.(Validator); ok {
err := validator.Validate()
if err != nil {
return err
}
}
// Recursively handle struct fields
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldType := objType.Field(i)
// If field is a struct or pointer to a struct, recurse
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Initialize nil pointer to struct
field.Set(reflect.New(fieldType.Type.Elem()))
}
err := m.validateObject(field.Interface())
if err != nil {
return err
}
}
}
return nil
}
func (m *Manager) applyDefaults(setter Defaults, prefix string) {
defaults := setter.Defaults()
for key, value := range defaults {
fullKey := key
if prefix != "" {
fullKey = fmt.Sprintf("%s.%s", prefix, key)
}
if m.setDefault(fullKey, value) {
m.changes = true
}
}
@ -102,13 +207,7 @@ func (m *Manager) maybeSave() error {
}
func (m *Manager) coreDefaults() map[string]interface{} {
return map[string]interface{}{
"core.post_upload_limit": units.MiB * 100,
"core.log.level": "info",
"core.db.charset": "utf8mb4",
"core.db.port": 3306,
"core.db.name": "portal",
}
return map[string]interface{}{}
}
func (m *Manager) Config() *Config {

View File

@ -1,5 +1,9 @@
package config
import "github.com/docker/go-units"
var _ Defaults = (*CoreConfig)(nil)
type CoreConfig struct {
DB DatabaseConfig `mapstructure:"db"`
Domain string `mapstructure:"domain"`
@ -15,3 +19,9 @@ type CoreConfig struct {
Mail MailConfig `mapstructure:"mail"`
Clustered *ClusterConfig `mapstructure:"clustered"`
}
func (c CoreConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"post_upload_limit": units.MiB * 100,
}
}

View File

@ -6,6 +6,8 @@ import (
"github.com/mitchellh/mapstructure"
)
var _ Defaults = (*DatabaseConfig)(nil)
type DatabaseConfig struct {
Charset string `mapstructure:"charset"`
Host string `mapstructure:"host"`
@ -16,6 +18,14 @@ type DatabaseConfig struct {
Cache *CacheConfig `mapstructure:"cache"`
}
func (d DatabaseConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"charset": "utf8mb4",
"port": 3306,
"name": "portal",
}
}
type CacheConfig struct {
Mode string `mapstructure:"mode"`
Options interface{} `mapstructure:"options"`

View File

@ -2,6 +2,8 @@ package config
import "errors"
var _ Defaults = (*EtcdConfig)(nil)
type EtcdConfig struct {
Endpoints []string `mapstructure:"endpoints"`
DialTimeout int `mapstructure:"dial_timeout"`
@ -13,3 +15,9 @@ func (r *EtcdConfig) Validate() error {
}
return nil
}
func (r *EtcdConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"dial_timeout": 5,
}
}

View File

@ -1,5 +1,13 @@
package config
var _ Defaults = (*LogConfig)(nil)
type LogConfig struct {
Level string `mapstructure:"level"`
}
func (l LogConfig) Defaults() map[string]interface{} {
return map[string]interface{}{
"level": "info",
}
}

View File

@ -1,5 +1,9 @@
package config
import "errors"
var _ Validator = (*MailConfig)(nil)
type MailConfig struct {
Host string
Port int
@ -8,3 +12,16 @@ type MailConfig struct {
Username string
Password string
}
func (m MailConfig) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Username == "" {
return errors.New("username is required")
}
if m.Password == "" {
return errors.New("password is required")
}
return nil
}

View File

@ -1,5 +1,5 @@
package config
type ProtocolConfig interface {
Defaults() map[string]interface{}
Defaults
}