mirror of
https://github.com/golang/go
synced 2024-10-04 21:11:22 -06:00
126 lines
2.6 KiB
Go
126 lines
2.6 KiB
Go
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"go/ast"
|
||
|
"go/token"
|
||
|
)
|
||
|
|
||
|
var httpserverFix = fix{
|
||
|
"httpserver",
|
||
|
httpserver,
|
||
|
`Adapt http server methods and functions to changes
|
||
|
made to the http ResponseWriter interface.
|
||
|
|
||
|
http://codereview.appspot.com/4245064 Hijacker
|
||
|
http://codereview.appspot.com/4239076 Header
|
||
|
http://codereview.appspot.com/4239077 Flusher
|
||
|
http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS
|
||
|
`,
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
register(httpserverFix)
|
||
|
}
|
||
|
|
||
|
func httpserver(f *ast.File) bool {
|
||
|
if !imports(f, "http") {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
fixed := false
|
||
|
for _, decl := range f.Decls {
|
||
|
fn, ok := decl.(*ast.FuncDecl)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
w, req, ok := isServeHTTP(fn)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
rewrite(fn.Body, func(n interface{}) {
|
||
|
// Want to replace expression sometimes,
|
||
|
// so record pointer to it for updating below.
|
||
|
ptr, ok := n.(*ast.Expr)
|
||
|
if ok {
|
||
|
n = *ptr
|
||
|
}
|
||
|
|
||
|
// Look for w.UsingTLS() and w.Remoteaddr().
|
||
|
call, ok := n.(*ast.CallExpr)
|
||
|
if !ok || len(call.Args) != 0 {
|
||
|
return
|
||
|
}
|
||
|
sel, ok := call.Fun.(*ast.SelectorExpr)
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
if !refersTo(sel.X, w) {
|
||
|
return
|
||
|
}
|
||
|
switch sel.Sel.String() {
|
||
|
case "Hijack":
|
||
|
// replace w with w.(http.Hijacker)
|
||
|
sel.X = &ast.TypeAssertExpr{
|
||
|
X: sel.X,
|
||
|
Type: ast.NewIdent("http.Hijacker"),
|
||
|
}
|
||
|
fixed = true
|
||
|
case "Flush":
|
||
|
// replace w with w.(http.Flusher)
|
||
|
sel.X = &ast.TypeAssertExpr{
|
||
|
X: sel.X,
|
||
|
Type: ast.NewIdent("http.Flusher"),
|
||
|
}
|
||
|
fixed = true
|
||
|
case "UsingTLS":
|
||
|
if ptr == nil {
|
||
|
// can only replace expression if we have pointer to it
|
||
|
break
|
||
|
}
|
||
|
// replace with req.TLS != nil
|
||
|
*ptr = &ast.BinaryExpr{
|
||
|
X: &ast.SelectorExpr{
|
||
|
X: ast.NewIdent(req.String()),
|
||
|
Sel: ast.NewIdent("TLS"),
|
||
|
},
|
||
|
Op: token.NEQ,
|
||
|
Y: ast.NewIdent("nil"),
|
||
|
}
|
||
|
fixed = true
|
||
|
case "RemoteAddr":
|
||
|
if ptr == nil {
|
||
|
// can only replace expression if we have pointer to it
|
||
|
break
|
||
|
}
|
||
|
// replace with req.RemoteAddr
|
||
|
*ptr = &ast.SelectorExpr{
|
||
|
X: ast.NewIdent(req.String()),
|
||
|
Sel: ast.NewIdent("RemoteAddr"),
|
||
|
}
|
||
|
fixed = true
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
return fixed
|
||
|
}
|
||
|
|
||
|
func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) {
|
||
|
for _, field := range fn.Type.Params.List {
|
||
|
if isPkgDot(field.Type, "http", "ResponseWriter") {
|
||
|
w = field.Names[0]
|
||
|
continue
|
||
|
}
|
||
|
if isPtrPkgDot(field.Type, "http", "Request") {
|
||
|
req = field.Names[0]
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ok = w != nil && req != nil
|
||
|
return
|
||
|
}
|