// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( "fmt" "go/ast" "go/token" "go/types" "sort" "golang.org/x/tools/cmd/guru/serial" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/pointer" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) var builtinErrorType = types.Universe.Lookup("error").Type() // whicherrs takes an position to an error and tries to find all types, constants // and global value which a given error can point to and which can be checked from the // scope where the error lives. // In short, it returns a list of things that can be checked against in order to handle // an error properly. // // TODO(dmorsing): figure out if fields in errors like *os.PathError.Err // can be queried recursively somehow. func whicherrs(q *Query) error { lconf := loader.Config{Build: q.Build} if err := setPTAScope(&lconf, q.Scope); err != nil { return err } // Load/parse/type-check the program. lprog, err := lconf.Load() if err != nil { return err } qpos, err := parseQueryPos(lprog, q.Pos, true) // needs exact pos if err != nil { return err } prog := ssautil.CreateProgram(lprog, ssa.GlobalDebug) ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection) if err != nil { return err } path, action := findInterestingNode(qpos.info, qpos.path) if action != actionExpr { return fmt.Errorf("whicherrs wants an expression; got %s", astutil.NodeDescription(qpos.path[0])) } var expr ast.Expr var obj types.Object switch n := path[0].(type) { case *ast.ValueSpec: // ambiguous ValueSpec containing multiple names return fmt.Errorf("multiple value specification") case *ast.Ident: obj = qpos.info.ObjectOf(n) expr = n case ast.Expr: expr = n default: return fmt.Errorf("unexpected AST for expr: %T", n) } typ := qpos.info.TypeOf(expr) if !types.Identical(typ, builtinErrorType) { return fmt.Errorf("selection is not an expression of type 'error'") } // Determine the ssa.Value for the expression. var value ssa.Value if obj != nil { // def/ref of func/var object value, _, err = ssaValueForIdent(prog, qpos.info, obj, path) } else { value, _, err = ssaValueForExpr(prog, qpos.info, path) } if err != nil { return err // e.g. trivially dead code } // Defer SSA construction till after errors are reported. prog.Build() globals := findVisibleErrs(prog, qpos) constants := findVisibleConsts(prog, qpos) res := &whicherrsResult{ qpos: qpos, errpos: expr.Pos(), } // TODO(adonovan): the following code is heavily duplicated // w.r.t. "pointsto". Refactor? // Find the instruction which initialized the // global error. If more than one instruction has stored to the global // remove the global from the set of values that we want to query. allFuncs := ssautil.AllFunctions(prog) for fn := range allFuncs { for _, b := range fn.Blocks { for _, instr := range b.Instrs { store, ok := instr.(*ssa.Store) if !ok { continue } gval, ok := store.Addr.(*ssa.Global) if !ok { continue } gbl, ok := globals[gval] if !ok { continue } // we already found a store to this global // The normal error define is just one store in the init // so we just remove this global from the set we want to query if gbl != nil { delete(globals, gval) } globals[gval] = store.Val } } } ptaConfig.AddQuery(value) for _, v := range globals { ptaConfig.AddQuery(v) } ptares := ptrAnalysis(ptaConfig) valueptr := ptares.Queries[value] if valueptr == (pointer.Pointer{}) { return fmt.Errorf("pointer analysis did not find expression (dead code?)") } for g, v := range globals { ptr, ok := ptares.Queries[v] if !ok { continue } if !ptr.MayAlias(valueptr) { continue } res.globals = append(res.globals, g) } pts := valueptr.PointsTo() dedup := make(map[*ssa.NamedConst]bool) for _, label := range pts.Labels() { // These values are either MakeInterfaces or reflect // generated interfaces. For the purposes of this // analysis, we don't care about reflect generated ones makeiface, ok := label.Value().(*ssa.MakeInterface) if !ok { continue } constval, ok := makeiface.X.(*ssa.Const) if !ok { continue } c := constants[*constval] if c != nil && !dedup[c] { dedup[c] = true res.consts = append(res.consts, c) } } concs := pts.DynamicTypes() concs.Iterate(func(conc types.Type, _ interface{}) { // go/types is a bit annoying here. // We want to find all the types that we can // typeswitch or assert to. This means finding out // if the type pointed to can be seen by us. // // For the purposes of this analysis, the type is always // either a Named type or a pointer to one. // There are cases where error can be implemented // by unnamed types, but in that case, we can't assert to // it, so we don't care about it for this analysis. var name *types.TypeName switch t := conc.(type) { case *types.Pointer: named, ok := t.Elem().(*types.Named) if !ok { return } name = named.Obj() case *types.Named: name = t.Obj() default: return } if !isAccessibleFrom(name, qpos.info.Pkg) { return } res.types = append(res.types, &errorType{conc, name}) }) sort.Sort(membersByPosAndString(res.globals)) sort.Sort(membersByPosAndString(res.consts)) sort.Sort(sorterrorType(res.types)) q.Output(lprog.Fset, res) return nil } // findVisibleErrs returns a mapping from each package-level variable of type "error" to nil. func findVisibleErrs(prog *ssa.Program, qpos *queryPos) map[*ssa.Global]ssa.Value { globals := make(map[*ssa.Global]ssa.Value) for _, pkg := range prog.AllPackages() { for _, mem := range pkg.Members { gbl, ok := mem.(*ssa.Global) if !ok { continue } gbltype := gbl.Type() // globals are always pointers if !types.Identical(deref(gbltype), builtinErrorType) { continue } if !isAccessibleFrom(gbl.Object(), qpos.info.Pkg) { continue } globals[gbl] = nil } } return globals } // findVisibleConsts returns a mapping from each package-level constant assignable to type "error", to nil. func findVisibleConsts(prog *ssa.Program, qpos *queryPos) map[ssa.Const]*ssa.NamedConst { constants := make(map[ssa.Const]*ssa.NamedConst) for _, pkg := range prog.AllPackages() { for _, mem := range pkg.Members { obj, ok := mem.(*ssa.NamedConst) if !ok { continue } consttype := obj.Type() if !types.AssignableTo(consttype, builtinErrorType) { continue } if !isAccessibleFrom(obj.Object(), qpos.info.Pkg) { continue } constants[*obj.Value] = obj } } return constants } type membersByPosAndString []ssa.Member func (a membersByPosAndString) Len() int { return len(a) } func (a membersByPosAndString) Less(i, j int) bool { cmp := a[i].Pos() - a[j].Pos() return cmp < 0 || cmp == 0 && a[i].String() < a[j].String() } func (a membersByPosAndString) Swap(i, j int) { a[i], a[j] = a[j], a[i] } type sorterrorType []*errorType func (a sorterrorType) Len() int { return len(a) } func (a sorterrorType) Less(i, j int) bool { cmp := a[i].obj.Pos() - a[j].obj.Pos() return cmp < 0 || cmp == 0 && a[i].typ.String() < a[j].typ.String() } func (a sorterrorType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } type errorType struct { typ types.Type // concrete type N or *N that implements error obj *types.TypeName // the named type N } type whicherrsResult struct { qpos *queryPos errpos token.Pos globals []ssa.Member consts []ssa.Member types []*errorType } func (r *whicherrsResult) PrintPlain(printf printfFunc) { if len(r.globals) > 0 { printf(r.qpos, "this error may point to these globals:") for _, g := range r.globals { printf(g.Pos(), "\t%s", g.RelString(r.qpos.info.Pkg)) } } if len(r.consts) > 0 { printf(r.qpos, "this error may contain these constants:") for _, c := range r.consts { printf(c.Pos(), "\t%s", c.RelString(r.qpos.info.Pkg)) } } if len(r.types) > 0 { printf(r.qpos, "this error may contain these dynamic types:") for _, t := range r.types { printf(t.obj.Pos(), "\t%s", r.qpos.typeString(t.typ)) } } } func (r *whicherrsResult) JSON(fset *token.FileSet) []byte { we := &serial.WhichErrs{} we.ErrPos = fset.Position(r.errpos).String() for _, g := range r.globals { we.Globals = append(we.Globals, fset.Position(g.Pos()).String()) } for _, c := range r.consts { we.Constants = append(we.Constants, fset.Position(c.Pos()).String()) } for _, t := range r.types { var et serial.WhichErrsType et.Type = r.qpos.typeString(t.typ) et.Position = fset.Position(t.obj.Pos()).String() we.Types = append(we.Types, et) } return toJSON(we) }