From 9c9c1b0d8e0b1e14290de7e75f2f3659ab3a8691 Mon Sep 17 00:00:00 2001
From: Martin Asquino <martin.asquino@gmail.com>
Date: Tue, 13 Feb 2024 17:38:09 +0000
Subject: [PATCH] gopls/internal/golang: add extract interface code action

---
 gopls/internal/golang/codeaction.go           |  65 +++++---
 gopls/internal/golang/extract.go              |  34 +++++
 gopls/internal/golang/fix.go                  | 141 ++++++++++++++++++
 .../testdata/codeaction/extract_interface.txt |  82 ++++++++++
 .../codeaction/extract_interface_resolve.txt  |  81 ++++++++++
 5 files changed, 380 insertions(+), 23 deletions(-)
 create mode 100644 gopls/internal/test/marker/testdata/codeaction/extract_interface.txt
 create mode 100644 gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt

diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go
index 99bfa3d75cf..acc746128fc 100644
--- a/gopls/internal/golang/codeaction.go
+++ b/gopls/internal/golang/codeaction.go
@@ -85,8 +85,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
 			}
 		}
 
+		pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
+		if err != nil {
+			return nil, err
+		}
+
 		if want[protocol.RefactorExtract] {
-			extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options())
+			extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options())
 			if err != nil {
 				return nil, err
 			}
@@ -198,20 +203,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
 }
 
 // getExtractCodeActions returns any refactor.extract code actions for the selection.
-func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
-	if rng.Start == rng.End {
-		return nil, nil
-	}
-
+func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
 	start, end, err := pgf.RangePos(rng)
 	if err != nil {
 		return nil, err
 	}
+
 	puri := pgf.URI
 	var commands []protocol.Command
-	if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
-		cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
-			Fix:          fixExtractFunction,
+
+	if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok {
+		cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{
+			Fix:          fixExtractInterface,
 			URI:          puri,
 			Range:        rng,
 			ResolveEdits: supportsResolveEdits(options),
@@ -220,9 +223,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
 			return nil, err
 		}
 		commands = append(commands, cmd)
-		if methodOk {
-			cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
-				Fix:          fixExtractMethod,
+	}
+
+	if rng.Start != rng.End {
+		if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
+			cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
+				Fix:          fixExtractFunction,
 				URI:          puri,
 				Range:        rng,
 				ResolveEdits: supportsResolveEdits(options),
@@ -231,20 +237,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
 				return nil, err
 			}
 			commands = append(commands, cmd)
+			if methodOk {
+				cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
+					Fix:          fixExtractMethod,
+					URI:          puri,
+					Range:        rng,
+					ResolveEdits: supportsResolveEdits(options),
+				})
+				if err != nil {
+					return nil, err
+				}
+				commands = append(commands, cmd)
+			}
 		}
-	}
-	if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
-		cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
-			Fix:          fixExtractVariable,
-			URI:          puri,
-			Range:        rng,
-			ResolveEdits: supportsResolveEdits(options),
-		})
-		if err != nil {
-			return nil, err
+		if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
+			cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
+				Fix:          fixExtractVariable,
+				URI:          puri,
+				Range:        rng,
+				ResolveEdits: supportsResolveEdits(options),
+			})
+			if err != nil {
+				return nil, err
+			}
+			commands = append(commands, cmd)
 		}
-		commands = append(commands, cmd)
 	}
+
 	var actions []protocol.CodeAction
 	for i := range commands {
 		actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))
diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go
index c07faec1b7a..46ac3b29a33 100644
--- a/gopls/internal/golang/extract.go
+++ b/gopls/internal/golang/extract.go
@@ -18,6 +18,7 @@ import (
 
 	"golang.org/x/tools/go/analysis"
 	"golang.org/x/tools/go/ast/astutil"
+	"golang.org/x/tools/gopls/internal/cache"
 	"golang.org/x/tools/gopls/internal/util/bug"
 	"golang.org/x/tools/gopls/internal/util/safetoken"
 	"golang.org/x/tools/internal/analysisinternal"
@@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N
 	return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
 }
 
+// CanExtractInterface reports whether the code in the given position is for a
+// type which can be represented as an interface.
+func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
+	path, _ := astutil.PathEnclosingInterval(file, start, end)
+	if len(path) == 0 {
+		return nil, nil, false, fmt.Errorf("no path enclosing interval")
+	}
+
+	node := path[0]
+	expr, ok := node.(ast.Expr)
+	if !ok {
+		return nil, nil, false, fmt.Errorf("node is not an expression")
+	}
+
+	switch e := expr.(type) {
+	case *ast.Ident:
+		o, ok := pkg.TypesInfo().ObjectOf(e).(*types.TypeName)
+		if !ok {
+			return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
+		}
+
+		if _, ok := o.Type().(*types.Basic); ok {
+			return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface")
+		}
+
+		return expr, path, true, nil
+	case *ast.StarExpr, *ast.SelectorExpr:
+		return expr, path, true, nil
+	default:
+		return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
+	}
+}
+
 // Calculate indentation for insertion.
 // When inserting lines of code, we must ensure that the lines have consistent
 // formatting (i.e. the proper indentation). To do so, we observe the indentation on the
diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go
index 2215da9b65e..a2d7748983f 100644
--- a/gopls/internal/golang/fix.go
+++ b/gopls/internal/golang/fix.go
@@ -5,13 +5,17 @@
 package golang
 
 import (
+	"bytes"
 	"context"
+	"errors"
 	"fmt"
 	"go/ast"
 	"go/token"
 	"go/types"
+	"slices"
 
 	"golang.org/x/tools/go/analysis"
+	"golang.org/x/tools/go/ast/astutil"
 	"golang.org/x/tools/gopls/internal/analysis/embeddirective"
 	"golang.org/x/tools/gopls/internal/analysis/fillstruct"
 	"golang.org/x/tools/gopls/internal/analysis/stubmethods"
@@ -22,6 +26,7 @@ import (
 	"golang.org/x/tools/gopls/internal/file"
 	"golang.org/x/tools/gopls/internal/protocol"
 	"golang.org/x/tools/gopls/internal/util/bug"
+	"golang.org/x/tools/gopls/internal/util/safetoken"
 	"golang.org/x/tools/internal/imports"
 )
 
@@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
 const (
 	fixExtractVariable   = "extract_variable"
 	fixExtractFunction   = "extract_function"
+	fixExtractInterface  = "extract_interface"
 	fixExtractMethod     = "extract_method"
 	fixInlineCall        = "inline_call"
 	fixInvertIfCondition = "invert_if_condition"
@@ -112,6 +118,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
 
 		// Ad-hoc fixers: these are used when the command is
 		// constructed directly by logic in server/code_action.
+		fixExtractInterface:  extractInterface,
 		fixExtractFunction:   singleFile(extractFunction),
 		fixExtractMethod:     singleFile(extractMethod),
 		fixExtractVariable:   singleFile(extractVariable),
@@ -142,6 +149,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
 	return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion)
 }
 
