fix parsing overrides
This commit is contained in:
		| @@ -23,7 +23,6 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"log/slog" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
|  | ||||
| @@ -75,75 +74,99 @@ func New(info *distro.OSRelease, runner *interp.Runner) *Decoder { | ||||
| // DecodeVar decodes a variable to val using reflection. | ||||
| // Structs should use the "sh" struct tag. | ||||
| func (d *Decoder) DecodeVar(name string, val any) error { | ||||
| 	variable := d.getVar(name) | ||||
| 	if variable == nil { | ||||
| 		return VarNotFoundError{name} | ||||
| 	} | ||||
| 	origType := reflect.TypeOf(val).Elem() | ||||
| 	isOverridableField := strings.Contains(origType.String(), "OverridableField[") | ||||
|  | ||||
| 	dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ | ||||
| 		WeaklyTypedInput: true, | ||||
| 		DecodeHook: mapstructure.DecodeHookFuncValue(func(from, to reflect.Value) (interface{}, error) { | ||||
| 			if strings.Contains(to.Type().String(), "alrsh.OverridableField") { | ||||
| 				if to.Kind() != reflect.Ptr && to.CanAddr() { | ||||
| 					to = to.Addr() | ||||
| 				} | ||||
| 	if !isOverridableField { | ||||
| 		variable := d.getVarNoOverrides(name) | ||||
| 		if variable == nil { | ||||
| 			return VarNotFoundError{name} | ||||
| 		} | ||||
|  | ||||
| 				names, err := overrides.Resolve(d.info, overrides.DefaultOpts.WithName(name)) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
|  | ||||
| 				isNotSet := true | ||||
|  | ||||
| 				setMethod := to.MethodByName("Set") | ||||
| 				setResolvedMethod := to.MethodByName("SetResolved") | ||||
|  | ||||
| 				for _, varName := range names { | ||||
| 					val := d.getVarNoOverrides(varName) | ||||
| 					if val == nil { | ||||
| 						continue | ||||
| 		dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ | ||||
| 			WeaklyTypedInput: true, | ||||
| 			Result:           val, // передаем указатель на новое значение | ||||
| 			TagName:          "sh", | ||||
| 			DecodeHook: mapstructure.DecodeHookFuncValue(func(from, to reflect.Value) (interface{}, error) { | ||||
| 				if from.Kind() == reflect.Slice && to.Kind() == reflect.String { | ||||
| 					s, ok := from.Interface().([]string) | ||||
| 					if ok && len(s) == 1 { | ||||
| 						return s[0], nil | ||||
| 					} | ||||
|  | ||||
| 					t := setMethod.Type().In(1) | ||||
|  | ||||
| 					newVal := from | ||||
|  | ||||
| 					if !newVal.Type().AssignableTo(t) { | ||||
| 						newVal = reflect.New(t) | ||||
| 						err = d.DecodeVar(name, newVal.Interface()) | ||||
| 						if err != nil { | ||||
| 							return nil, err | ||||
| 						} | ||||
| 						newVal = newVal.Elem() | ||||
| 					} | ||||
|  | ||||
| 					if isNotSet { | ||||
| 						setResolvedMethod.Call([]reflect.Value{newVal}) | ||||
| 					} | ||||
|  | ||||
| 					override := strings.TrimPrefix(strings.TrimPrefix(varName, name), "_") | ||||
| 					setMethod.Call([]reflect.Value{reflect.ValueOf(override), newVal}) | ||||
| 				} | ||||
| 				return from.Interface(), nil | ||||
| 			}), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 				return to, nil | ||||
| 		switch variable.Kind { | ||||
| 		case expand.Indexed: | ||||
| 			return dec.Decode(variable.List) | ||||
| 		case expand.Associative: | ||||
| 			return dec.Decode(variable.Map) | ||||
| 		default: | ||||
| 			return dec.Decode(variable.Str) | ||||
| 		} | ||||
| 	} else { | ||||
| 		vars := d.getVarsByPrefix(name) | ||||
|  | ||||
| 		if len(vars) == 0 { | ||||
| 			return VarNotFoundError{name} | ||||
| 		} | ||||
|  | ||||
| 		reflectVal := reflect.ValueOf(val) | ||||
| 		overridableVal := reflect.ValueOf(val).Elem() | ||||
|  | ||||
| 		dataField := overridableVal.FieldByName("data") | ||||
| 		if !dataField.IsValid() { | ||||
| 			return fmt.Errorf("data field not found in OverridableField") | ||||
| 		} | ||||
| 		mapType := dataField.Type() // map[string]T | ||||
| 		elemType := mapType.Elem()  // T | ||||
|  | ||||
| 		var overridablePtr reflect.Value | ||||
| 		if reflectVal.Kind() == reflect.Ptr { | ||||
| 			overridablePtr = reflectVal | ||||
| 		} else { | ||||
| 			if !reflectVal.CanAddr() { | ||||
| 				return fmt.Errorf("OverridableField value is not addressable") | ||||
| 			} | ||||
| 			return from.Interface(), nil | ||||
| 		}), | ||||
| 		Result:  val, | ||||
| 		TagName: "sh", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		slog.Warn("err", "err", err) | ||||
| 		return err | ||||
| 	} | ||||
| 			overridablePtr = reflectVal.Addr() | ||||
| 		} | ||||
|  | ||||
| 	switch variable.Kind { | ||||
| 	case expand.Indexed: | ||||
| 		return dec.Decode(variable.List) | ||||
| 	case expand.Associative: | ||||
| 		return dec.Decode(variable.Map) | ||||
| 	default: | ||||
| 		return dec.Decode(variable.Str) | ||||
| 		setValue := overridablePtr.MethodByName("Set") | ||||
| 		if !setValue.IsValid() { | ||||
| 			return fmt.Errorf("method Set not found on OverridableField") | ||||
| 		} | ||||
|  | ||||
| 		for _, v := range vars { | ||||
| 			varName := v.Name | ||||
|  | ||||
| 			key := strings.TrimPrefix(strings.TrimPrefix(varName, name), "_") | ||||
| 			newVal := reflect.New(elemType) | ||||
|  | ||||
| 			if err := d.DecodeVar(varName, newVal.Interface()); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			keyValue := reflect.ValueOf(key) | ||||
| 			setValue.Call([]reflect.Value{keyValue, newVal.Elem()}) | ||||
| 		} | ||||
|  | ||||
| 		resolveValue := overridablePtr.MethodByName("Resolve") | ||||
| 		if !resolveValue.IsValid() { | ||||
| 			return fmt.Errorf("method Resolve not found on OverridableField") | ||||
| 		} | ||||
|  | ||||
| 		names, err := overrides.Resolve(d.info, overrides.DefaultOpts) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		resolveValue.Call([]reflect.Value{reflect.ValueOf(names)}) | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -284,23 +307,6 @@ func (d *Decoder) getFunc(name string) *syntax.Stmt { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // getVar gets a variable based on its name, taking into account | ||||
| // override variables and nameref variables. | ||||
| func (d *Decoder) getVar(name string) *expand.Variable { | ||||
| 	names, err := overrides.Resolve(d.info, overrides.DefaultOpts.WithName(name)) | ||||
| 	if err != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	for _, varName := range names { | ||||
| 		res := d.getVarNoOverrides(varName) | ||||
| 		if res != nil { | ||||
| 			return res | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (d *Decoder) getVarNoOverrides(name string) *expand.Variable { | ||||
| 	val, ok := d.Runner.Vars[name] | ||||
| 	if ok { | ||||
| @@ -318,6 +324,32 @@ func (d *Decoder) getVarNoOverrides(name string) *expand.Variable { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type vars struct { | ||||
| 	Name  string | ||||
| 	Value *expand.Variable | ||||
| } | ||||
|  | ||||
| func (d *Decoder) getVarsByPrefix(prefix string) []*vars { | ||||
| 	result := make([]*vars, 0) | ||||
| 	for name, val := range d.Runner.Vars { | ||||
| 		if !strings.HasPrefix(name, prefix) { | ||||
| 			continue | ||||
| 		} | ||||
| 		switch prefix { | ||||
| 		case "auto_req": | ||||
| 			if strings.HasPrefix(name, "auto_req_skiplist") { | ||||
| 				continue | ||||
| 			} | ||||
| 		case "auto_prov": | ||||
| 			if strings.HasPrefix(name, "auto_prov_skiplist") { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 		result = append(result, &vars{name, &val}) | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func IsTruthy(value string) bool { | ||||
| 	value = strings.ToLower(strings.TrimSpace(value)) | ||||
| 	return value == "true" || value == "yes" || value == "1" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user