diff --git a/cmds.go b/cmds.go new file mode 100644 index 0000000..5b7bbe6 --- /dev/null +++ b/cmds.go @@ -0,0 +1,70 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "os" + "os/exec" +) + +type Command struct { + Path string `json:"command_path"` + Args []string `json:"command_args"` + AllowExitCode int `json:"exit_code"` + Event string `json:"event"` + MsgFormat string `json:"msg_format"` +} + +func (c *Command) Run(fp string) bool { + cmd := &exec.Cmd{} + if len(c.Args) == 0 { + cmd = exec.Command(c.Path, fmt.Sprintf(c.MsgFormat, fp)) + } else { + cmd = exec.Command(c.Path, c.Args...) + } + + err := cmd.Start() + if err != nil { + log.Println(err) + return false + } + + err = cmd.Wait() + if err != nil { + exit, ok := err.(*exec.ExitError) + if !ok { + return false + } + if exit.ExitCode() == c.AllowExitCode { + return true + } + } + + return true +} + +type Commands []Command + +func (cs Commands) Get(event string) *Command { + for _, c := range cs { + if c.Event == event { + return &c + } + } + return nil +} + +func LoadCommands(p string) Commands { + cmds := Commands{} + data, err := os.ReadFile(p) + if err != nil { + log.Fatal(err) + } + err = json.Unmarshal(data, &cmds) + if err != nil { + log.Fatal(err) + } + + return cmds +} diff --git a/main.go b/main.go index ccaa896..855058b 100644 --- a/main.go +++ b/main.go @@ -1,15 +1,11 @@ package main import ( - "bufio" "flag" - "fmt" "log" "net" "net/http" "net/http/pprof" - "os" - "strings" "time" "golang.org/x/crypto/ssh" @@ -19,6 +15,7 @@ import ( func main() { dbg := flag.Bool("debug", false, "Enable pprof debugging") sock := flag.String("s", "/tmp/traygent", "Socket path to create") + cmdList := flag.String("c", "/etc/traygent.json", "List of commands to execute") flag.Parse() if *dbg { @@ -60,25 +57,36 @@ func main() { } }() + cmds := LoadCommands(*cmdList) + for { select { case added := <-tagent.addChan: - fmt.Printf("NOTICE: added %q\n", ssh.FingerprintSHA256(added)) - case rm := <-tagent.rmChan: - fmt.Printf("NOTICE: removed %q\n", rm) - case pub := <-tagent.sigReq: - r := bufio.NewReader(os.Stdin) - - fmt.Printf("NOTICE: Allow access to %q?: ", ssh.FingerprintSHA256(pub)) - resp, _ := r.ReadString('\n') - resp = strings.Trim(resp, "\n") - - if resp == "yes" { - go func() { tagent.sigResp <- true }() - } else { - go func() { tagent.sigResp <- false }() + fp := ssh.FingerprintSHA256(added) + log.Printf("NOTICE: added %q\n", fp) + c := cmds.Get("added") + if c != nil { + c.Run(fp) + } + case rm := <-tagent.rmChan: + log.Printf("NOTICE: removed %q\n", rm) + c := cmds.Get("removed") + if c != nil { + c.Run(rm) + } + case pub := <-tagent.sigReq: + fp := ssh.FingerprintSHA256(pub) + log.Printf("NOTICE: access request for: %q?\n", fp) + c := cmds.Get("sign") + if c != nil { + if c.Run(fp) { + go func() { tagent.sigResp <- true }() + } else { + go func() { tagent.sigResp <- false }() + } + } else { + panic("nope") } - fmt.Printf("%q\n", resp) } } }