Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I compare two source code files/ ast trees?

Tags:

I'm generating some source code using the templates package( is there a better method? )and part of the testing I need to check if the output matches the expected source code.

  • I tried a string comparison but it fails due the extra spaces / new lines generated by the templates package. I've also tried format.Source with not success. ( FAIL)
  • I tried to parse the ast of the both sources (see bellow) but the ast doesn't match either even if the code is basically same except the new lines / spaces. (FAIL)

    package main

    import (
        "fmt"
        "go/parser"
        "go/token"
        "reflect"
    )
    
    func main() {
        stub1 := `package main
         func myfunc(s string) error {
            return nil  
        }`
        stub2 := `package main
    
         func myfunc(s string) error {
    
            return nil
    
        }`
        fset := token.NewFileSet()
        r1, err := parser.ParseFile(fset, "", stub1, parser.AllErrors)
        if err != nil {
            panic(err)
        }
        fset = token.NewFileSet()
        r2, err := parser.ParseFile(fset, "", stub2, parser.AllErrors)
        if err != nil {
            panic(err)
        }
        if !reflect.DeepEqual(r1, r2) {
            fmt.Printf("e %v, r %s, ", r1, r2)
        }
    }
    

Playground

like image 860
themihai Avatar asked May 12 '15 16:05

themihai


People also ask

How do you use ast?

How to do using ast library, a = b + 3 or a = 3+b , both have same node type i.e. BinOp, you can validate variable “a” value and its node type. For each line of code, create AST node then compare value, node type and other parameters as well like operator, operand, function name, class name, index, etc… if required.

What is ast in coding?

An Abstract Syntax Tree, or AST, is a tree representation of the source code of a computer program that conveys the structure of the source code. Each node in the tree represents a construct occurring in the source code.

How do you find the ast in Python?

The abstract syntax itself might change with each Python release; this module helps to find out programmatically what the current grammar looks like. An abstract syntax tree can be generated by passing ast. PyCF_ONLY_AST as a flag to the compile() built-in function, or using the parse() helper provided in this module.


2 Answers

Well, one simple way to achieve this is to use the go/printer library, that gives you better control of output formatting, and is basically like running gofmt on the source, normalizing both trees:

package main
import (
    "fmt"
    "go/parser"
    "go/token"
    "go/printer"
    //"reflect"
    "bytes"
)

func main() {
    stub1 := `package main
     func myfunc(s string) error {
        return nil  
    }`
    stub2 := `package main

     func myfunc(s string) error {

        return nil

    }`

    fset1 := token.NewFileSet()
    r1, err := parser.ParseFile(fset1, "", stub1, parser.AllErrors)
    if err != nil {
        panic(err)
    }
    fset2 := token.NewFileSet()
    r2, err := parser.ParseFile(fset1, "", stub2, parser.AllErrors)
    if err != nil {
        panic(err)
    }

    // we create two output buffers for each source tree
    out1 := bytes.NewBuffer(nil)
    out2 := bytes.NewBuffer(nil)

    // we use the same printer config for both
    conf := &printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}

    // print to both outputs
    if err := conf.Fprint(out1, fset1, r1); err != nil {
        panic(err)
    }
    if err := conf.Fprint(out2, fset2, r2); err != nil {
        panic(err)
    }


    // they should be identical!
    if string(out1.Bytes()) != string(out2.Bytes()) {
        panic(string(out1.Bytes()) +"\n" + string(out2.Bytes()))
    } else {
        fmt.Println("A-OKAY!")
    }
}

Of course this code needs to be refactored to not look as stupid. Another approach is instead of using DeepEqual, create a tree comparison function yourself, that skips irrelevant nodes.

like image 132
Not_a_Golfer Avatar answered Sep 27 '22 22:09

Not_a_Golfer


This was easier than I thought. All I had to do was to remove the empty new lines(after formatting). Below is the code.

    package main

    import (
        "fmt"
        "go/format"
        "strings"
    )

    func main() {
        a, err := fmtSource(stub1)
        if err != nil {
            panic(err)
        }
        b, err := fmtSource(stub2)
        if err != nil {
            panic(err)
        }
        if a != b {
            fmt.Printf("a %v, \n b %v", a, b)
        }
    }

func fmtSource(source string) (string, error) {
    if !strings.Contains(source, "package") {
        source = "package main\n" + source
    }
    b, err := format.Source([]byte(source))
    if err != nil {
        return "", err
    }
    // cleanLine replaces double space with one space
    cleanLine := func(s string)string{
        sa := strings.Fields(s)
        return strings.Join(sa, " ")
    }
    lines := strings.Split(string(b), "\n")
    n := 0
    var startLn *int
    for _, line := range lines {
        if line != "" {
            line = cleanLine(line)
            lines[n] = line
            if startLn == nil {
                x := n
                startLn = &x
            }
            n++
        }
    }
    lines = lines[*startLn:n]
    // Add final "" entry to get trailing newline from Join.
    if n > 0 && lines[n-1] != "" {
        lines = append(lines, "")
    }


    // Make it pretty 
    b, err = format.Source([]byte(strings.Join(lines, "\n")))
    if err != nil {
        return "", err
    }
    return string(b), nil
}
like image 25
themihai Avatar answered Sep 27 '22 20:09

themihai