Skip to content

Commit cc152e3

Browse files
vankleefjimsivchari
authored andcommitted
also check function literals, fixes #19
1 parent a75b385 commit cc152e3

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

tenv.go

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,49 +34,58 @@ func run(pass *analysis.Pass) (interface{}, error) {
3434
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
3535

3636
nodeFilter := []ast.Node{
37-
(*ast.File)(nil),
37+
(*ast.FuncDecl)(nil),
38+
(*ast.FuncLit)(nil),
3839
}
3940

4041
inspect.Preorder(nodeFilter, func(n ast.Node) {
4142
switch n := n.(type) {
42-
case *ast.File:
43-
for _, decl := range n.Decls {
44-
45-
funcDecl, ok := decl.(*ast.FuncDecl)
46-
if !ok {
47-
continue
48-
}
49-
checkFunc(pass, funcDecl, pass.Fset.File(n.Pos()).Name())
50-
}
43+
case *ast.FuncDecl:
44+
checkFuncDecl(pass, n, pass.Fset.File(n.Pos()).Name())
45+
case *ast.FuncLit:
46+
checkFuncLit(pass, n, pass.Fset.File(n.Pos()).Name())
5147
}
5248
})
5349

5450
return nil, nil
5551
}
5652

57-
func checkFunc(pass *analysis.Pass, n *ast.FuncDecl, fileName string) {
58-
argName, ok := targetRunner(n, fileName)
59-
if ok {
60-
for _, stmt := range n.Body.List {
61-
switch stmt := stmt.(type) {
62-
case *ast.ExprStmt:
63-
if !checkExprStmt(pass, stmt, n, argName) {
64-
continue
65-
}
66-
case *ast.IfStmt:
67-
if !checkIfStmt(pass, stmt, n, argName) {
68-
continue
69-
}
70-
case *ast.AssignStmt:
71-
if !checkAssignStmt(pass, stmt, n, argName) {
72-
continue
73-
}
53+
func checkFuncDecl(pass *analysis.Pass, f *ast.FuncDecl, fileName string) {
54+
argName, ok := targetRunner(f.Type.Params.List, fileName)
55+
if !ok {
56+
return
57+
}
58+
checkStmts(pass, f.Body.List, f.Name.Name, argName)
59+
}
60+
61+
func checkFuncLit(pass *analysis.Pass, f *ast.FuncLit, fileName string) {
62+
argName, ok := targetRunner(f.Type.Params.List, fileName)
63+
if !ok {
64+
return
65+
}
66+
checkStmts(pass, f.Body.List, "function literal", argName)
67+
}
68+
69+
func checkStmts(pass *analysis.Pass, stmts []ast.Stmt, funcName, argName string) {
70+
for _, stmt := range stmts {
71+
switch stmt := stmt.(type) {
72+
case *ast.ExprStmt:
73+
if !checkExprStmt(pass, stmt, funcName, argName) {
74+
continue
75+
}
76+
case *ast.IfStmt:
77+
if !checkIfStmt(pass, stmt, funcName, argName) {
78+
continue
79+
}
80+
case *ast.AssignStmt:
81+
if !checkAssignStmt(pass, stmt, funcName, argName) {
82+
continue
7483
}
7584
}
7685
}
7786
}
7887

79-
func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, argName string) bool {
88+
func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, funcName, argName string) bool {
8089
callExpr, ok := stmt.X.(*ast.CallExpr)
8190
if !ok {
8291
return false
@@ -94,12 +103,12 @@ func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, arg
94103
if argName == "" {
95104
argName = "testing"
96105
}
97-
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
106+
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
98107
}
99108
return true
100109
}
101110

102-
func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName string) bool {
111+
func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, funcName, argName string) bool {
103112
assignStmt, ok := stmt.Init.(*ast.AssignStmt)
104113
if !ok {
105114
return false
@@ -121,12 +130,12 @@ func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName
121130
if argName == "" {
122131
argName = "testing"
123132
}
124-
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
133+
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
125134
}
126135
return true
127136
}
128137

129-
func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl, argName string) bool {
138+
func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, funcName, argName string) bool {
130139
rhs, ok := stmt.Rhs[0].(*ast.CallExpr)
131140
if !ok {
132141
return false
@@ -144,13 +153,12 @@ func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl,
144153
if argName == "" {
145154
argName = "testing"
146155
}
147-
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
156+
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
148157
}
149158
return true
150159
}
151160

152-
func targetRunner(funcDecl *ast.FuncDecl, fileName string) (string, bool) {
153-
params := funcDecl.Type.Params.List
161+
func targetRunner(params []*ast.Field, fileName string) (string, bool) {
154162
for _, p := range params {
155163
switch typ := p.Type.(type) {
156164
case *ast.StarExpr:

testdata/src/a/a_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,15 @@ func FuzzF(f *testing.F) {
5555
_ = err
5656
}
5757
}
58+
59+
func TestFunctionLiteral(t *testing.T) {
60+
testsetup()
61+
t.Run("test", func(t *testing.T) {
62+
os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
63+
err := os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
64+
_ = err
65+
if err := os.Setenv("a", "b"); err != nil { // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
66+
_ = err
67+
}
68+
})
69+
}

0 commit comments

Comments
 (0)