traygent/agent.go

306 lines
5.5 KiB
Go

package main
import (
"bytes"
"crypto/rand"
"crypto/subtle"
"errors"
"fmt"
"log"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var errLocked = errors.New("agent is locked")
const expFormat = "Mon Jan 2 15:04:05 MST 2006"
// Traygent extends x/crypto/ssh/agent to hook into fyne for various tasks:
// - notifications
// - allowing UI elements to represent keys
type Traygent struct {
expire uint32
listener net.Listener
mu sync.RWMutex
keys []privKey
passphrase []byte
locked bool
addChan chan ssh.PublicKey
rmChan chan string
sigReq chan ssh.PublicKey
sigResp chan bool
}
func (t *Traygent) log(title, msgFmt string, msg ...any) {
msgStr := fmt.Sprintf(msgFmt, msg...)
log.Println(msgStr)
}
func (t *Traygent) remove(key ssh.PublicKey, reason string) error {
hasKey := false
strReason := ""
switch reason {
case "expired":
strReason = "key expired"
case "request":
strReason = "user requested key be removed"
default:
log.Fatalf("unknown removal reason: %q\n", reason)
}
for i := 0; i < len(t.keys); {
if bytes.Equal(
t.keys[i].signer.PublicKey().Marshal(),
key.Marshal(),
) {
hasKey = true
t.keys[i] = t.keys[len(t.keys)-1]
t.keys = t.keys[:len(t.keys)-1]
fp := ssh.FingerprintSHA256(key)
t.log("Key removed", "removed key: %q (%s)\n", fp, strReason)
go func() { t.rmChan <- fp }()
continue
} else {
i++
}
}
if !hasKey {
return errors.New("key not found")
}
return nil
}
func (t *Traygent) RemoveLocked() {
t.mu.Lock()
defer t.mu.Unlock()
for _, k := range t.keys {
if k.expire != nil && time.Now().After(*k.expire) {
t.remove(k.signer.PublicKey(), "expired")
}
}
}
func (t *Traygent) List() ([]*agent.Key, error) {
t.mu.RLock()
defer t.mu.RUnlock()
var pubKeys []*agent.Key
if t.locked {
return nil, nil
}
for _, k := range t.keys {
pubKeys = append(pubKeys, &agent.Key{
Blob: k.pubKey.Marshal(),
Comment: fmt.Sprintf("%s [%s]", k.comment, k.expire.Format(expFormat)),
Format: k.pubKey.Type(),
})
}
return pubKeys, nil
}
func (t *Traygent) Lock(passphrase []byte) error {
t.log("Agent locked", "locking agent")
if t.locked {
return errLocked
}
t.mu.Lock()
defer t.mu.Unlock()
t.passphrase = passphrase
t.locked = true
return nil
}
func (t *Traygent) Unlock(unusedpassphrase []byte) error {
log.Println("unlocking agent")
if t.locked {
return errors.New("not locked")
}
t.mu.Lock()
defer t.mu.Unlock()
log.Println("hur")
if subtle.ConstantTimeCompare(unusedpassphrase, t.passphrase) == 1 {
t.passphrase = nil
t.locked = false
}
return nil
}
func (t *Traygent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
var sig *ssh.Signature
go func() { t.sigReq <- key }()
select {
case allowed := <-t.sigResp:
if allowed {
return t.SignWithFlags(key, data, 0)
}
}
return sig, fmt.Errorf("not allowed")
}
func (t *Traygent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
if t.locked {
return nil, errLocked
}
t.RemoveLocked()
t.mu.Lock()
defer t.mu.Unlock()
pk := key.Marshal()
for _, k := range t.keys {
if bytes.Equal(k.signer.PublicKey().Marshal(), pk) {
if flags == 0 {
return k.signer.Sign(rand.Reader, data)
} else {
if algSiger, ok := k.signer.(ssh.AlgorithmSigner); !ok {
return nil, fmt.Errorf("%T is not supported", k.signer)
} else {
var alg string
switch flags {
case agent.SignatureFlagRsaSha256:
alg = ssh.KeyAlgoRSASHA256
case agent.SignatureFlagRsaSha512:
alg = ssh.KeyAlgoRSASHA512
default:
return nil, fmt.Errorf("unsupported signature flags: %d", flags)
}
return algSiger.SignWithAlgorithm(rand.Reader, data, alg)
}
}
}
}
return nil, errors.New("not found")
}
func (t *Traygent) Signers() ([]ssh.Signer, error) {
log.Println("signers")
if t.locked {
return nil, errLocked
}
t.RemoveLocked()
t.mu.Lock()
defer t.mu.Unlock()
signers := make([]ssh.Signer, 0, len(t.keys))
for _, k := range t.keys {
signers = append(signers, k.signer)
}
return signers, nil
}
func (t *Traygent) getMaxes() (string, string, string, string) {
t.mu.RLock()
defer t.mu.RUnlock()
maxType := ""
maxSum := ""
maxComment := ""
for _, entry := range t.keys {
if len(entry.GetType()) > len(maxType) {
maxType = entry.GetType()
}
if len(entry.GetSum()) > len(maxSum) {
maxSum = entry.GetSum()
}
if len(entry.GetComment()) > len(maxComment) {
maxComment = entry.GetComment()
}
}
return maxType, maxSum, maxComment, expFormat
}
func (t *Traygent) Add(key agent.AddedKey) error {
signer, err := ssh.NewSignerFromKey(key.PrivateKey)
if err != nil {
return err
}
t.mu.Lock()
p := NewPrivKey(signer, key)
t.keys = append(t.keys, p)
t.log("Key added", "added %q to agent", p.fingerPrint)
go func() { t.addChan <- p.pubKey }()
t.mu.Unlock()
return nil
}
func (t *Traygent) RemoveAll() error {
if t.locked {
return errLocked
}
t.mu.Lock()
klen := len(t.keys)
t.keys = nil
t.log("All keys removed", "removed %d keys from agent", klen)
go func() { t.rmChan <- "all" }()
t.mu.Unlock()
return nil
}
func (t *Traygent) Remove(key ssh.PublicKey) error {
if t.locked {
return errLocked
}
t.mu.Lock()
err := t.remove(key, "request")
t.log("Key removed", "remove key from agent")
t.mu.Unlock()
return err
}
func NewTraygent() agent.Agent {
return &Traygent{
expire: 360,
addChan: make(chan ssh.PublicKey),
rmChan: make(chan string),
}
}