I'm generating some source code using the templates package( is there a better method? )and part of the testing I need to check if the output matches the expected source code.
I tried to parse the ast of the both sources (see bellow) but the ast doesn't match either even if the code is basically same except the new lines / spaces. (FAIL)
package main
import (
"fmt"
"go/parser"
"go/token"
"reflect"
)
func main() {
stub1 := `package main
func myfunc(s string) error {
return nil
}`
stub2 := `package main
func myfunc(s string) error {
return nil
}`
fset := token.NewFileSet()
r1, err := parser.ParseFile(fset, "", stub1, parser.AllErrors)
if err != nil {
panic(err)
}
fset = token.NewFileSet()
r2, err := parser.ParseFile(fset, "", stub2, parser.AllErrors)
if err != nil {
panic(err)
}
if !reflect.DeepEqual(r1, r2) {
fmt.Printf("e %v, r %s, ", r1, r2)
}
}
Playground
How to do using ast library, a = b + 3 or a = 3+b , both have same node type i.e. BinOp, you can validate variable “a” value and its node type. For each line of code, create AST node then compare value, node type and other parameters as well like operator, operand, function name, class name, index, etc… if required.
An Abstract Syntax Tree, or AST, is a tree representation of the source code of a computer program that conveys the structure of the source code. Each node in the tree represents a construct occurring in the source code.
The abstract syntax itself might change with each Python release; this module helps to find out programmatically what the current grammar looks like. An abstract syntax tree can be generated by passing ast. PyCF_ONLY_AST as a flag to the compile() built-in function, or using the parse() helper provided in this module.
Well, one simple way to achieve this is to use the go/printer
library, that gives you better control of output formatting, and is basically like running gofmt
on the source, normalizing both trees:
package main
import (
"fmt"
"go/parser"
"go/token"
"go/printer"
//"reflect"
"bytes"
)
func main() {
stub1 := `package main
func myfunc(s string) error {
return nil
}`
stub2 := `package main
func myfunc(s string) error {
return nil
}`
fset1 := token.NewFileSet()
r1, err := parser.ParseFile(fset1, "", stub1, parser.AllErrors)
if err != nil {
panic(err)
}
fset2 := token.NewFileSet()
r2, err := parser.ParseFile(fset1, "", stub2, parser.AllErrors)
if err != nil {
panic(err)
}
// we create two output buffers for each source tree
out1 := bytes.NewBuffer(nil)
out2 := bytes.NewBuffer(nil)
// we use the same printer config for both
conf := &printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}
// print to both outputs
if err := conf.Fprint(out1, fset1, r1); err != nil {
panic(err)
}
if err := conf.Fprint(out2, fset2, r2); err != nil {
panic(err)
}
// they should be identical!
if string(out1.Bytes()) != string(out2.Bytes()) {
panic(string(out1.Bytes()) +"\n" + string(out2.Bytes()))
} else {
fmt.Println("A-OKAY!")
}
}
Of course this code needs to be refactored to not look as stupid. Another approach is instead of using DeepEqual, create a tree comparison function yourself, that skips irrelevant nodes.
This was easier than I thought. All I had to do was to remove the empty new lines(after formatting). Below is the code.
package main
import (
"fmt"
"go/format"
"strings"
)
func main() {
a, err := fmtSource(stub1)
if err != nil {
panic(err)
}
b, err := fmtSource(stub2)
if err != nil {
panic(err)
}
if a != b {
fmt.Printf("a %v, \n b %v", a, b)
}
}
func fmtSource(source string) (string, error) {
if !strings.Contains(source, "package") {
source = "package main\n" + source
}
b, err := format.Source([]byte(source))
if err != nil {
return "", err
}
// cleanLine replaces double space with one space
cleanLine := func(s string)string{
sa := strings.Fields(s)
return strings.Join(sa, " ")
}
lines := strings.Split(string(b), "\n")
n := 0
var startLn *int
for _, line := range lines {
if line != "" {
line = cleanLine(line)
lines[n] = line
if startLn == nil {
x := n
startLn = &x
}
n++
}
}
lines = lines[*startLn:n]
// Add final "" entry to get trailing newline from Join.
if n > 0 && lines[n-1] != "" {
lines = append(lines, "")
}
// Make it pretty
b, err = format.Source([]byte(strings.Join(lines, "\n")))
if err != nil {
return "", err
}
return string(b), nil
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With