417 lines
10 KiB
Go
417 lines
10 KiB
Go
// ALR - Any Linux Repository
|
|
// Copyright (C) 2025 The ALR Authors
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"text/template"
|
|
"unicode"
|
|
|
|
"golang.org/x/text/cases"
|
|
"golang.org/x/text/language"
|
|
)
|
|
|
|
type MethodInfo struct {
|
|
Name string
|
|
Params []ParamInfo
|
|
Results []ResultInfo
|
|
EntityName string
|
|
}
|
|
|
|
type ParamInfo struct {
|
|
Name string
|
|
Type string
|
|
}
|
|
|
|
type ResultInfo struct {
|
|
Name string
|
|
Type string
|
|
Index int
|
|
}
|
|
|
|
func extractImports(node *ast.File) []string {
|
|
var imports []string
|
|
for _, imp := range node.Imports {
|
|
if imp.Path.Value != "" {
|
|
imports = append(imports, imp.Path.Value)
|
|
}
|
|
}
|
|
return imports
|
|
}
|
|
|
|
func output(path string, buf bytes.Buffer) {
|
|
formatted, err := format.Source(buf.Bytes())
|
|
if err != nil {
|
|
log.Fatalf("formatting: %v", err)
|
|
}
|
|
|
|
outPath := strings.TrimSuffix(path, ".go") + "_gen.go"
|
|
outFile, err := os.Create(outPath)
|
|
if err != nil {
|
|
log.Fatalf("create file: %v", err)
|
|
}
|
|
_, err = outFile.Write(formatted)
|
|
if err != nil {
|
|
log.Fatalf("writing output: %v", err)
|
|
}
|
|
outFile.Close()
|
|
}
|
|
|
|
func main() {
|
|
path := os.Getenv("GOFILE")
|
|
if path == "" {
|
|
log.Fatal("GOFILE must be set")
|
|
}
|
|
|
|
if len(os.Args) < 2 {
|
|
log.Fatal("At least one entity name must be provided")
|
|
}
|
|
|
|
entityNames := os.Args[1:]
|
|
|
|
fset := token.NewFileSet()
|
|
node, err := parser.ParseFile(fset, path, nil, parser.AllErrors)
|
|
if err != nil {
|
|
log.Fatalf("parsing file: %v", err)
|
|
}
|
|
|
|
packageName := node.Name.Name
|
|
|
|
// Find all specified entities
|
|
entityData := make(map[string][]*ast.Field)
|
|
|
|
for _, decl := range node.Decls {
|
|
genDecl, ok := decl.(*ast.GenDecl)
|
|
if !ok || genDecl.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, spec := range genDecl.Specs {
|
|
typeSpec := spec.(*ast.TypeSpec)
|
|
for _, entityName := range entityNames {
|
|
if typeSpec.Name.Name == entityName {
|
|
interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
|
|
if !ok {
|
|
log.Fatalf("entity %s is not an interface", entityName)
|
|
}
|
|
entityData[entityName] = interfaceType.Methods.List
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify all entities were found
|
|
for _, entityName := range entityNames {
|
|
if _, found := entityData[entityName]; !found {
|
|
log.Fatalf("interface %s not found", entityName)
|
|
}
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
|
|
buf.WriteString(`
|
|
// DO NOT EDIT MANUALLY. This file is generated.
|
|
|
|
// ALR - Any Linux Repository
|
|
// Copyright (C) 2025 The ALR Authors
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
|
|
`)
|
|
|
|
buf.WriteString(fmt.Sprintf("package %s\n", packageName))
|
|
|
|
// Generate base structures for all entities
|
|
baseStructs(&buf, entityNames, extractImports(node))
|
|
|
|
// Generate method-specific code for each entity
|
|
for _, entityName := range entityNames {
|
|
methods := parseMethodsFromFields(entityName, entityData[entityName])
|
|
argsGen(&buf, methods)
|
|
}
|
|
|
|
output(path, buf)
|
|
}
|
|
|
|
func parseMethodsFromFields(entityName string, fields []*ast.Field) []MethodInfo {
|
|
var methods []MethodInfo
|
|
|
|
for _, field := range fields {
|
|
if len(field.Names) == 0 {
|
|
continue
|
|
}
|
|
|
|
methodName := field.Names[0].Name
|
|
funcType, ok := field.Type.(*ast.FuncType)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
method := MethodInfo{
|
|
Name: methodName,
|
|
EntityName: entityName,
|
|
}
|
|
|
|
// Parse parameters, excluding context.Context
|
|
if funcType.Params != nil {
|
|
for i, param := range funcType.Params.List {
|
|
paramType := typeToString(param.Type)
|
|
// Skip context.Context parameters
|
|
if paramType == "context.Context" {
|
|
continue
|
|
}
|
|
if len(param.Names) == 0 {
|
|
method.Params = append(method.Params, ParamInfo{
|
|
Name: fmt.Sprintf("Arg%d", i),
|
|
Type: paramType,
|
|
})
|
|
} else {
|
|
for _, name := range param.Names {
|
|
method.Params = append(method.Params, ParamInfo{
|
|
Name: cases.Title(language.Und, cases.NoLower).String(name.Name),
|
|
Type: paramType,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse results
|
|
if funcType.Results != nil {
|
|
resultIndex := 0
|
|
for _, result := range funcType.Results.List {
|
|
resultType := typeToString(result.Type)
|
|
if resultType == "error" {
|
|
continue // Skip error in response struct
|
|
}
|
|
|
|
if len(result.Names) == 0 {
|
|
method.Results = append(method.Results, ResultInfo{
|
|
Name: fmt.Sprintf("Result%d", resultIndex),
|
|
Type: resultType,
|
|
Index: resultIndex,
|
|
})
|
|
} else {
|
|
for _, name := range result.Names {
|
|
method.Results = append(method.Results, ResultInfo{
|
|
Name: cases.Title(language.Und, cases.NoLower).String(name.Name),
|
|
Type: resultType,
|
|
Index: resultIndex,
|
|
})
|
|
}
|
|
}
|
|
resultIndex++
|
|
}
|
|
}
|
|
|
|
methods = append(methods, method)
|
|
}
|
|
|
|
return methods
|
|
}
|
|
|
|
func argsGen(buf *bytes.Buffer, methods []MethodInfo) {
|
|
// Add template functions first
|
|
funcMap := template.FuncMap{
|
|
"lowerFirst": func(s string) string {
|
|
if len(s) == 0 {
|
|
return s
|
|
}
|
|
return strings.ToLower(s[:1]) + s[1:]
|
|
},
|
|
"zeroValue": func(typeName string) string {
|
|
typeName = strings.TrimSpace(typeName)
|
|
|
|
switch typeName {
|
|
case "string":
|
|
return "\"\""
|
|
case "int", "int8", "int16", "int32", "int64":
|
|
return "0"
|
|
case "uint", "uint8", "uint16", "uint32", "uint64":
|
|
return "0"
|
|
case "float32", "float64":
|
|
return "0.0"
|
|
case "bool":
|
|
return "false"
|
|
}
|
|
|
|
if strings.HasPrefix(typeName, "*") {
|
|
return "nil"
|
|
}
|
|
if strings.HasPrefix(typeName, "[]") ||
|
|
strings.HasPrefix(typeName, "map[") ||
|
|
strings.HasPrefix(typeName, "chan ") {
|
|
return "nil"
|
|
}
|
|
|
|
if typeName == "interface{}" {
|
|
return "nil"
|
|
}
|
|
|
|
// If external type: pkg.Type
|
|
if strings.Contains(typeName, ".") {
|
|
return typeName + "{}"
|
|
}
|
|
|
|
// If starts with uppercase — likely struct
|
|
if len(typeName) > 0 && unicode.IsUpper(rune(typeName[0])) {
|
|
return typeName + "{}"
|
|
}
|
|
|
|
return "nil"
|
|
},
|
|
}
|
|
|
|
argsTemplate := template.Must(template.New("args").Funcs(funcMap).Parse(`
|
|
{{range .}}
|
|
type {{.EntityName}}{{.Name}}Args struct {
|
|
{{range .Params}} {{.Name}} {{.Type}}
|
|
{{end}}}
|
|
|
|
type {{.EntityName}}{{.Name}}Resp struct {
|
|
{{range .Results}} {{.Name}} {{.Type}}
|
|
{{end}}}
|
|
|
|
func (s *{{.EntityName}}RPC) {{.Name}}(ctx context.Context, {{range $i, $p := .Params}}{{if $i}}, {{end}}{{lowerFirst $p.Name}} {{$p.Type}}{{end}}) ({{range $i, $r := .Results}}{{if $i}}, {{end}}{{$r.Type}}{{end}}{{if .Results}}, {{end}}error) {
|
|
var resp *{{.EntityName}}{{.Name}}Resp
|
|
err := s.client.Call("Plugin.{{.Name}}", &{{.EntityName}}{{.Name}}Args{
|
|
{{range .Params}} {{.Name}}: {{lowerFirst .Name}},
|
|
{{end}} }, &resp)
|
|
if err != nil {
|
|
return {{range $i, $r := .Results}}{{if $i}}, {{end}}{{zeroValue $r.Type}}{{end}}{{if .Results}}, {{end}}err
|
|
}
|
|
return {{range $i, $r := .Results}}{{if $i}}, {{end}}resp.{{$r.Name}}{{end}}{{if .Results}}, {{end}}nil
|
|
}
|
|
|
|
func (s *{{.EntityName}}RPCServer) {{.Name}}(args *{{.EntityName}}{{.Name}}Args, resp *{{.EntityName}}{{.Name}}Resp) error {
|
|
{{if .Results}}{{range $i, $r := .Results}}{{if $i}}, {{end}}{{lowerFirst $r.Name}}{{end}}, err := {{else}}err := {{end}}s.Impl.{{.Name}}(context.Background(),{{range $i, $p := .Params}}{{if $i}}, {{end}}args.{{$p.Name}}{{end}})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
{{if .Results}}*resp = {{.EntityName}}{{.Name}}Resp{
|
|
{{range .Results}} {{.Name}}: {{lowerFirst .Name}},
|
|
{{end}} }
|
|
{{else}}*resp = {{.EntityName}}{{.Name}}Resp{}
|
|
{{end}}return nil
|
|
}
|
|
{{end}}
|
|
`))
|
|
|
|
err := argsTemplate.Execute(buf, methods)
|
|
if err != nil {
|
|
log.Fatalf("execute args template: %v", err)
|
|
}
|
|
}
|
|
|
|
func typeToString(expr ast.Expr) string {
|
|
switch t := expr.(type) {
|
|
case *ast.Ident:
|
|
return t.Name
|
|
case *ast.StarExpr:
|
|
return "*" + typeToString(t.X)
|
|
case *ast.ArrayType:
|
|
return "[]" + typeToString(t.Elt)
|
|
case *ast.SelectorExpr:
|
|
xStr := typeToString(t.X)
|
|
if xStr == "context" && t.Sel.Name == "Context" {
|
|
return "context.Context"
|
|
}
|
|
return xStr + "." + t.Sel.Name
|
|
case *ast.InterfaceType:
|
|
return "interface{}"
|
|
default:
|
|
return "interface{}"
|
|
}
|
|
}
|
|
|
|
func baseStructs(buf *bytes.Buffer, entityNames, imports []string) {
|
|
// Ensure "context" is included in imports
|
|
updatedImports := imports
|
|
hasContext := false
|
|
for _, imp := range imports {
|
|
if strings.Contains(imp, `"context"`) {
|
|
hasContext = true
|
|
break
|
|
}
|
|
}
|
|
if !hasContext {
|
|
updatedImports = append(updatedImports, `"context"`)
|
|
}
|
|
|
|
contentTemplate := template.Must(template.New("").Parse(`
|
|
import (
|
|
"net/rpc"
|
|
|
|
"github.com/hashicorp/go-plugin"
|
|
{{range .Imports}} {{.}}
|
|
{{end}}
|
|
)
|
|
|
|
{{range .EntityNames}}
|
|
type {{ . }}Plugin struct {
|
|
Impl {{ . }}
|
|
}
|
|
|
|
type {{ . }}RPCServer struct {
|
|
Impl {{ . }}
|
|
}
|
|
|
|
type {{ . }}RPC struct {
|
|
client *rpc.Client
|
|
}
|
|
|
|
func (p *{{ . }}Plugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
|
return &{{ . }}RPC{client: c}, nil
|
|
}
|
|
|
|
func (p *{{ . }}Plugin) Server(*plugin.MuxBroker) (interface{}, error) {
|
|
return &{{ . }}RPCServer{Impl: p.Impl}, nil
|
|
}
|
|
|
|
{{end}}
|
|
`))
|
|
err := contentTemplate.Execute(buf, struct {
|
|
EntityNames []string
|
|
Imports []string
|
|
}{
|
|
EntityNames: entityNames,
|
|
Imports: updatedImports,
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("execute template: %v", err)
|
|
}
|
|
}
|