1
0
mirror of https://github.com/golang/go synced 2024-11-06 02:26:17 -07:00
go/astutil/imports_test.go

758 lines
8.8 KiB
Go
Raw Normal View History

// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package astutil
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"reflect"
"strconv"
"testing"
)
var fset = token.NewFileSet()
func parse(t *testing.T, name, in string) *ast.File {
file, err := parser.ParseFile(fset, name, in, parser.ParseComments)
if err != nil {
t.Fatalf("%s parse: %v", name, err)
}
return file
}
func print(t *testing.T, name string, f *ast.File) string {
var buf bytes.Buffer
if err := format.Node(&buf, fset, f); err != nil {
t.Fatalf("%s gofmt: %v", name, err)
}
return string(buf.Bytes())
}
type test struct {
name string
renamedPkg string
pkg string
in string
out string
broken bool // known broken
}
var addTests = []test{
{
name: "leave os alone",
pkg: "os",
in: `package main
import (
"os"
)
`,
out: `package main
import (
"os"
)
`,
},
{
name: "import.1",
pkg: "os",
in: `package main
`,
out: `package main
import "os"
`,
},
{
name: "import.2",
pkg: "os",
in: `package main
// Comment
import "C"
`,
out: `package main
// Comment
import "C"
import "os"
`,
},
{
name: "import.3",
pkg: "os",
in: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
out: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
},
{
name: "import.17",
pkg: "x/y/z",
in: `package main
// Comment
import "C"
import (
"a"
"b"
"x/w"
"d/f"
)
`,
out: `package main
// Comment
import "C"
import (
"a"
"b"
"x/w"
"x/y/z"
"d/f"
)
`,
},
{
name: "import into singular block",
pkg: "bytes",
in: `package main
import "os"
`,
out: `package main
import (
"bytes"
"os"
)
`,
},
{
name: "",
renamedPkg: "fmtpkg",
pkg: "fmt",
in: `package main
import "os"
`,
out: `package main
import (
fmtpkg "fmt"
"os"
)
`,
},
{
name: "struct comment",
pkg: "time",
in: `package main
// This is a comment before a struct.
type T struct {
t time.Time
}
`,
out: `package main
import "time"
// This is a comment before a struct.
type T struct {
t time.Time
}
`,
},
}
func TestAddImport(t *testing.T) {
for _, test := range addTests {
file := parse(t, test.name, test.in)
var before bytes.Buffer
ast.Fprint(&before, fset, file, nil)
AddNamedImport(fset, file, test.renamedPkg, test.pkg)
if got := print(t, test.name, file); got != test.out {
if test.broken {
t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out)
} else {
t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
}
var after bytes.Buffer
ast.Fprint(&after, fset, file, nil)
t.Logf("AST before:\n%s\nAST after:\n%s\n", before.String(), after.String())
}
}
}
func TestDoubleAddImport(t *testing.T) {
file := parse(t, "doubleimport", "package main\n")
AddImport(fset, file, "os")
AddImport(fset, file, "bytes")
want := `package main
import (
"bytes"
"os"
)
`
if got := print(t, "doubleimport", file); got != want {
t.Errorf("got: %s\nwant: %s", got, want)
}
}
var deleteTests = []test{
{
name: "import.4",
pkg: "os",
in: `package main
import (
"os"
)
`,
out: `package main
`,
},
{
name: "import.5",
pkg: "os",
in: `package main
// Comment
import "C"
import "os"
`,
out: `package main
// Comment
import "C"
`,
},
{
name: "import.6",
pkg: "os",
in: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
out: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
},
{
name: "import.7",
pkg: "io",
in: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
out: `package main
import (
// a
"os" // b
"utf8" // c
)
`,
},
{
name: "import.8",
pkg: "os",
in: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
out: `package main
import (
"io" // a
// b
"utf8" // c
)
`,
},
{
name: "import.9",
pkg: "utf8",
in: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
out: `package main
import (
"io" // a
"os" // b
// c
)
`,
},
{
name: "import.10",
pkg: "io",
in: `package main
import (
"io"
"os"
"utf8"
)
`,
out: `package main
import (
"os"
"utf8"
)
`,
},
{
name: "import.11",
pkg: "os",
in: `package main
import (
"io"
"os"
"utf8"
)
`,
out: `package main
import (
"io"
"utf8"
)
`,
},
{
name: "import.12",
pkg: "utf8",
in: `package main
import (
"io"
"os"
"utf8"
)
`,
out: `package main
import (
"io"
"os"
)
`,
},
{
name: "handle.raw.quote.imports",
pkg: "os",
in: "package main\n\nimport `os`",
out: `package main
go.tools/astutil: fix edge case in DeleteImport causing merging of import sections. The issue occurs only when deleting an import that has a blank line immediately preceding, and other imports before that. Currently, DeleteImport assumes there's a blank line-sized hole left behind where the import was, and always deletes it. That blank line-sized hole is there in all cases except the above edge case. This fix checks for that edge case, and does not remove the blank line-sized hole. The CL also adds a previously failing test case that catches this scenario. After the change to DeleteImport, the new test passes (along with all other tests). Fixes golang/go#7679. Note that there is no attempt to ensure the result *ast.File and *token.FileSet are perfectly matching to what you would get if you printed the AST and parsed it back. This is how the rest of the package and the current tests work (i.e., they only check that printing the AST gives the correct output). Changing that is very hard, if not impossible, at least not without resorting to manipulating AST via printing, text manipulation and parsing. This is okay for most usages, but it does create potential problems. For example, astutil.Imports() currently only works correctly on freshly parsed AST. If that AST is manipulated via astutil funcs, then Imports() may not always generate correct output. However, thas is a separate issue and should be treated as such. LGTM=bradfitz R=golang-codereviews, gobot, adonovan, bradfitz CC=golang-codereviews https://golang.org/cl/92250045
2014-05-19 15:04:30 -06:00
`,
},
{
name: "import.13",
pkg: "io",
in: `package main
import (
"fmt"
"io"
"os"
"utf8"
"go/format"
)
`,
out: `package main
import (
"fmt"
"os"
"utf8"
"go/format"
)
`,
},
{
name: "import.14",
pkg: "io",
in: `package main
import (
"fmt" // a
"io" // b
"os" // c
"utf8" // d
"go/format" // e
)
`,
out: `package main
import (
"fmt" // a
// b
"os" // c
"utf8" // d
"go/format" // e
)
`,
},
{
name: "import.15",
pkg: "double",
in: `package main
import (
"double"
"double"
)
`,
out: `package main
`,
},
{
name: "import.16",
pkg: "bubble",
in: `package main
import (
"toil"
"bubble"
"bubble"
"trouble"
)
`,
out: `package main
import (
"toil"
"trouble"
)
`,
},
{
name: "import.17",
pkg: "quad",
in: `package main
import (
"quad"
"quad"
)
import (
"quad"
"quad"
)
`,
out: `package main
`,
},
}
func TestDeleteImport(t *testing.T) {
for _, test := range deleteTests {
file := parse(t, test.name, test.in)
DeleteImport(fset, file, test.pkg)
if got := print(t, test.name, file); got != test.out {
t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
}
}
}
type rewriteTest struct {
name string
srcPkg string
dstPkg string
in string
out string
}
var rewriteTests = []rewriteTest{
{
name: "import.13",
srcPkg: "utf8",
dstPkg: "encoding/utf8",
in: `package main
import (
"io"
"os"
"utf8" // thanks ken
)
`,
out: `package main
import (
"encoding/utf8" // thanks ken
"io"
"os"
)
`,
},
{
name: "import.14",
srcPkg: "asn1",
dstPkg: "encoding/asn1",
in: `package main
import (
"asn1"
"crypto"
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"time"
)
var x = 1
`,
out: `package main
import (
"crypto"
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"time"
)
var x = 1
`,
},
{
name: "import.15",
srcPkg: "url",
dstPkg: "net/url",
in: `package main
import (
"bufio"
"net"
"path"
"url"
)
var x = 1 // comment on x, not on url
`,
out: `package main
import (
"bufio"
"net"
"net/url"
"path"
)
var x = 1 // comment on x, not on url
`,
},
{
name: "import.16",
srcPkg: "http",
dstPkg: "net/http",
in: `package main
import (
"flag"
"http"
"log"
"text/template"
)
var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
`,
out: `package main
import (
"flag"
"log"
"net/http"
"text/template"
)
var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
`,
},
}
func TestRewriteImport(t *testing.T) {
for _, test := range rewriteTests {
file := parse(t, test.name, test.in)
RewriteImport(fset, file, test.srcPkg, test.dstPkg)
if got := print(t, test.name, file); got != test.out {
t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
}
}
}
var importsTests = []struct {
name string
in string
want [][]string
}{
{
name: "no packages",
in: `package foo
`,
want: nil,
},
{
name: "one group",
in: `package foo
import (
"fmt"
"testing"
)
`,
want: [][]string{{"fmt", "testing"}},
},
{
name: "four groups",
in: `package foo
import "C"
import (
"fmt"
"testing"
"appengine"
"myproject/mylib1"
"myproject/mylib2"
)
`,
want: [][]string{
{"C"},
{"fmt", "testing"},
{"appengine"},
{"myproject/mylib1", "myproject/mylib2"},
},
},
{
name: "multiple factored groups",
in: `package foo
import (
"fmt"
"testing"
"appengine"
)
import (
"reflect"
"bytes"
)
`,
want: [][]string{
{"fmt", "testing"},
{"appengine"},
{"reflect"},
{"bytes"},
},
},
}
func unquote(s string) string {
res, err := strconv.Unquote(s)
if err != nil {
return "could_not_unquote"
}
return res
}
func TestImports(t *testing.T) {
fset := token.NewFileSet()
for _, test := range importsTests {
f, err := parser.ParseFile(fset, "test.go", test.in, 0)
if err != nil {
t.Errorf("%s: %v", test.name, err)
continue
}
var got [][]string
for _, block := range Imports(fset, f) {
var b []string
for _, spec := range block {
b = append(b, unquote(spec.Path.Value))
}
got = append(got, b)
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("Imports(%s)=%v, want %v", test.name, got, test.want)
}
}
}