diff --git a/src/cmd/fix/cryptotype.go b/src/cmd/fix/cryptotype.go new file mode 100644 index 0000000000..abcf7714a8 --- /dev/null +++ b/src/cmd/fix/cryptotype.go @@ -0,0 +1,36 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +var cryptotypeFix = fix{ + "cryptotype", + "2012-02-12", + renameFix(cryptotypeReplace), + `Rewrite uses of concrete cipher types to refer to the generic cipher.Block. + +http://codereview.appspot.com/5625045/ +`, +} + +var cryptotypeReplace = []rename{ + { + OldImport: "crypto/aes", + NewImport: "crypto/cipher", + Old: "*aes.Cipher", + New: "cipher.Block", + }, + { + OldImport: "crypto/des", + NewImport: "crypto/cipher", + Old: "*des.Cipher", + New: "cipher.Block", + }, + { + OldImport: "crypto/des", + NewImport: "crypto/cipher", + Old: "*des.TripleDESCipher", + New: "cipher.Block", + }, +} diff --git a/src/cmd/fix/cryptotype_test.go b/src/cmd/fix/cryptotype_test.go new file mode 100644 index 0000000000..7accceef3e --- /dev/null +++ b/src/cmd/fix/cryptotype_test.go @@ -0,0 +1,43 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(cryptotypeTests, cryptotypeFix.f) +} + +var cryptotypeTests = []testCase{ + { + Name: "cryptotype.0", + In: `package main + +import ( + "crypto/aes" + "crypto/des" +) + +var ( + _ *aes.Cipher + _ *des.Cipher + _ *des.TripleDESCipher + _ = aes.New() +) +`, + Out: `package main + +import ( + "crypto/aes" + "crypto/cipher" +) + +var ( + _ cipher.Block + _ cipher.Block + _ cipher.Block + _ = aes.New() +) +`, + }, +} diff --git a/src/cmd/fix/fix.go b/src/cmd/fix/fix.go index 2c1be6942a..d2067cb51e 100644 --- a/src/cmd/fix/fix.go +++ b/src/cmd/fix/fix.go @@ -4,14 +4,6 @@ package main -/* -receiver named error -function named error -method on error -exiterror -slice of named type (go/scanner) -*/ - import ( "fmt" "go/ast" @@ -19,6 +11,7 @@ import ( "go/token" "os" "path" + "reflect" "strconv" "strings" ) @@ -750,5 +743,105 @@ func expr(s string) ast.Expr { if err != nil { panic("parsing " + s + ": " + err.Error()) } + // Remove position information to avoid spurious newlines. + killPos(reflect.ValueOf(x)) return x } + +var posType = reflect.TypeOf(token.Pos(0)) + +func killPos(v reflect.Value) { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if !v.IsNil() { + killPos(v.Elem()) + } + case reflect.Slice: + n := v.Len() + for i := 0; i < n; i++ { + killPos(v.Index(i)) + } + case reflect.Struct: + n := v.NumField() + for i := 0; i < n; i++ { + f := v.Field(i) + if f.Type() == posType { + f.SetInt(0) + continue + } + killPos(f) + } + } +} + +// A Rename describes a single renaming. +type rename struct { + OldImport string // only apply rename if this import is present + NewImport string // add this import during rewrite + Old string // old name: p.T or *p.T + New string // new name: p.T or *p.T +} + +func renameFix(tab []rename) func(*ast.File) bool { + return func(f *ast.File) bool { + return renameFixTab(f, tab) + } +} + +func parseName(s string) (ptr bool, pkg, nam string) { + i := strings.Index(s, ".") + if i < 0 { + panic("parseName: invalid name " + s) + } + if strings.HasPrefix(s, "*") { + ptr = true + s = s[1:] + i-- + } + pkg = s[:i] + nam = s[i+1:] + return +} + +func renameFixTab(f *ast.File, tab []rename) bool { + fixed := false + added := map[string]bool{} + check := map[string]bool{} + for _, t := range tab { + if !imports(f, t.OldImport) { + continue + } + optr, opkg, onam := parseName(t.Old) + walk(f, func(n interface{}) { + np, ok := n.(*ast.Expr) + if !ok { + return + } + x := *np + if optr { + p, ok := x.(*ast.StarExpr) + if !ok { + return + } + x = p.X + } + if !isPkgDot(x, opkg, onam) { + return + } + if t.NewImport != "" && !added[t.NewImport] { + addImport(f, t.NewImport) + added[t.NewImport] = true + } + *np = expr(t.New) + check[t.OldImport] = true + fixed = true + }) + } + + for ipath := range check { + if !usesImport(f, ipath) { + deleteImport(f, ipath) + } + } + return fixed +}