feat: add reflection-based system to manage defaults and validations per struct
This commit is contained in:
parent
1a20a7d35f
commit
cb558cdfc3
135
config/config.go
135
config/config.go
|
@ -3,8 +3,8 @@ package config
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
@ -17,6 +17,14 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type Defaults interface {
|
||||
Defaults() map[string]interface{}
|
||||
}
|
||||
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Core CoreConfig `mapstructure:"core"`
|
||||
Protocol map[string]ProtocolConfig `mapstructure:"protocol"`
|
||||
|
@ -41,7 +49,7 @@ func NewManager() (*Manager, error) {
|
|||
root: &config,
|
||||
}
|
||||
|
||||
m.setDefaults(m.coreDefaults(), "")
|
||||
m.setDefaultsForObject(m.root.Core, "")
|
||||
err = m.maybeSave()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -52,29 +60,126 @@ func NewManager() (*Manager, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = m.validateObject(m.root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
|
||||
defaults := cfg.Defaults()
|
||||
|
||||
protocolPrefix := fmt.Sprintf("protocol.%s", name)
|
||||
|
||||
m.setDefaults(defaults, protocolPrefix)
|
||||
m.setDefaultsForObject(cfg, protocolPrefix)
|
||||
err := m.maybeSave()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.viper.Sub(protocolPrefix).Unmarshal(cfg)
|
||||
err = m.viper.Sub(protocolPrefix).Unmarshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.validateObject(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) {
|
||||
for key, value := range defaults {
|
||||
if prefix != "" {
|
||||
key = fmt.Sprintf("%s.%s", prefix, key)
|
||||
func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) {
|
||||
// Reflect on the object to traverse its fields
|
||||
objValue := reflect.ValueOf(obj)
|
||||
objType := reflect.TypeOf(obj)
|
||||
|
||||
// If the object is a pointer, we need to work with its element
|
||||
if objValue.Kind() == reflect.Ptr {
|
||||
objValue = objValue.Elem()
|
||||
objType = objType.Elem()
|
||||
}
|
||||
if m.setDefault(key, value) {
|
||||
|
||||
// Check if the object itself implements Defaults
|
||||
if setter, ok := obj.(Defaults); ok {
|
||||
m.applyDefaults(setter, prefix)
|
||||
}
|
||||
|
||||
// Recursively handle struct fields
|
||||
for i := 0; i < objValue.NumField(); i++ {
|
||||
field := objValue.Field(i)
|
||||
fieldType := objType.Field(i)
|
||||
mapstructureTag := fieldType.Tag.Get("mapstructure")
|
||||
|
||||
// Construct new prefix based on the mapstructure tag, if available
|
||||
newPrefix := prefix
|
||||
if mapstructureTag != "" && mapstructureTag != "-" {
|
||||
if newPrefix != "" {
|
||||
newPrefix += "."
|
||||
}
|
||||
newPrefix += mapstructureTag
|
||||
}
|
||||
|
||||
// If field is a struct or pointer to a struct, recurse
|
||||
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
|
||||
if field.Kind() == reflect.Ptr && field.IsNil() {
|
||||
// Initialize nil pointer to struct
|
||||
field.Set(reflect.New(fieldType.Type.Elem()))
|
||||
}
|
||||
m.setDefaultsForObject(field.Interface(), newPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) validateObject(obj interface{}) error {
|
||||
// Reflect on the object to traverse its fields
|
||||
objValue := reflect.ValueOf(obj)
|
||||
objType := reflect.TypeOf(obj)
|
||||
|
||||
// If the object is a pointer, we need to work with its element
|
||||
if objValue.Kind() == reflect.Ptr {
|
||||
objValue = objValue.Elem()
|
||||
objType = objType.Elem()
|
||||
}
|
||||
|
||||
// Check if the object itself implements Defaults
|
||||
if validator, ok := obj.(Validator); ok {
|
||||
err := validator.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively handle struct fields
|
||||
for i := 0; i < objValue.NumField(); i++ {
|
||||
field := objValue.Field(i)
|
||||
fieldType := objType.Field(i)
|
||||
|
||||
// If field is a struct or pointer to a struct, recurse
|
||||
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
|
||||
if field.Kind() == reflect.Ptr && field.IsNil() {
|
||||
// Initialize nil pointer to struct
|
||||
field.Set(reflect.New(fieldType.Type.Elem()))
|
||||
}
|
||||
err := m.validateObject(field.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) applyDefaults(setter Defaults, prefix string) {
|
||||
defaults := setter.Defaults()
|
||||
for key, value := range defaults {
|
||||
fullKey := key
|
||||
if prefix != "" {
|
||||
fullKey = fmt.Sprintf("%s.%s", prefix, key)
|
||||
}
|
||||
if m.setDefault(fullKey, value) {
|
||||
m.changes = true
|
||||
}
|
||||
}
|
||||
|
@ -102,13 +207,7 @@ func (m *Manager) maybeSave() error {
|
|||
}
|
||||
|
||||
func (m *Manager) coreDefaults() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"core.post_upload_limit": units.MiB * 100,
|
||||
"core.log.level": "info",
|
||||
"core.db.charset": "utf8mb4",
|
||||
"core.db.port": 3306,
|
||||
"core.db.name": "portal",
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func (m *Manager) Config() *Config {
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package config
|
||||
|
||||
import "github.com/docker/go-units"
|
||||
|
||||
var _ Defaults = (*CoreConfig)(nil)
|
||||
|
||||
type CoreConfig struct {
|
||||
DB DatabaseConfig `mapstructure:"db"`
|
||||
Domain string `mapstructure:"domain"`
|
||||
|
@ -15,3 +19,9 @@ type CoreConfig struct {
|
|||
Mail MailConfig `mapstructure:"mail"`
|
||||
Clustered *ClusterConfig `mapstructure:"clustered"`
|
||||
}
|
||||
|
||||
func (c CoreConfig) Defaults() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"post_upload_limit": units.MiB * 100,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
var _ Defaults = (*DatabaseConfig)(nil)
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Charset string `mapstructure:"charset"`
|
||||
Host string `mapstructure:"host"`
|
||||
|
@ -16,6 +18,14 @@ type DatabaseConfig struct {
|
|||
Cache *CacheConfig `mapstructure:"cache"`
|
||||
}
|
||||
|
||||
func (d DatabaseConfig) Defaults() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"charset": "utf8mb4",
|
||||
"port": 3306,
|
||||
"name": "portal",
|
||||
}
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Mode string `mapstructure:"mode"`
|
||||
Options interface{} `mapstructure:"options"`
|
||||
|
|
|
@ -2,6 +2,8 @@ package config
|
|||
|
||||
import "errors"
|
||||
|
||||
var _ Defaults = (*EtcdConfig)(nil)
|
||||
|
||||
type EtcdConfig struct {
|
||||
Endpoints []string `mapstructure:"endpoints"`
|
||||
DialTimeout int `mapstructure:"dial_timeout"`
|
||||
|
@ -13,3 +15,9 @@ func (r *EtcdConfig) Validate() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *EtcdConfig) Defaults() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"dial_timeout": 5,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
package config
|
||||
|
||||
var _ Defaults = (*LogConfig)(nil)
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
}
|
||||
|
||||
func (l LogConfig) Defaults() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"level": "info",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package config
|
||||
|
||||
import "errors"
|
||||
|
||||
var _ Validator = (*MailConfig)(nil)
|
||||
|
||||
type MailConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
|
@ -8,3 +12,16 @@ type MailConfig struct {
|
|||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (m MailConfig) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
if m.Password == "" {
|
||||
return errors.New("password is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
package config
|
||||
|
||||
type ProtocolConfig interface {
|
||||
Defaults() map[string]interface{}
|
||||
Defaults
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue