Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions cmd/gosh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"

"golang.org/x/term"

Expand All @@ -35,48 +37,49 @@ func main() {
}

func runAll() error {
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)

r, err := interp.New(interp.Interactive(true), interp.StdIO(os.Stdin, os.Stdout, os.Stderr))
if err != nil {
return err
}

if *command != "" {
return run(r, strings.NewReader(*command), "")
return run(ctx, r, strings.NewReader(*command), "")
}
if flag.NArg() == 0 {
if term.IsTerminal(int(os.Stdin.Fd())) {
return runInteractive(r, os.Stdin, os.Stdout, os.Stderr)
return runInteractive(ctx, r, os.Stdin, os.Stdout, os.Stderr)
}
return run(r, os.Stdin, "")
return run(ctx, r, os.Stdin, "")
}
for _, path := range flag.Args() {
if err := runPath(r, path); err != nil {
if err := runPath(ctx, r, path); err != nil {
return err
}
}
return nil
}

func run(r *interp.Runner, reader io.Reader, name string) error {
func run(ctx context.Context, r *interp.Runner, reader io.Reader, name string) error {
prog, err := syntax.NewParser().Parse(reader, name)
if err != nil {
return err
}
r.Reset()
ctx := context.Background()
return r.Run(ctx, prog)
}

func runPath(r *interp.Runner, path string) error {
func runPath(ctx context.Context, r *interp.Runner, path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
return run(r, f, path)
return run(ctx, r, f, path)
}

func runInteractive(r *interp.Runner, stdin io.Reader, stdout, stderr io.Writer) error {
func runInteractive(ctx context.Context, r *interp.Runner, stdin io.Reader, stdout, stderr io.Writer) error {
parser := syntax.NewParser()
fmt.Fprintf(stdout, "$ ")
for stmts, err := range parser.InteractiveSeq(stdin) {
Expand All @@ -87,7 +90,6 @@ func runInteractive(r *interp.Runner, stdin io.Reader, stdout, stderr io.Writer)
fmt.Fprintf(stdout, "> ")
continue
}
ctx := context.Background()
for _, stmt := range stmts {
err := r.Run(ctx, stmt)
if r.Exited() {
Expand Down
5 changes: 3 additions & 2 deletions cmd/gosh/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"context"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -202,7 +203,7 @@ func TestInteractive(t *testing.T) {
runner, _ := interp.New(interp.Interactive(true), interp.StdIO(inReader, outWriter, outWriter))
errc := make(chan error, 1)
go func() {
errc <- runInteractive(runner, inReader, outWriter, outWriter)
errc <- runInteractive(context.Background(), runner, inReader, outWriter, outWriter)
// Discard the rest of the input.
io.Copy(io.Discard, inReader)
inReader.Close()
Expand Down Expand Up @@ -255,7 +256,7 @@ func TestInteractiveExit(t *testing.T) {
}()
w := io.Discard
runner, _ := interp.New(interp.Interactive(true), interp.StdIO(inReader, w, w))
if err := runInteractive(runner, inReader, w, w); err != nil {
if err := runInteractive(context.Background(), runner, inReader, w, w); err != nil {
t.Fatal("expected a nil error")
}
}
Expand Down
13 changes: 10 additions & 3 deletions interp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,29 @@ func DefaultExecHandler(killTimeout time.Duration) ExecHandlerFunc {
Stdout: hc.Stdout,
Stderr: hc.Stderr,
}
prepareCommand(&cmd)

err = cmd.Start()
if err == nil {
stopf := context.AfterFunc(ctx, func() {
if killTimeout <= 0 || runtime.GOOS == "windows" {
_ = cmd.Process.Signal(os.Kill)
_ = killCommand(&cmd)
return
}
_ = cmd.Process.Signal(os.Interrupt)
_ = interruptCommand(&cmd)
// TODO: don't sleep in this goroutine if the program
// stops itself with the interrupt above.
time.Sleep(killTimeout)
_ = cmd.Process.Signal(os.Kill)
_ = killCommand(&cmd)
})
defer stopf()

// Set the command's process group as foreground so that signals
// (e.g., SIGINT) are delivered to it.
setProcessForeground(cmd.Process.Pid, hc.runner.stdin.Fd())
// Restore the shell's process group as foreground when done.
defer restoreForeground(hc.runner.stdin.Fd())

err = cmd.Wait()
}

Expand Down
20 changes: 20 additions & 0 deletions interp/os_notunix.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package interp
import (
"context"
"fmt"
"os/exec"

"mvdan.cc/sh/v3/syntax"
)
Expand Down Expand Up @@ -56,3 +57,22 @@ type waitStatus struct{}

func (waitStatus) Signaled() bool { return false }
func (waitStatus) Signal() int { return 0 }

// prepareCommand is a no-op.
func prepareCommand(cmd *exec.Cmd) {}

// setProcessForeground is a no-op on non-Unix systems.
func setProcessForeground(pid int, fd uintptr) {}

// restoreForeground is a no-op on non-Unix systems.
func restoreForeground(fd uintptr) {}

// interruptCommand interrupts the process killing it.
func interruptCommand(cmd *exec.Cmd) error {
return cmd.Process.Kill()
}

// killCommand kills the process by killing it.
func killCommand(cmd *exec.Cmd) error {
return cmd.Process.Kill()
}
48 changes: 48 additions & 0 deletions interp/os_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ package interp

import (
"context"
"os/exec"
"os/signal"
"os/user"
"strconv"
"syscall"
"unsafe"

"golang.org/x/sys/unix"
"mvdan.cc/sh/v3/syntax"
Expand Down Expand Up @@ -46,3 +49,48 @@ func (r *Runner) unTestOwnOrGrp(ctx context.Context, op syntax.UnTestOperator, x
}

type waitStatus = syscall.WaitStatus

// prepareCommand sets the SysProcAttr for the command to create a new process group.
func prepareCommand(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

// setProcessForeground sets the given process group as the foreground process group
// on the terminal associated with fd. This ensures that signals like SIGINT are
// delivered to the process group.
func setProcessForeground(pid int, fd uintptr) {
signal.Ignore(syscall.SIGTTOU)
defer signal.Reset(syscall.SIGTTOU)

syscall.Syscall(
syscall.SYS_IOCTL,
fd,
syscall.TIOCSPGRP,
uintptr(unsafe.Pointer(&pid)),
)
}

// restoreForeground restores the shell's process group as the foreground process group
// on the terminal associated with fd.
func restoreForeground(fd uintptr) {
signal.Ignore(syscall.SIGTTOU)
defer signal.Reset(syscall.SIGTTOU)

shPgid, _ := syscall.Getpgid(0)
syscall.Syscall(
syscall.SYS_IOCTL,
fd,
syscall.TIOCSPGRP,
uintptr(unsafe.Pointer(&shPgid)),
)
}

// interruptCommand interrupts the whole process group.
func interruptCommand(cmd *exec.Cmd) error {
return unix.Kill(-cmd.Process.Pid, unix.SIGINT)
}

// killCommand kills the whole process group.
func killCommand(cmd *exec.Cmd) error {
return unix.Kill(-cmd.Process.Pid, unix.SIGKILL)
}