diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 57b9dd95..a192ed8d 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -108,31 +108,12 @@ func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) { func (comp *compiler) check() { // TODO: check len in syscall arguments referring to parent. // TODO: incorrect name is referenced in len type - // TODO: infinite recursion via struct pointers (e.g. a linked list) // TODO: no constructor for a resource comp.checkNames() comp.checkFields() - - for _, decl := range comp.desc.Nodes { - switch n := decl.(type) { - case *ast.Resource: - comp.checkType(n.Base, false, true) - comp.checkResource(n) - case *ast.Struct: - for _, f := range n.Fields { - comp.checkType(f.Type, false, false) - } - comp.checkStruct(n) - case *ast.Call: - for _, a := range n.Args { - comp.checkType(a.Type, true, false) - } - if n.Ret != nil { - comp.checkType(n.Ret, true, false) - } - } - } + comp.checkTypes() + comp.checkRecursion() } func (comp *compiler) checkNames() { @@ -238,7 +219,40 @@ func (comp *compiler) checkFields() { } } -func (comp *compiler) checkResource(n *ast.Resource) { +func (comp *compiler) checkTypes() { + for _, decl := range comp.desc.Nodes { + switch n := decl.(type) { + case *ast.Resource: + comp.checkType(n.Base, false, true) + case *ast.Struct: + for _, f := range n.Fields { + comp.checkType(f.Type, false, false) + } + comp.checkStruct(n) + case *ast.Call: + for _, a := range n.Args { + comp.checkType(a.Type, true, false) + } + if n.Ret != nil { + comp.checkType(n.Ret, true, false) + } + } + } +} + +func (comp *compiler) checkRecursion() { + checked := make(map[string]bool) + for _, decl := range comp.desc.Nodes { + switch n := decl.(type) { + case *ast.Resource: + comp.checkResourceRecursion(n) + case *ast.Struct: + comp.checkStructRecursion(checked, n) + } + } +} + +func (comp *compiler) checkResourceRecursion(n *ast.Resource) { var seen []string for n != nil { if arrayContains(seen, n.Name.Name) { @@ -255,6 +269,65 @@ func (comp *compiler) checkResource(n *ast.Resource) { } } +type pathElem struct { + Pos ast.Pos + Struct string + Field string +} + +func (comp *compiler) checkStructRecursion(checked map[string]bool, n *ast.Struct) { + var path []pathElem + comp.checkStructRecursion1(checked, n, path) +} + +func (comp *compiler) checkStructRecursion1(checked map[string]bool, n *ast.Struct, path []pathElem) { + name := n.Name.Name + if checked[name] { + return + } + for i, elem := range path { + if elem.Struct != name { + continue + } + path = path[i:] + str := "" + for _, elem := range path { + str += fmt.Sprintf("%v.%v -> ", elem.Struct, elem.Field) + } + str += name + comp.error(path[0].Pos, "recursive declaration: %v (mark some pointers as opt)", str) + checked[name] = true + return + } + for _, f := range n.Fields { + path = append(path, pathElem{ + Pos: f.Pos, + Struct: name, + Field: f.Name.Name, + }) + comp.recurseField(checked, f.Type, path) + path = path[:len(path)-1] + } + checked[name] = true +} + +func (comp *compiler) recurseField(checked map[string]bool, t *ast.Type, path []pathElem) { + desc := comp.getTypeDesc(t) + if desc == typeStruct { + comp.checkStructRecursion1(checked, comp.structs[t.Ident], path) + return + } + _, args, base := comp.getArgsBase(t, "", sys.DirIn, false) + if desc == typePtr && base.IsOptional { + return // optional pointers prune recursion + } + for i, arg := range args { + if desc.Args[i].Type == typeArgType { + comp.recurseField(checked, arg, path) + } + } +} + func (comp *compiler) checkStruct(n *ast.Struct) { if n.IsUnion { comp.parseUnionAttrs(n) diff --git a/pkg/compiler/testdata/errors.txt b/pkg/compiler/testdata/errors.txt index 68c9895b..feb0fed3 100644 --- a/pkg/compiler/testdata/errors.txt +++ b/pkg/compiler/testdata/errors.txt @@ -152,3 +152,34 @@ define d1 `some C expression` define d2 some C expression define d2 SOMETHING ### duplicate define d2 define d3 1 + +sr1 { + f1 sr1 ### recursive declaration: sr1.f1 -> sr1 (mark some pointers as opt) +} + +sr2 { + f1 sr3 + f2 sr4 +} + +sr3 { + f1 ptr[in, sr3] ### recursive declaration: sr3.f1 -> sr3 (mark some pointers as opt) +} + +sr4 { + f1 ptr[in, sr3] + f2 array[ptr[in, sr5], 4] ### recursive declaration: sr4.f2 -> sr5.f2 -> sr6.f1 -> sr4 (mark some pointers as opt) +} + +sr5 [ + f1 int32 + f2 sr6 +] + +sr6 { + f1 sr4 +} + +sr7 { + f1 ptr[in, sr7, opt] +}