fix parsing overrides

This commit is contained in:
2025-06-22 12:44:21 +03:00
parent 85878f69d3
commit c4a92c67d4
15 changed files with 705 additions and 178 deletions

View File

@ -23,7 +23,6 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"strings"
@ -75,75 +74,99 @@ func New(info *distro.OSRelease, runner *interp.Runner) *Decoder {
// DecodeVar decodes a variable to val using reflection.
// Structs should use the "sh" struct tag.
func (d *Decoder) DecodeVar(name string, val any) error {
variable := d.getVar(name)
if variable == nil {
return VarNotFoundError{name}
}
origType := reflect.TypeOf(val).Elem()
isOverridableField := strings.Contains(origType.String(), "OverridableField[")
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
WeaklyTypedInput: true,
DecodeHook: mapstructure.DecodeHookFuncValue(func(from, to reflect.Value) (interface{}, error) {
if strings.Contains(to.Type().String(), "alrsh.OverridableField") {
if to.Kind() != reflect.Ptr && to.CanAddr() {
to = to.Addr()
}
if !isOverridableField {
variable := d.getVarNoOverrides(name)
if variable == nil {
return VarNotFoundError{name}
}
names, err := overrides.Resolve(d.info, overrides.DefaultOpts.WithName(name))
if err != nil {
return nil, err
}
isNotSet := true
setMethod := to.MethodByName("Set")
setResolvedMethod := to.MethodByName("SetResolved")
for _, varName := range names {
val := d.getVarNoOverrides(varName)
if val == nil {
continue
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
WeaklyTypedInput: true,
Result: val, // передаем указатель на новое значение
TagName: "sh",
DecodeHook: mapstructure.DecodeHookFuncValue(func(from, to reflect.Value) (interface{}, error) {
if from.Kind() == reflect.Slice && to.Kind() == reflect.String {
s, ok := from.Interface().([]string)
if ok && len(s) == 1 {
return s[0], nil
}
t := setMethod.Type().In(1)
newVal := from
if !newVal.Type().AssignableTo(t) {
newVal = reflect.New(t)
err = d.DecodeVar(name, newVal.Interface())
if err != nil {
return nil, err
}
newVal = newVal.Elem()
}
if isNotSet {
setResolvedMethod.Call([]reflect.Value{newVal})
}
override := strings.TrimPrefix(strings.TrimPrefix(varName, name), "_")
setMethod.Call([]reflect.Value{reflect.ValueOf(override), newVal})
}
return from.Interface(), nil
}),
})
if err != nil {
return err
}
return to, nil
switch variable.Kind {
case expand.Indexed:
return dec.Decode(variable.List)
case expand.Associative:
return dec.Decode(variable.Map)
default:
return dec.Decode(variable.Str)
}
} else {
vars := d.getVarsByPrefix(name)
if len(vars) == 0 {
return VarNotFoundError{name}
}
reflectVal := reflect.ValueOf(val)
overridableVal := reflect.ValueOf(val).Elem()
dataField := overridableVal.FieldByName("data")
if !dataField.IsValid() {
return fmt.Errorf("data field not found in OverridableField")
}
mapType := dataField.Type() // map[string]T
elemType := mapType.Elem() // T
var overridablePtr reflect.Value
if reflectVal.Kind() == reflect.Ptr {
overridablePtr = reflectVal
} else {
if !reflectVal.CanAddr() {
return fmt.Errorf("OverridableField value is not addressable")
}
return from.Interface(), nil
}),
Result: val,
TagName: "sh",
})
if err != nil {
slog.Warn("err", "err", err)
return err
}
overridablePtr = reflectVal.Addr()
}
switch variable.Kind {
case expand.Indexed:
return dec.Decode(variable.List)
case expand.Associative:
return dec.Decode(variable.Map)
default:
return dec.Decode(variable.Str)
setValue := overridablePtr.MethodByName("Set")
if !setValue.IsValid() {
return fmt.Errorf("method Set not found on OverridableField")
}
for _, v := range vars {
varName := v.Name
key := strings.TrimPrefix(strings.TrimPrefix(varName, name), "_")
newVal := reflect.New(elemType)
if err := d.DecodeVar(varName, newVal.Interface()); err != nil {
return err
}
keyValue := reflect.ValueOf(key)
setValue.Call([]reflect.Value{keyValue, newVal.Elem()})
}
resolveValue := overridablePtr.MethodByName("Resolve")
if !resolveValue.IsValid() {
return fmt.Errorf("method Resolve not found on OverridableField")
}
names, err := overrides.Resolve(d.info, overrides.DefaultOpts)
if err != nil {
return err
}
resolveValue.Call([]reflect.Value{reflect.ValueOf(names)})
return nil
}
}
@ -284,23 +307,6 @@ func (d *Decoder) getFunc(name string) *syntax.Stmt {
return nil
}
// getVar gets a variable based on its name, taking into account
// override variables and nameref variables.
func (d *Decoder) getVar(name string) *expand.Variable {
names, err := overrides.Resolve(d.info, overrides.DefaultOpts.WithName(name))
if err != nil {
return nil
}
for _, varName := range names {
res := d.getVarNoOverrides(varName)
if res != nil {
return res
}
}
return nil
}
func (d *Decoder) getVarNoOverrides(name string) *expand.Variable {
val, ok := d.Runner.Vars[name]
if ok {
@ -318,6 +324,32 @@ func (d *Decoder) getVarNoOverrides(name string) *expand.Variable {
return nil
}
type vars struct {
Name string
Value *expand.Variable
}
func (d *Decoder) getVarsByPrefix(prefix string) []*vars {
result := make([]*vars, 0)
for name, val := range d.Runner.Vars {
if !strings.HasPrefix(name, prefix) {
continue
}
switch prefix {
case "auto_req":
if strings.HasPrefix(name, "auto_req_skiplist") {
continue
}
case "auto_prov":
if strings.HasPrefix(name, "auto_prov_skiplist") {
continue
}
}
result = append(result, &vars{name, &val})
}
return result
}
func IsTruthy(value string) bool {
value = strings.ToLower(strings.TrimSpace(value))
return value == "true" || value == "yes" || value == "1"