package defaults import ( "encoding" "encoding/json" "errors" "reflect" "strconv" "time" ) var ( errInvalidType = errors.New("not a struct pointer") ) const ( fieldName = "default" ) // Set initializes members in a struct referenced by a pointer. // Maps and slices are initialized by `make` and other primitive types are set with default values. // `ptr` should be a struct pointer func Set(ptr interface{}) error { if reflect.TypeOf(ptr).Kind() != reflect.Ptr { return errInvalidType } v := reflect.ValueOf(ptr).Elem() t := v.Type() if t.Kind() != reflect.Struct { return errInvalidType } for i := 0; i < t.NumField(); i++ { if defaultVal := t.Field(i).Tag.Get(fieldName); defaultVal != "-" { if err := setField(v.Field(i), defaultVal); err != nil { return err } } } callSetter(ptr) return nil } // MustSet function is a wrapper of Set function // It will call Set and panic if err not equals nil. func MustSet(ptr interface{}) { if err := Set(ptr); err != nil { panic(err) } } func setField(field reflect.Value, defaultVal string) error { if !field.CanSet() { return nil } if !shouldInitializeField(field, defaultVal) { return nil } isInitial := isInitialValue(field) if isInitial { if unmarshalByInterface(field, defaultVal) { return nil } switch field.Kind() { case reflect.Bool: if val, err := strconv.ParseBool(defaultVal); err == nil { field.Set(reflect.ValueOf(val).Convert(field.Type())) } case reflect.Int: if val, err := strconv.ParseInt(defaultVal, 0, strconv.IntSize); err == nil { field.Set(reflect.ValueOf(int(val)).Convert(field.Type())) } case reflect.Int8: if val, err := strconv.ParseInt(defaultVal, 0, 8); err == nil { field.Set(reflect.ValueOf(int8(val)).Convert(field.Type())) } case reflect.Int16: if val, err := strconv.ParseInt(defaultVal, 0, 16); err == nil { field.Set(reflect.ValueOf(int16(val)).Convert(field.Type())) } case reflect.Int32: if val, err := strconv.ParseInt(defaultVal, 0, 32); err == nil { field.Set(reflect.ValueOf(int32(val)).Convert(field.Type())) } case reflect.Int64: if val, err := time.ParseDuration(defaultVal); err == nil { field.Set(reflect.ValueOf(val).Convert(field.Type())) } else if val, err := strconv.ParseInt(defaultVal, 0, 64); err == nil { field.Set(reflect.ValueOf(val).Convert(field.Type())) } case reflect.Uint: if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil { field.Set(reflect.ValueOf(uint(val)).Convert(field.Type())) } case reflect.Uint8: if val, err := strconv.ParseUint(defaultVal, 0, 8); err == nil { field.Set(reflect.ValueOf(uint8(val)).Convert(field.Type())) } case reflect.Uint16: if val, err := strconv.ParseUint(defaultVal, 0, 16); err == nil { field.Set(reflect.ValueOf(uint16(val)).Convert(field.Type())) } case reflect.Uint32: if val, err := strconv.ParseUint(defaultVal, 0, 32); err == nil { field.Set(reflect.ValueOf(uint32(val)).Convert(field.Type())) } case reflect.Uint64: if val, err := strconv.ParseUint(defaultVal, 0, 64); err == nil { field.Set(reflect.ValueOf(val).Convert(field.Type())) } case reflect.Uintptr: if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil { field.Set(reflect.ValueOf(uintptr(val)).Convert(field.Type())) } case reflect.Float32: if val, err := strconv.ParseFloat(defaultVal, 32); err == nil { field.Set(reflect.ValueOf(float32(val)).Convert(field.Type())) } case reflect.Float64: if val, err := strconv.ParseFloat(defaultVal, 64); err == nil { field.Set(reflect.ValueOf(val).Convert(field.Type())) } case reflect.String: field.Set(reflect.ValueOf(defaultVal).Convert(field.Type())) case reflect.Slice: ref := reflect.New(field.Type()) ref.Elem().Set(reflect.MakeSlice(field.Type(), 0, 0)) if defaultVal != "" && defaultVal != "[]" { if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { return err } } field.Set(ref.Elem().Convert(field.Type())) case reflect.Map: ref := reflect.New(field.Type()) ref.Elem().Set(reflect.MakeMap(field.Type())) if defaultVal != "" && defaultVal != "{}" { if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { return err } } field.Set(ref.Elem().Convert(field.Type())) case reflect.Struct: if defaultVal != "" && defaultVal != "{}" { if err := json.Unmarshal([]byte(defaultVal), field.Addr().Interface()); err != nil { return err } } case reflect.Ptr: field.Set(reflect.New(field.Type().Elem())) } } switch field.Kind() { case reflect.Ptr: if isInitial || field.Elem().Kind() == reflect.Struct { setField(field.Elem(), defaultVal) callSetter(field.Interface()) } case reflect.Struct: if err := Set(field.Addr().Interface()); err != nil { return err } case reflect.Slice: for j := 0; j < field.Len(); j++ { if err := setField(field.Index(j), ""); err != nil { return err } } case reflect.Map: for _, e := range field.MapKeys() { var v = field.MapIndex(e) switch v.Kind() { case reflect.Ptr: switch v.Elem().Kind() { case reflect.Struct, reflect.Slice, reflect.Map: if err := setField(v.Elem(), ""); err != nil { return err } } case reflect.Struct, reflect.Slice, reflect.Map: ref := reflect.New(v.Type()) ref.Elem().Set(v) if err := setField(ref.Elem(), ""); err != nil { return err } field.SetMapIndex(e, ref.Elem().Convert(v.Type())) } } } return nil } func unmarshalByInterface(field reflect.Value, defaultVal string) bool { asText, ok := field.Addr().Interface().(encoding.TextUnmarshaler) if ok && defaultVal != "" { // if field implements encode.TextUnmarshaler, try to use it before decode by kind if err := asText.UnmarshalText([]byte(defaultVal)); err == nil { return true } } asJSON, ok := field.Addr().Interface().(json.Unmarshaler) if ok && defaultVal != "" && defaultVal != "{}" && defaultVal != "[]" { // if field implements json.Unmarshaler, try to use it before decode by kind if err := asJSON.UnmarshalJSON([]byte(defaultVal)); err == nil { return true } } return false } func isInitialValue(field reflect.Value) bool { return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface()) } func shouldInitializeField(field reflect.Value, tag string) bool { switch field.Kind() { case reflect.Struct: return true case reflect.Ptr: if !field.IsNil() && field.Elem().Kind() == reflect.Struct { return true } case reflect.Slice: return field.Len() > 0 || tag != "" case reflect.Map: return field.Len() > 0 || tag != "" } return tag != "" } // CanUpdate returns true when the given value is an initial value of its type func CanUpdate(v interface{}) bool { return isInitialValue(reflect.ValueOf(v)) }