From 3d4b55ad5b99def7f7e3f8694e590c1b44a1be97 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Wed, 13 Apr 2011 13:57:53 -0700 Subject: [PATCH] gofmt: minor refactor to permit easy testing R=rsc CC=golang-dev https://golang.org/cl/4397046 --- src/cmd/gofmt/gofmt.go | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go index 1ed5b9b9931..ce274aa21b1 100644 --- a/src/cmd/gofmt/gofmt.go +++ b/src/cmd/gofmt/gofmt.go @@ -13,6 +13,7 @@ import ( "go/printer" "go/scanner" "go/token" + "io" "io/ioutil" "os" "path/filepath" @@ -86,14 +87,23 @@ func isGoFile(f *os.FileInfo) bool { } -func processFile(f *os.File) os.Error { - src, err := ioutil.ReadAll(f) +// If in == nil, the source is the contents of the file with the given filename. +func processFile(filename string, in io.Reader, out io.Writer) os.Error { + if in == nil { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + in = f + } + + src, err := ioutil.ReadAll(in) if err != nil { return err } - file, err := parser.ParseFile(fset, f.Name(), src, parserMode) - + file, err := parser.ParseFile(fset, filename, src, parserMode) if err != nil { return err } @@ -116,10 +126,10 @@ func processFile(f *os.File) os.Error { if !bytes.Equal(src, res) { // formatting has changed if *list { - fmt.Fprintln(os.Stdout, f.Name()) + fmt.Fprintln(out, filename) } if *write { - err = ioutil.WriteFile(f.Name(), res, 0) + err = ioutil.WriteFile(filename, res, 0) if err != nil { return err } @@ -127,23 +137,13 @@ func processFile(f *os.File) os.Error { } if !*list && !*write { - _, err = os.Stdout.Write(res) + _, err = out.Write(res) } return err } -func processFileByName(filename string) os.Error { - file, err := os.Open(filename) - if err != nil { - return err - } - defer file.Close() - return processFile(file) -} - - type fileVisitor chan os.Error func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { @@ -154,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { if isGoFile(f) { v <- nil // synchronize error handler - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { v <- err } } @@ -210,9 +210,10 @@ func gofmtMain() { initRewrite() if flag.NArg() == 0 { - if err := processFile(os.Stdin); err != nil { + if err := processFile("", os.Stdin, os.Stdout); err != nil { report(err) } + return } for i := 0; i < flag.NArg(); i++ { @@ -221,7 +222,7 @@ func gofmtMain() { case err != nil: report(err) case dir.IsRegular(): - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { report(err) } case dir.IsDirectory():