refactor: generate plugin executors
This commit is contained in:
		
							
								
								
									
										390
									
								
								generators/plugin-generator/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										390
									
								
								generators/plugin-generator/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,390 @@ | ||||
| // ALR - Any Linux Repository | ||||
| // Copyright (C) 2025 The ALR Authors | ||||
| // | ||||
| // This program is free software: you can redistribute it and/or modify | ||||
| // it under the terms of the GNU General Public License as published by | ||||
| // the Free Software Foundation, either version 3 of the License, or | ||||
| // (at your option) any later version. | ||||
| // | ||||
| // This program is distributed in the hope that it will be useful, | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| // GNU General Public License for more details. | ||||
| // | ||||
| // You should have received a copy of the GNU General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"go/ast" | ||||
| 	"go/format" | ||||
| 	"go/parser" | ||||
| 	"go/token" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"text/template" | ||||
|  | ||||
| 	"golang.org/x/text/cases" | ||||
| 	"golang.org/x/text/language" | ||||
| ) | ||||
|  | ||||
| type MethodInfo struct { | ||||
| 	Name       string | ||||
| 	Params     []ParamInfo | ||||
| 	Results    []ResultInfo | ||||
| 	EntityName string | ||||
| } | ||||
|  | ||||
| type ParamInfo struct { | ||||
| 	Name string | ||||
| 	Type string | ||||
| } | ||||
|  | ||||
| type ResultInfo struct { | ||||
| 	Name  string | ||||
| 	Type  string | ||||
| 	Index int | ||||
| } | ||||
|  | ||||
| func extractImports(node *ast.File) []string { | ||||
| 	var imports []string | ||||
| 	for _, imp := range node.Imports { | ||||
| 		if imp.Path.Value != "" { | ||||
| 			imports = append(imports, imp.Path.Value) | ||||
| 		} | ||||
| 	} | ||||
| 	return imports | ||||
| } | ||||
|  | ||||
| func output(path string, buf bytes.Buffer) { | ||||
| 	formatted, err := format.Source(buf.Bytes()) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("formatting: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	outPath := strings.TrimSuffix(path, ".go") + "_gen.go" | ||||
| 	outFile, err := os.Create(outPath) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("create file: %v", err) | ||||
| 	} | ||||
| 	_, err = outFile.Write(formatted) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("writing output: %v", err) | ||||
| 	} | ||||
| 	outFile.Close() | ||||
| } | ||||
|  | ||||
| func main() { | ||||
| 	path := os.Getenv("GOFILE") | ||||
| 	if path == "" { | ||||
| 		log.Fatal("GOFILE must be set") | ||||
| 	} | ||||
|  | ||||
| 	if len(os.Args) < 2 { | ||||
| 		log.Fatal("At least one entity name must be provided") | ||||
| 	} | ||||
|  | ||||
| 	entityNames := os.Args[1:] | ||||
|  | ||||
| 	fset := token.NewFileSet() | ||||
| 	node, err := parser.ParseFile(fset, path, nil, parser.AllErrors) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("parsing file: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	packageName := node.Name.Name | ||||
|  | ||||
| 	// Find all specified entities | ||||
| 	entityData := make(map[string][]*ast.Field) | ||||
|  | ||||
| 	for _, decl := range node.Decls { | ||||
| 		genDecl, ok := decl.(*ast.GenDecl) | ||||
| 		if !ok || genDecl.Tok != token.TYPE { | ||||
| 			continue | ||||
| 		} | ||||
| 		for _, spec := range genDecl.Specs { | ||||
| 			typeSpec := spec.(*ast.TypeSpec) | ||||
| 			for _, entityName := range entityNames { | ||||
| 				if typeSpec.Name.Name == entityName { | ||||
| 					interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) | ||||
| 					if !ok { | ||||
| 						log.Fatalf("entity %s is not an interface", entityName) | ||||
| 					} | ||||
| 					entityData[entityName] = interfaceType.Methods.List | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Verify all entities were found | ||||
| 	for _, entityName := range entityNames { | ||||
| 		if _, found := entityData[entityName]; !found { | ||||
| 			log.Fatalf("interface %s not found", entityName) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
|  | ||||
| 	buf.WriteString(` | ||||
| // DO NOT EDIT MANUALLY. This file is generated. | ||||
|  | ||||
| // ALR - Any Linux Repository | ||||
| // Copyright (C) 2025 The ALR Authors | ||||
| // | ||||
| // This program is free software: you can redistribute it and/or modify | ||||
| // it under the terms of the GNU General Public License as published by | ||||
| // the Free Software Foundation, either version 3 of the License, or | ||||
| // (at your option) any later version. | ||||
| // | ||||
| // This program is distributed in the hope that it will be useful, | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| // GNU General Public License for more details. | ||||
| // | ||||
| // You should have received a copy of the GNU General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
|  | ||||
|  | ||||
| `) | ||||
|  | ||||
| 	buf.WriteString(fmt.Sprintf("package %s\n", packageName)) | ||||
|  | ||||
| 	// Generate base structures for all entities | ||||
| 	baseStructs(&buf, entityNames, extractImports(node)) | ||||
|  | ||||
| 	// Generate method-specific code for each entity | ||||
| 	for _, entityName := range entityNames { | ||||
| 		methods := parseMethodsFromFields(entityName, entityData[entityName]) | ||||
| 		argsGen(&buf, methods) | ||||
| 	} | ||||
|  | ||||
| 	output(path, buf) | ||||
| } | ||||
|  | ||||
| func parseMethodsFromFields(entityName string, fields []*ast.Field) []MethodInfo { | ||||
| 	var methods []MethodInfo | ||||
|  | ||||
| 	for _, field := range fields { | ||||
| 		if len(field.Names) == 0 { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		methodName := field.Names[0].Name | ||||
| 		funcType, ok := field.Type.(*ast.FuncType) | ||||
| 		if !ok { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		method := MethodInfo{ | ||||
| 			Name:       methodName, | ||||
| 			EntityName: entityName, | ||||
| 		} | ||||
|  | ||||
| 		// Parse parameters, excluding context.Context | ||||
| 		if funcType.Params != nil { | ||||
| 			for i, param := range funcType.Params.List { | ||||
| 				paramType := typeToString(param.Type) | ||||
| 				// Skip context.Context parameters | ||||
| 				if paramType == "context.Context" { | ||||
| 					continue | ||||
| 				} | ||||
| 				if len(param.Names) == 0 { | ||||
| 					method.Params = append(method.Params, ParamInfo{ | ||||
| 						Name: fmt.Sprintf("Arg%d", i), | ||||
| 						Type: paramType, | ||||
| 					}) | ||||
| 				} else { | ||||
| 					for _, name := range param.Names { | ||||
| 						method.Params = append(method.Params, ParamInfo{ | ||||
| 							Name: cases.Title(language.Und, cases.NoLower).String(name.Name), | ||||
| 							Type: paramType, | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// Parse results | ||||
| 		if funcType.Results != nil { | ||||
| 			resultIndex := 0 | ||||
| 			for _, result := range funcType.Results.List { | ||||
| 				resultType := typeToString(result.Type) | ||||
| 				if resultType == "error" { | ||||
| 					continue // Skip error in response struct | ||||
| 				} | ||||
|  | ||||
| 				if len(result.Names) == 0 { | ||||
| 					method.Results = append(method.Results, ResultInfo{ | ||||
| 						Name:  fmt.Sprintf("Result%d", resultIndex), | ||||
| 						Type:  resultType, | ||||
| 						Index: resultIndex, | ||||
| 					}) | ||||
| 				} else { | ||||
| 					for _, name := range result.Names { | ||||
| 						method.Results = append(method.Results, ResultInfo{ | ||||
| 							Name:  cases.Title(language.Und, cases.NoLower).String(name.Name), | ||||
| 							Type:  resultType, | ||||
| 							Index: resultIndex, | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 				resultIndex++ | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		methods = append(methods, method) | ||||
| 	} | ||||
|  | ||||
| 	return methods | ||||
| } | ||||
|  | ||||
| func argsGen(buf *bytes.Buffer, methods []MethodInfo) { | ||||
| 	// Add template functions first | ||||
| 	funcMap := template.FuncMap{ | ||||
| 		"lowerFirst": func(s string) string { | ||||
| 			if len(s) == 0 { | ||||
| 				return s | ||||
| 			} | ||||
| 			return strings.ToLower(s[:1]) + s[1:] | ||||
| 		}, | ||||
| 		"zeroValue": func(typeName string) string { | ||||
| 			switch typeName { | ||||
| 			case "string": | ||||
| 				return "\"\"" | ||||
| 			case "int", "int8", "int16", "int32", "int64": | ||||
| 				return "0" | ||||
| 			case "uint", "uint8", "uint16", "uint32", "uint64": | ||||
| 				return "0" | ||||
| 			case "float32", "float64": | ||||
| 				return "0.0" | ||||
| 			case "bool": | ||||
| 				return "false" | ||||
| 			default: | ||||
| 				return "nil" | ||||
| 			} | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	argsTemplate := template.Must(template.New("args").Funcs(funcMap).Parse(` | ||||
| {{range .}} | ||||
| type {{.EntityName}}{{.Name}}Args struct { | ||||
| {{range .Params}}	{{.Name}} {{.Type}} | ||||
| {{end}}} | ||||
|  | ||||
| type {{.EntityName}}{{.Name}}Resp struct { | ||||
| {{range .Results}}	{{.Name}} {{.Type}} | ||||
| {{end}}} | ||||
|  | ||||
| func (s *{{.EntityName}}RPC) {{.Name}}(ctx context.Context, {{range $i, $p := .Params}}{{if $i}}, {{end}}{{lowerFirst $p.Name}} {{$p.Type}}{{end}}) ({{range $i, $r := .Results}}{{if $i}}, {{end}}{{$r.Type}}{{end}}{{if .Results}}, {{end}}error) { | ||||
| 	var resp *{{.EntityName}}{{.Name}}Resp | ||||
| 	err := s.client.Call("Plugin.{{.Name}}", &{{.EntityName}}{{.Name}}Args{ | ||||
| {{range .Params}}		{{.Name}}: {{lowerFirst .Name}}, | ||||
| {{end}}	}, &resp) | ||||
| 	if err != nil { | ||||
| 		return {{range $i, $r := .Results}}{{if $i}}, {{end}}{{zeroValue $r.Type}}{{end}}{{if .Results}}, {{end}}err | ||||
| 	} | ||||
| 	return {{range $i, $r := .Results}}{{if $i}}, {{end}}resp.{{$r.Name}}{{end}}{{if .Results}}, {{end}}nil | ||||
| } | ||||
|  | ||||
| func (s *{{.EntityName}}RPCServer) {{.Name}}(args *{{.EntityName}}{{.Name}}Args, resp *{{.EntityName}}{{.Name}}Resp) error { | ||||
| 	{{if .Results}}{{range $i, $r := .Results}}{{if $i}}, {{end}}{{lowerFirst $r.Name}}{{end}}, err := {{else}}err := {{end}}s.Impl.{{.Name}}(context.Background(),{{range $i, $p := .Params}}{{if $i}}, {{end}}args.{{$p.Name}}{{end}}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	{{if .Results}}*resp = {{.EntityName}}{{.Name}}Resp{ | ||||
| {{range .Results}}		{{.Name}}: {{lowerFirst .Name}}, | ||||
| {{end}}	} | ||||
| 	{{else}}*resp = {{.EntityName}}{{.Name}}Resp{} | ||||
| 	{{end}}return nil | ||||
| } | ||||
| {{end}} | ||||
| `)) | ||||
|  | ||||
| 	err := argsTemplate.Execute(buf, methods) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("execute args template: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func typeToString(expr ast.Expr) string { | ||||
| 	switch t := expr.(type) { | ||||
| 	case *ast.Ident: | ||||
| 		return t.Name | ||||
| 	case *ast.StarExpr: | ||||
| 		return "*" + typeToString(t.X) | ||||
| 	case *ast.ArrayType: | ||||
| 		return "[]" + typeToString(t.Elt) | ||||
| 	case *ast.SelectorExpr: | ||||
| 		xStr := typeToString(t.X) | ||||
| 		if xStr == "context" && t.Sel.Name == "Context" { | ||||
| 			return "context.Context" | ||||
| 		} | ||||
| 		return xStr + "." + t.Sel.Name | ||||
| 	case *ast.InterfaceType: | ||||
| 		return "interface{}" | ||||
| 	default: | ||||
| 		return "interface{}" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func baseStructs(buf *bytes.Buffer, entityNames, imports []string) { | ||||
| 	// Ensure "context" is included in imports | ||||
| 	updatedImports := imports | ||||
| 	hasContext := false | ||||
| 	for _, imp := range imports { | ||||
| 		if strings.Contains(imp, `"context"`) { | ||||
| 			hasContext = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if !hasContext { | ||||
| 		updatedImports = append(updatedImports, `"context"`) | ||||
| 	} | ||||
|  | ||||
| 	contentTemplate := template.Must(template.New("").Parse(` | ||||
| import ( | ||||
| 	"net/rpc" | ||||
|  | ||||
| 	"github.com/hashicorp/go-plugin" | ||||
| {{range .Imports}}	{{.}} | ||||
| {{end}} | ||||
| ) | ||||
|  | ||||
| {{range .EntityNames}} | ||||
| type {{ . }}Plugin struct { | ||||
| 	Impl {{ . }} | ||||
| } | ||||
|  | ||||
| type {{ . }}RPCServer struct { | ||||
| 	Impl {{ . }} | ||||
| } | ||||
|  | ||||
| type {{ . }}RPC struct { | ||||
| 	client *rpc.Client | ||||
| } | ||||
|  | ||||
| func (p *{{ . }}Plugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { | ||||
| 	return &{{ . }}RPC{client: c}, nil | ||||
| } | ||||
|  | ||||
| func (p *{{ . }}Plugin) Server(*plugin.MuxBroker) (interface{}, error) { | ||||
| 	return &{{ . }}RPCServer{Impl: p.Impl}, nil | ||||
| } | ||||
|  | ||||
| {{end}} | ||||
| `)) | ||||
| 	err := contentTemplate.Execute(buf, struct { | ||||
| 		EntityNames []string | ||||
| 		Imports     []string | ||||
| 	}{ | ||||
| 		EntityNames: entityNames, | ||||
| 		Imports:     updatedImports, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("execute template: %v", err) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user