diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index fe35bfb08de..94f7912a38b 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -37,21 +37,12 @@ func initRewrite() { // but there are problems with preserving formatting and also // with what a wildcard for a statement looks like. func parseExpr(s string, what string) ast.Expr { - stmts, err := parser.ParseStmtList("input", s) + x, err := parser.ParseExpr("input", s) if err != nil { fmt.Fprintf(os.Stderr, "parsing %s %s: %s\n", what, s, err) os.Exit(2) } - if len(stmts) != 1 { - fmt.Fprintf(os.Stderr, "%s must be single expression\n", what) - os.Exit(2) - } - x, ok := stmts[0].(*ast.ExprStmt) - if !ok { - fmt.Fprintf(os.Stderr, "%s must be single expression\n", what) - os.Exit(2) - } - return x.X + return x } @@ -147,6 +138,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { switch p := p.(type) { case *reflect.SliceValue: v := v.(*reflect.SliceValue) + if p.Len() != v.Len() { + return false + } for i := 0; i < p.Len(); i++ { if !match(m, p.Elem(i), v.Elem(i)) { return false @@ -156,6 +150,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { case *reflect.StructValue: v := v.(*reflect.StructValue) + if p.NumField() != v.NumField() { + return false + } for i := 0; i < p.NumField(); i++ { if !match(m, p.Field(i), v.Field(i)) { return false