diff --git a/src/pkg/exp/ssh/client_auth.go b/src/pkg/exp/ssh/client_auth.go index 25f9e216225..1a382357b46 100644 --- a/src/pkg/exp/ssh/client_auth.go +++ b/src/pkg/exp/ssh/client_auth.go @@ -9,7 +9,7 @@ import ( "io" ) -// authenticate authenticates with the remote server. See RFC 4252. +// authenticate authenticates with the remote server. See RFC 4252. func (c *ClientConn) authenticate(session []byte) error { // initiate user auth session if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil { @@ -24,7 +24,7 @@ func (c *ClientConn) authenticate(session []byte) error { return err } // during the authentication phase the client first attempts the "none" method - // then any untried methods suggested by the server. + // then any untried methods suggested by the server. tried, remain := make(map[string]bool), make(map[string]bool) for auth := ClientAuth(new(noneAuth)); auth != nil; { ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand()) @@ -57,9 +57,9 @@ func (c *ClientConn) authenticate(session []byte) error { // A ClientAuth represents an instance of an RFC 4252 authentication method. type ClientAuth interface { - // auth authenticates user over transport t. + // auth authenticates user over transport t. // Returns true if authentication is successful. - // If authentication is not successful, a []string of alternative + // If authentication is not successful, a []string of alternative // method names is returned. auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) @@ -79,19 +79,7 @@ func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reade return false, nil, err } - packet, err := t.readPacket() - if err != nil { - return false, nil, err - } - - switch packet[0] { - case msgUserAuthSuccess: - return true, nil, nil - case msgUserAuthFailure: - msg := decode(packet).(*userAuthFailureMsg) - return false, msg.Methods, nil - } - return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + return handleAuthResponse(t) } func (n *noneAuth) method() string { @@ -127,19 +115,7 @@ func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.R return false, nil, err } - packet, err := t.readPacket() - if err != nil { - return false, nil, err - } - - switch packet[0] { - case msgUserAuthSuccess: - return true, nil, nil - case msgUserAuthFailure: - msg := decode(packet).(*userAuthFailureMsg) - return false, msg.Methods, nil - } - return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + return handleAuthResponse(t) } func (p *passwordAuth) method() string { @@ -159,7 +135,7 @@ func ClientAuthPassword(impl ClientPassword) ClientAuth { // ClientKeyring implements access to a client key ring. type ClientKeyring interface { - // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if + // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if // no key exists at i. Key(i int) (key interface{}, err error) @@ -173,27 +149,28 @@ type publickeyAuth struct { ClientKeyring } +type publickeyAuthMsg struct { + User string + Service string + Method string + // HasSig indicates to the reciver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + Pubkey string + // Sig is defined as []byte so marshal will exclude it during validateKey + Sig []byte `ssh:"rest"` +} + func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) { - type publickeyAuthMsg struct { - User string - Service string - Method string - // HasSig indicates to the reciver packet that the auth request is signed and - // should be used for authentication of the request. - HasSig bool - Algoname string - Pubkey string - // Sig is defined as []byte so marshal will exclude it during the query phase - Sig []byte `ssh:"rest"` - } // Authentication is performed in two stages. The first stage sends an // enquiry to test if each key is acceptable to the remote. The second - // stage attempts to authenticate with the valid keys obtained in the + // stage attempts to authenticate with the valid keys obtained in the // first stage. var index int - // a map of public keys to their index in the keyring + // a map of public keys to their index in the keyring validKeys := make(map[int]interface{}) for { key, err := p.Key(index) @@ -204,33 +181,13 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io. // no more keys in the keyring break } - pubkey := serializePublickey(key) - algoname := algoName(key) - msg := publickeyAuthMsg{ - User: user, - Service: serviceSSH, - Method: p.method(), - HasSig: false, - Algoname: algoname, - Pubkey: string(pubkey), - } - if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil { - return false, nil, err - } - packet, err := t.readPacket() - if err != nil { - return false, nil, err - } - switch packet[0] { - case msgUserAuthPubKeyOk: - msg := decode(packet).(*userAuthPubKeyOkMsg) - if msg.Algo != algoname || msg.PubKey != string(pubkey) { - continue - } + + if ok, err := p.validateKey(key, user, t); ok { validKeys[index] = key - case msgUserAuthFailure: - default: - return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + } else { + if err != nil { + return false, nil, err + } } index++ } @@ -265,26 +222,63 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io. if err := t.writePacket(p); err != nil { return false, nil, err } - packet, err := t.readPacket() + success, methods, err := handleAuthResponse(t) if err != nil { return false, nil, err } - switch packet[0] { - case msgUserAuthSuccess: - return true, nil, nil - case msgUserAuthFailure: - msg := decode(packet).(*userAuthFailureMsg) - methods = msg.Methods - continue - case msgDisconnect: - return false, nil, io.EOF - default: - return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + if success { + return success, methods, err } } return false, methods, nil } +// validateKey validates the key provided it is acceptable to the server. +func (p *publickeyAuth) validateKey(key interface{}, user string, t *transport) (bool, error) { + pubkey := serializePublickey(key) + algoname := algoName(key) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: p.method(), + HasSig: false, + Algoname: algoname, + Pubkey: string(pubkey), + } + if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil { + return false, err + } + + return p.confirmKeyAck(key, t) +} + +func (p *publickeyAuth) confirmKeyAck(key interface{}, t *transport) (bool, error) { + pubkey := serializePublickey(key) + algoname := algoName(key) + + for { + packet, err := t.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + // TODO(gpaul): add callback to present the banner to the user + case msgUserAuthPubKeyOk: + msg := decode(packet).(*userAuthPubKeyOkMsg) + if msg.Algo != algoname || msg.PubKey != string(pubkey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + } + } + panic("unreachable") +} + func (p *publickeyAuth) method() string { return "publickey" } @@ -293,3 +287,30 @@ func (p *publickeyAuth) method() string { func ClientAuthPublickey(impl ClientKeyring) ClientAuth { return &publickeyAuth{impl} } + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(t *transport) (bool, []string, error) { + for { + packet, err := t.readPacket() + if err != nil { + return false, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + // TODO: add callback to present the banner to the user + case msgUserAuthFailure: + msg := decode(packet).(*userAuthFailureMsg) + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + case msgDisconnect: + return false, nil, io.EOF + default: + return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + } + } + panic("unreachable") +}