1
0
mirror of https://github.com/golang/go synced 2024-11-20 05:34:40 -07:00

go/ast: fixed bug in NotNilFilter, added test

- fixed a couple of comments
- cleanups after reflect change

R=rsc
CC=golang-dev
https://golang.org/cl/4389041
This commit is contained in:
Robert Griesemer 2011-04-13 09:37:13 -07:00
parent dcf32a24a0
commit 7c270aef08
2 changed files with 114 additions and 20 deletions

View File

@ -26,7 +26,7 @@ func NotNilFilter(_ string, v reflect.Value) bool {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return !v.IsNil() return !v.IsNil()
} }
return false return true
} }
@ -80,7 +80,7 @@ type printer struct {
output io.Writer output io.Writer
fset *token.FileSet fset *token.FileSet
filter FieldFilter filter FieldFilter
ptrmap map[interface{}]int // *reflect.PtrValue -> line number ptrmap map[interface{}]int // *T -> line number
written int // number of bytes written to output written int // number of bytes written to output
indent int // current indentation level indent int // current indentation level
last byte // the last byte processed by Write last byte // the last byte processed by Write
@ -141,6 +141,11 @@ func (p *printer) printf(format string, args ...interface{}) {
// Implementation note: Print is written for AST nodes but could be // Implementation note: Print is written for AST nodes but could be
// used to print arbitrary data structures; such a version should // used to print arbitrary data structures; such a version should
// probably be in a different package. // probably be in a different package.
//
// Note: This code detects (some) cycles created via pointers but
// not cycles that are created via slices or maps containing the
// same slice or map. Code for general data structures probably
// should catch those as well.
func (p *printer) print(x reflect.Value) { func (p *printer) print(x reflect.Value) {
if !NotNilFilter("", x) { if !NotNilFilter("", x) {
@ -148,17 +153,17 @@ func (p *printer) print(x reflect.Value) {
return return
} }
switch v := x; v.Kind() { switch x.Kind() {
case reflect.Interface: case reflect.Interface:
p.print(v.Elem()) p.print(x.Elem())
case reflect.Map: case reflect.Map:
p.printf("%s (len = %d) {\n", x.Type().String(), v.Len()) p.printf("%s (len = %d) {\n", x.Type().String(), x.Len())
p.indent++ p.indent++
for _, key := range v.MapKeys() { for _, key := range x.MapKeys() {
p.print(key) p.print(key)
p.printf(": ") p.printf(": ")
p.print(v.MapIndex(key)) p.print(x.MapIndex(key))
p.printf("\n") p.printf("\n")
} }
p.indent-- p.indent--
@ -169,24 +174,24 @@ func (p *printer) print(x reflect.Value) {
// type-checked ASTs may contain cycles - use ptrmap // type-checked ASTs may contain cycles - use ptrmap
// to keep track of objects that have been printed // to keep track of objects that have been printed
// already and print the respective line number instead // already and print the respective line number instead
ptr := v.Interface() ptr := x.Interface()
if line, exists := p.ptrmap[ptr]; exists { if line, exists := p.ptrmap[ptr]; exists {
p.printf("(obj @ %d)", line) p.printf("(obj @ %d)", line)
} else { } else {
p.ptrmap[ptr] = p.line p.ptrmap[ptr] = p.line
p.print(v.Elem()) p.print(x.Elem())
} }
case reflect.Slice: case reflect.Slice:
if s, ok := v.Interface().([]byte); ok { if s, ok := x.Interface().([]byte); ok {
p.printf("%#q", s) p.printf("%#q", s)
return return
} }
p.printf("%s (len = %d) {\n", x.Type().String(), v.Len()) p.printf("%s (len = %d) {\n", x.Type().String(), x.Len())
p.indent++ p.indent++
for i, n := 0, v.Len(); i < n; i++ { for i, n := 0, x.Len(); i < n; i++ {
p.printf("%d: ", i) p.printf("%d: ", i)
p.print(v.Index(i)) p.print(x.Index(i))
p.printf("\n") p.printf("\n")
} }
p.indent-- p.indent--
@ -195,10 +200,10 @@ func (p *printer) print(x reflect.Value) {
case reflect.Struct: case reflect.Struct:
p.printf("%s {\n", x.Type().String()) p.printf("%s {\n", x.Type().String())
p.indent++ p.indent++
t := v.Type() t := x.Type()
for i, n := 0, t.NumField(); i < n; i++ { for i, n := 0, t.NumField(); i < n; i++ {
name := t.Field(i).Name name := t.Field(i).Name
value := v.Field(i) value := x.Field(i)
if p.filter == nil || p.filter(name, value) { if p.filter == nil || p.filter(name, value) {
p.printf("%s: ", name) p.printf("%s: ", name)
p.print(value) p.print(value)
@ -209,11 +214,20 @@ func (p *printer) print(x reflect.Value) {
p.printf("}") p.printf("}")
default: default:
value := x.Interface() v := x.Interface()
switch v := v.(type) {
case string:
// print strings in quotes
p.printf("%q", v)
return
case token.Pos:
// position values can be printed nicely if we have a file set // position values can be printed nicely if we have a file set
if pos, ok := value.(token.Pos); ok && p.fset != nil { if p.fset != nil {
value = p.fset.Position(pos) p.printf("%s", p.fset.Position(v))
} return
p.printf("%v", value) }
}
// default
p.printf("%v", v)
} }
} }

View File

@ -0,0 +1,80 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ast
import (
"bytes"
"strings"
"testing"
)
var tests = []struct {
x interface{} // x is printed as s
s string
}{
// basic types
{nil, "0 nil"},
{true, "0 true"},
{42, "0 42"},
{3.14, "0 3.14"},
{1 + 2.718i, "0 (1+2.718i)"},
{"foobar", "0 \"foobar\""},
// maps
{map[string]int{"a": 1, "b": 2},
`0 map[string] int (len = 2) {
1 . "a": 1
2 . "b": 2
3 }`},
// pointers
{new(int), "0 *0"},
// slices
{[]int{1, 2, 3},
`0 []int (len = 3) {
1 . 0: 1
2 . 1: 2
3 . 2: 3
4 }`},
// structs
{struct{ x, y int }{42, 991},
`0 struct { x int; y int } {
1 . x: 42
2 . y: 991
3 }`},
}
// Split s into lines, trim whitespace from all lines, and return
// the concatenated non-empty lines.
func trim(s string) string {
lines := strings.Split(s, "\n", -1)
i := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
lines[i] = line
i++
}
}
return strings.Join(lines[0:i], "\n")
}
func TestPrint(t *testing.T) {
var buf bytes.Buffer
for _, test := range tests {
buf.Reset()
if _, err := Fprint(&buf, nil, test.x, nil); err != nil {
t.Errorf("Fprint failed: %s", err)
}
if s, ts := trim(buf.String()), trim(test.s); s != ts {
t.Errorf("got:\n%s\nexpected:\n%s\n", s, ts)
}
}
}