+func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
+	path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
+
+	var field *ast.Field
+	var decl ast.Decl
+	for _, node := range path {
+		if f, ok := node.(*ast.Field); ok {
+			field = f
+			continue
+		}
+
+		// Record the node that starts the declaration of the type that contains
+		// the field we are creating the interface for.
+		if d, ok := node.(ast.Decl); ok {
+			decl = d
+			break // we have both the field and the declaration
+		}
+	}
+
+	if field == nil || decl == nil {
+		return nil, nil, nil
+	}
+
+	p := safetoken.StartPosition(pkg.FileSet(), field.Pos())
+	pos := protocol.Position{
+		Line:      uint32(p.Line - 1),   // Line is zero-based
+		Character: uint32(p.Column - 1), // Character is zero-based
+	}
+
+	fh, err := snapshot.ReadFile(ctx, pgf.URI)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	refs, err := references(ctx, snapshot, fh, pos, false)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	type method struct {
+		signature *types.Signature
+		name      string
+	}
+
+	var methods []method
+	for _, ref := range refs {
+		locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI)
+		if err != nil {
+			return nil, nil, err
+		}
+
+		_, end, err := locPgf.RangePos(ref.location.Range)
+		if err != nil {
+			return nil, nil, err
+		}
+
+		// We are interested in the method call, so we need the node after the dot
+		rangeEnd := end + token.Pos(len("."))
+		path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd)
+		id, ok := path[0].(*ast.Ident)
+		if !ok {
+			continue
+		}
+
+		obj := locPkg.TypesInfo().ObjectOf(id)
+		if obj == nil {
+			continue
+		}
+
+		sig, ok := obj.Type().(*types.Signature)
+		if !ok {
+			return nil, nil, errors.New("cannot extract interface with non-method accesses")
+		}
+
+		fc := method{signature: sig, name: obj.Name()}
+		if !slices.Contains(methods, fc) {
+			methods = append(methods, fc)
+		}
+	}
+
+	interfaceName := "I" + pkg.TypesInfo().ObjectOf(field.Names[0]).Name()
+	var buf bytes.Buffer
+	buf.WriteString("\ntype ")
+	buf.WriteString(interfaceName)
+	buf.WriteString(" interface {\n")
+	for _, fc := range methods {
+		buf.WriteString("\t")
+		buf.WriteString(fc.name)
+		types.WriteSignature(&buf, fc.signature, relativeTo(pkg.Types()))
+		buf.WriteByte('\n')
+	}
+	buf.WriteByte('}')
+	buf.WriteByte('\n')
+
+	interfacePos := decl.Pos() - 1
+	// Move the interface above the documentation comment if the type declaration
+	// includes one.
+	switch d := decl.(type) {
+	case *ast.GenDecl:
+		if d.Doc != nil {
+			interfacePos = d.Doc.Pos() - 1
+		}
+	case *ast.FuncDecl:
+		if d.Doc != nil {
+			interfacePos = d.Doc.Pos() - 1
+		}
+	}
+
+	return pkg.FileSet(), &analysis.SuggestedFix{
+		Message: "Extract interface",
+		TextEdits: []analysis.TextEdit{{
+			Pos:     interfacePos,
+			End:     interfacePos,
+			NewText: buf.Bytes(),
+		}, {
+			Pos:     field.Type.Pos(),
+			End:     field.Type.End(),
+			NewText: []byte(interfaceName),
+		}},
+	}, nil
+}
+
+func relativeTo(pkg *types.Package) types.Qualifier {
+	if pkg == nil {
+		return nil
+	}
+	return func(other *types.Package) string {
+		if pkg == other {
+			return "" // same package; unqualified
+		}
+		return other.Name()
+	}
+}
+
 // suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form.
 func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) {
 	editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{}
diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt b/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt
new file mode 100644
index 00000000000..485c1bbde7e
--- /dev/null
+++ b/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt
@@ -0,0 +1,82 @@
+This test checks the behavior of the 'extract interface' code action.
+See extract_interface_resolve.txt for the same test with resolve support.
+
+-- flags --
+-ignore_extra_diags
+
+-- go.mod --
+module golang.org/lsptests/extract
+
+go 1.18
+
+-- b/b.go --
+package b
+
+type BFoo struct {}
+
+func (b BFoo) Bar() string {
+	return ""
+}
+
+func (b BFoo) Baz() int {
+	return 0
+}
+
+-- a.go --
+package extract
+
+import (
+	"golang.org/lsptests/extract/b"
+)
+
+// foo doc comment
+type foo struct {
+	fieldOne bar //@codeactionedit("bar", "refactor.extract", a1)
+	fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2)
+}
+
+type bar struct {}
+
+func (b bar) baz() error {}
+func (b bar) qux(a string, b int, c func() string) {}
+
+func (f foo) quux() {
+	f.fieldTwo.Bar()
+	f.fieldOne.baz()
+}
+
+func (f foo) corge() {
+	f.fieldOne.qux("someString", 3, func() string { return "" })
+}
+
+func FuncThatUsesBar(b *bar) { //@codeactionedit("bar", "refactor.extract", a3)
+  b.qux()
+}
+
+-- @a1/a.go --
+@@ -7 +7,5 @@
++type IfieldOne interface {
++	baz() error
++	qux(a string, b int, c func() string)
++}
++
+@@ -9 +14 @@
+-	fieldOne bar //@codeactionedit("bar", "refactor.extract", a1)
++	fieldOne IfieldOne //@codeactionedit("bar", "refactor.extract", a1)
+-- @a2/a.go --
+@@ -7 +7,4 @@
++type IfieldTwo interface {
++	Bar() string
++}
++
+@@ -10 +14 @@
+-	fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2)
++	fieldTwo IfieldTwo //@codeactionedit("BFoo", "refactor.extract", a2)
+-- @a3/a.go --
+@@ -27 +27,5 @@
+-func FuncThatUsesBar(b *bar) { //@codeactionedit("bar", "refactor.extract", a3)
++type Ib interface {
++	qux(a string, b int, c func() string)
++}
++
++func FuncThatUsesBar(b Ib) { //@codeactionedit("bar", "refactor.extract", a3)
diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt
new file mode 100644
index 00000000000..d4ee4099138
--- /dev/null
+++ b/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt
@@ -0,0 +1,81 @@
+This test checks the behavior of the 'extract interface' code action.
+See extract_interface_resolve.txt for the same test with resolve support.
+
+-- capabilities.json --
+{
+	"textDocument": {
+		"codeAction": {
+			"dataSupport": true,
+			"resolveSupport": {
+				"properties": ["edit"]
+			}
+		}
+	}
+}
+-- flags --
+-ignore_extra_diags
+
+-- go.mod --
+module golang.org/lsptests/extract
+
+go 1.18
+
+-- b/b.go --
+package b
+
+type BFoo struct {}
+
+func (b BFoo) Bar() string {
+	return ""
+}
+
+func (b BFoo) Baz() int {
+	return 0
+}
+
+-- a.go --
+package extract
+
+import (
+	"golang.org/lsptests/extract/b"
+)
+
+// foo doc comment
+type foo struct {
+	fieldOne bar //@codeactionedit("bar", "refactor.extract", a1)
+	fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2)
+}
+
+type bar struct {}
+
+func (b bar) baz() error {}
+func (b bar) qux(a string, b int, c func() string) {}
+
+func (f foo) quux() {
+	f.fieldTwo.Bar()
+	f.fieldOne.baz()
+}
+
+func (f foo) corge() {
+	f.fieldOne.qux("someString", 3, func() string { return "" })
+}
+
+-- @a1/a.go --
+@@ -7 +7,5 @@
++type IfieldOne interface {
++	baz() error
++	qux(a string, b int, c func() string)
++}
++
+@@ -9 +14 @@
+-	fieldOne bar //@codeactionedit("bar", "refactor.extract", a1)
++	fieldOne IfieldOne //@codeactionedit("bar", "refactor.extract", a1)
+-- @a2/a.go --
+@@ -7 +7,4 @@
++type IfieldTwo interface {
++	Bar() string
++}
++
+@@ -10 +14 @@
+-	fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2)
++	fieldTwo IfieldTwo //@codeactionedit("BFoo", "refactor.extract", a2)