feat: add reflection-based system to manage defaults and validations per struct
This commit is contained in:
parent
1a20a7d35f
commit
cb558cdfc3
135
config/config.go
135
config/config.go
|
@ -3,8 +3,8 @@ package config
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/docker/go-units"
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,14 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Defaults interface {
|
||||||
|
Defaults() map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator interface {
|
||||||
|
Validate() error
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Core CoreConfig `mapstructure:"core"`
|
Core CoreConfig `mapstructure:"core"`
|
||||||
Protocol map[string]ProtocolConfig `mapstructure:"protocol"`
|
Protocol map[string]ProtocolConfig `mapstructure:"protocol"`
|
||||||
|
@ -41,7 +49,7 @@ func NewManager() (*Manager, error) {
|
||||||
root: &config,
|
root: &config,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.setDefaults(m.coreDefaults(), "")
|
m.setDefaultsForObject(m.root.Core, "")
|
||||||
err = m.maybeSave()
|
err = m.maybeSave()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -52,29 +60,126 @@ func NewManager() (*Manager, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = m.validateObject(m.root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
|
func (m *Manager) ConfigureProtocol(name string, cfg ProtocolConfig) error {
|
||||||
defaults := cfg.Defaults()
|
|
||||||
|
|
||||||
protocolPrefix := fmt.Sprintf("protocol.%s", name)
|
protocolPrefix := fmt.Sprintf("protocol.%s", name)
|
||||||
|
|
||||||
m.setDefaults(defaults, protocolPrefix)
|
m.setDefaultsForObject(cfg, protocolPrefix)
|
||||||
err := m.maybeSave()
|
err := m.maybeSave()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.viper.Sub(protocolPrefix).Unmarshal(cfg)
|
err = m.viper.Sub(protocolPrefix).Unmarshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.validateObject(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) setDefaults(defaults map[string]interface{}, prefix string) {
|
func (m *Manager) setDefaultsForObject(obj interface{}, prefix string) {
|
||||||
for key, value := range defaults {
|
// Reflect on the object to traverse its fields
|
||||||
if prefix != "" {
|
objValue := reflect.ValueOf(obj)
|
||||||
key = fmt.Sprintf("%s.%s", prefix, key)
|
objType := reflect.TypeOf(obj)
|
||||||
|
|
||||||
|
// If the object is a pointer, we need to work with its element
|
||||||
|
if objValue.Kind() == reflect.Ptr {
|
||||||
|
objValue = objValue.Elem()
|
||||||
|
objType = objType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the object itself implements Defaults
|
||||||
|
if setter, ok := obj.(Defaults); ok {
|
||||||
|
m.applyDefaults(setter, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively handle struct fields
|
||||||
|
for i := 0; i < objValue.NumField(); i++ {
|
||||||
|
field := objValue.Field(i)
|
||||||
|
fieldType := objType.Field(i)
|
||||||
|
mapstructureTag := fieldType.Tag.Get("mapstructure")
|
||||||
|
|
||||||
|
// Construct new prefix based on the mapstructure tag, if available
|
||||||
|
newPrefix := prefix
|
||||||
|
if mapstructureTag != "" && mapstructureTag != "-" {
|
||||||
|
if newPrefix != "" {
|
||||||
|
newPrefix += "."
|
||||||
|
}
|
||||||
|
newPrefix += mapstructureTag
|
||||||
}
|
}
|
||||||
if m.setDefault(key, value) {
|
|
||||||
|
// If field is a struct or pointer to a struct, recurse
|
||||||
|
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
|
||||||
|
if field.Kind() == reflect.Ptr && field.IsNil() {
|
||||||
|
// Initialize nil pointer to struct
|
||||||
|
field.Set(reflect.New(fieldType.Type.Elem()))
|
||||||
|
}
|
||||||
|
m.setDefaultsForObject(field.Interface(), newPrefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) validateObject(obj interface{}) error {
|
||||||
|
// Reflect on the object to traverse its fields
|
||||||
|
objValue := reflect.ValueOf(obj)
|
||||||
|
objType := reflect.TypeOf(obj)
|
||||||
|
|
||||||
|
// If the object is a pointer, we need to work with its element
|
||||||
|
if objValue.Kind() == reflect.Ptr {
|
||||||
|
objValue = objValue.Elem()
|
||||||
|
objType = objType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the object itself implements Defaults
|
||||||
|
if validator, ok := obj.(Validator); ok {
|
||||||
|
err := validator.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively handle struct fields
|
||||||
|
for i := 0; i < objValue.NumField(); i++ {
|
||||||
|
field := objValue.Field(i)
|
||||||
|
fieldType := objType.Field(i)
|
||||||
|
|
||||||
|
// If field is a struct or pointer to a struct, recurse
|
||||||
|
if field.Kind() == reflect.Struct || (field.Kind() == reflect.Ptr && field.Elem().Kind() == reflect.Struct) {
|
||||||
|
if field.Kind() == reflect.Ptr && field.IsNil() {
|
||||||
|
// Initialize nil pointer to struct
|
||||||
|
field.Set(reflect.New(fieldType.Type.Elem()))
|
||||||
|
}
|
||||||
|
err := m.validateObject(field.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyDefaults(setter Defaults, prefix string) {
|
||||||
|
defaults := setter.Defaults()
|
||||||
|
for key, value := range defaults {
|
||||||
|
fullKey := key
|
||||||
|
if prefix != "" {
|
||||||
|
fullKey = fmt.Sprintf("%s.%s", prefix, key)
|
||||||
|
}
|
||||||
|
if m.setDefault(fullKey, value) {
|
||||||
m.changes = true
|
m.changes = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,13 +207,7 @@ func (m *Manager) maybeSave() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) coreDefaults() map[string]interface{} {
|
func (m *Manager) coreDefaults() map[string]interface{} {
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{}
|
||||||
"core.post_upload_limit": units.MiB * 100,
|
|
||||||
"core.log.level": "info",
|
|
||||||
"core.db.charset": "utf8mb4",
|
|
||||||
"core.db.port": 3306,
|
|
||||||
"core.db.name": "portal",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Config() *Config {
|
func (m *Manager) Config() *Config {
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
|
import "github.com/docker/go-units"
|
||||||
|
|
||||||
|
var _ Defaults = (*CoreConfig)(nil)
|
||||||
|
|
||||||
type CoreConfig struct {
|
type CoreConfig struct {
|
||||||
DB DatabaseConfig `mapstructure:"db"`
|
DB DatabaseConfig `mapstructure:"db"`
|
||||||
Domain string `mapstructure:"domain"`
|
Domain string `mapstructure:"domain"`
|
||||||
|
@ -15,3 +19,9 @@ type CoreConfig struct {
|
||||||
Mail MailConfig `mapstructure:"mail"`
|
Mail MailConfig `mapstructure:"mail"`
|
||||||
Clustered *ClusterConfig `mapstructure:"clustered"`
|
Clustered *ClusterConfig `mapstructure:"clustered"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c CoreConfig) Defaults() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"post_upload_limit": units.MiB * 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ Defaults = (*DatabaseConfig)(nil)
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
Charset string `mapstructure:"charset"`
|
Charset string `mapstructure:"charset"`
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
|
@ -16,6 +18,14 @@ type DatabaseConfig struct {
|
||||||
Cache *CacheConfig `mapstructure:"cache"`
|
Cache *CacheConfig `mapstructure:"cache"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d DatabaseConfig) Defaults() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
"port": 3306,
|
||||||
|
"name": "portal",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type CacheConfig struct {
|
type CacheConfig struct {
|
||||||
Mode string `mapstructure:"mode"`
|
Mode string `mapstructure:"mode"`
|
||||||
Options interface{} `mapstructure:"options"`
|
Options interface{} `mapstructure:"options"`
|
||||||
|
|
|
@ -2,6 +2,8 @@ package config
|
||||||
|
|
||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
|
var _ Defaults = (*EtcdConfig)(nil)
|
||||||
|
|
||||||
type EtcdConfig struct {
|
type EtcdConfig struct {
|
||||||
Endpoints []string `mapstructure:"endpoints"`
|
Endpoints []string `mapstructure:"endpoints"`
|
||||||
DialTimeout int `mapstructure:"dial_timeout"`
|
DialTimeout int `mapstructure:"dial_timeout"`
|
||||||
|
@ -13,3 +15,9 @@ func (r *EtcdConfig) Validate() error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *EtcdConfig) Defaults() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"dial_timeout": 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
|
var _ Defaults = (*LogConfig)(nil)
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
Level string `mapstructure:"level"`
|
Level string `mapstructure:"level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l LogConfig) Defaults() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var _ Validator = (*MailConfig)(nil)
|
||||||
|
|
||||||
type MailConfig struct {
|
type MailConfig struct {
|
||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
|
@ -8,3 +12,16 @@ type MailConfig struct {
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m MailConfig) Validate() error {
|
||||||
|
if m.Host == "" {
|
||||||
|
return errors.New("host is required")
|
||||||
|
}
|
||||||
|
if m.Username == "" {
|
||||||
|
return errors.New("username is required")
|
||||||
|
}
|
||||||
|
if m.Password == "" {
|
||||||
|
return errors.New("password is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
type ProtocolConfig interface {
|
type ProtocolConfig interface {
|
||||||
Defaults() map[string]interface{}
|
Defaults
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue