This commit is contained in:
Manfred Touron
2017-05-18 18:54:23 +02:00
parent dc386661ca
commit 5448f25fd6
645 changed files with 55908 additions and 33297 deletions

View File

@@ -15,6 +15,7 @@ package acme
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -36,9 +37,6 @@ import (
"strings"
"sync"
"time"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
)
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@@ -133,7 +131,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
if dirURL == "" {
dirURL = LetsEncryptURL
}
res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
res, err := c.get(ctx, dirURL)
if err != nil {
return Directory{}, err
}
@@ -154,7 +152,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
CAA []string `json:"caa-identities"`
}
}
if json.NewDecoder(res.Body).Decode(&v); err != nil {
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return Directory{}, err
}
c.dir = &Directory{
@@ -200,7 +198,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
req.NotAfter = now.Add(exp).Format(time.RFC3339)
}
res, err := c.postJWS(ctx, c.Key, c.dir.CertURL, req)
res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req)
if err != nil {
return nil, "", err
}
@@ -216,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
return cert, curl, err
}
// slurp issued cert and CA chain, if requested
cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
cert, err := c.responseCert(ctx, res, bundle)
return cert, curl, err
}
@@ -231,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
// and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
return responseCert(ctx, c.HTTPClient, res, bundle)
return c.responseCert(ctx, res, bundle)
}
if res.StatusCode > 299 {
return nil, responseError(res)
@@ -275,7 +273,7 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte,
if key == nil {
key = c.Key
}
res, err := c.postJWS(ctx, key, c.dir.RevokeURL, body)
res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body)
if err != nil {
return err
}
@@ -363,7 +361,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain},
}
res, err := c.postJWS(ctx, c.Key, c.dir.AuthzURL, req)
res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req)
if err != nil {
return nil, err
}
@@ -387,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
// If a caller needs to poll an authorization until its status is final,
// see the WaitAuthorization method.
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -421,7 +419,7 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
Status: "deactivated",
Delete: true,
}
res, err := c.postJWS(ctx, c.Key, url, req)
res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil {
return err
}
@@ -438,25 +436,11 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
//
// It returns a non-nil Authorization only if its Status is StatusValid.
// In all other cases WaitAuthorization returns an error.
// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed.
// If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
var count int
sleep := func(v string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(v, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
sleep := sleeper(ctx)
for {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -481,7 +465,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
return raw.authorization(url), nil
}
if raw.Status == StatusInvalid {
return nil, ErrAuthorizationFailed
return nil, raw.error(url)
}
if err := sleep(retry, 0); err != nil {
return nil, err
@@ -493,7 +477,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
//
// A client typically polls a challenge status using this method.
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -527,7 +511,7 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
Type: chal.Type,
Auth: auth,
}
res, err := c.postJWS(ctx, c.Key, chal.URI, req)
res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req)
if err != nil {
return nil, err
}
@@ -660,7 +644,7 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms
}
res, err := c.postJWS(ctx, c.Key, url, req)
res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil {
return nil, err
}
@@ -697,6 +681,40 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
}, nil
}
// retryPostJWS will retry calls to postJWS if there is a badNonce error,
// clearing the stored nonces after each error.
// If the response was 4XX-5XX, then responseError is called on the body,
// the body is closed, and the error returned.
func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
sleep := sleeper(ctx)
for {
res, err := c.postJWS(ctx, key, url, body)
if err != nil {
return nil, err
}
// handle errors 4XX-5XX with responseError
if res.StatusCode >= 400 && res.StatusCode <= 599 {
err := responseError(res)
res.Body.Close()
// according to spec badNonce is urn:ietf:params:acme:error:badNonce
// however, acme servers in the wild return their version of the error
// https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") {
// clear any nonces that we might've stored that might now be
// considered bad
c.clearNonces()
retry := res.Header.Get("retry-after")
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
return nil, err
}
return res, nil
}
}
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
@@ -708,7 +726,7 @@ func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, bod
if err != nil {
return nil, err
}
res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
if err != nil {
return nil, err
}
@@ -722,7 +740,7 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
if len(c.nonces) == 0 {
return fetchNonce(ctx, c.HTTPClient, url)
return c.fetchNonce(ctx, url)
}
var nonce string
for nonce = range c.nonces {
@@ -732,6 +750,13 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
return nonce, nil
}
// clearNonces clears any stored nonces
func (c *Client) clearNonces() {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
c.nonces = make(map[string]struct{})
}
// addNonce stores a nonce value found in h (if any) for future use.
func (c *Client) addNonce(h http.Header) {
v := nonceFromHeader(h)
@@ -749,8 +774,58 @@ func (c *Client) addNonce(h http.Header) {
c.nonces[v] = struct{}{}
}
func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
resp, err := ctxhttp.Head(ctx, client, url)
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", urlStr, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.do(ctx, req)
}
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
res, err := c.httpClient().Do(req.WithContext(ctx))
if err != nil {
select {
case <-ctx.Done():
// Prefer the unadorned context error.
// (The acme package had tests assuming this, previously from ctxhttp's
// behavior, predating net/http supporting contexts natively)
// TODO(bradfitz): reconsider this in the future. But for now this
// requires no test updates.
return nil, ctx.Err()
default:
return nil, err
}
}
return res, nil
}
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
resp, err := c.head(ctx, url)
if err != nil {
return "", err
}
@@ -769,7 +844,7 @@ func nonceFromHeader(h http.Header) string {
return h.Get("Replay-Nonce")
}
func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
if err != nil {
return nil, fmt.Errorf("acme: response stream: %v", err)
@@ -793,7 +868,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response,
return nil, errors.New("acme: rel=up link is too large")
}
for _, url := range up {
cc, err := chainCert(ctx, client, url, 0)
cc, err := c.chainCert(ctx, url, 0)
if err != nil {
return nil, err
}
@@ -807,14 +882,8 @@ func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body)
e := struct {
Status int
Type string
Detail string
}{
Status: resp.StatusCode,
}
if err := json.Unmarshal(b, &e); err != nil {
e := &wireError{Status: resp.StatusCode}
if err := json.Unmarshal(b, e); err != nil {
// this is not a regular error response:
// populate detail with anything we received,
// e.Status will already contain HTTP response code value
@@ -823,12 +892,7 @@ func responseError(resp *http.Response) error {
e.Detail = resp.Status
}
}
return &Error{
StatusCode: e.Status,
ProblemType: e.Type,
Detail: e.Detail,
Header: resp.Header,
}
return e.error(resp.Header)
}
// chainCert fetches CA certificate chain recursively by following "up" links.
@@ -836,12 +900,12 @@ func responseError(resp *http.Response) error {
// if the recursion level reaches maxChainLen.
//
// First chainCert call starts with depth of 0.
func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
if depth >= maxChainLen {
return nil, errors.New("acme: certificate chain is too deep")
}
res, err := ctxhttp.Get(ctx, client, url)
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -863,7 +927,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
return nil, errors.New("acme: certificate chain is too large")
}
for _, up := range uplink {
cc, err := chainCert(ctx, client, up, depth+1)
cc, err := c.chainCert(ctx, up, depth+1)
if err != nil {
return nil, err
}
@@ -893,6 +957,28 @@ func linkHeader(h http.Header, rel string) []string {
return links
}
// sleeper returns a function that accepts the Retry-After HTTP header value
// and an increment that's used with backoff to increasingly sleep on
// consecutive calls until the context is done. If the Retry-After header
// cannot be parsed, then backoff is used with a maximum sleep time of 10
// seconds.
func sleeper(ctx context.Context) func(ra string, inc int) error {
var count int
return func(ra string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(ra, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
}
// retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns d if v cannot be parsed.
@@ -974,7 +1060,8 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
tmpl.DNSNames = san

View File

@@ -6,6 +6,7 @@ package acme
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -23,8 +24,6 @@ import (
"strings"
"testing"
"time"
"golang.org/x/net/context"
)
// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided
@@ -544,6 +543,9 @@ func TestWaitAuthorizationInvalid(t *testing.T) {
if err == nil {
t.Error("err is nil")
}
if _, ok := err.(*AuthorizationError); !ok {
t.Errorf("err is %T; want *AuthorizationError", err)
}
}
}
@@ -981,7 +983,8 @@ func TestNonce_fetch(t *testing.T) {
defer ts.Close()
for ; i < len(tests); i++ {
test := tests[i]
n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
c := &Client{}
n, err := c.fetchNonce(context.Background(), ts.URL)
if n != test.nonce {
t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
}
@@ -999,7 +1002,8 @@ func TestNonce_fetchError(t *testing.T) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer ts.Close()
_, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
c := &Client{}
_, err := c.fetchNonce(context.Background(), ts.URL)
e, ok := err.(*Error)
if !ok {
t.Fatalf("err is %T; want *Error", err)
@@ -1064,6 +1068,44 @@ func TestNonce_postJWS(t *testing.T) {
}
}
func TestRetryPostJWS(t *testing.T) {
var count int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count))
if r.Method == "HEAD" {
// We expect the client to do 2 head requests to fetch
// nonces, one to start and another after getting badNonce
return
}
head, err := decodeJWSHead(r)
if err != nil {
t.Errorf("decodeJWSHead: %v", err)
} else if head.Nonce == "" {
t.Error("head.Nonce is empty")
} else if head.Nonce == "nonce1" {
// return a badNonce error to force the call to retry
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`))
return
}
// Make client.Authorize happy; we're not testing its result.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"status":"valid"}`))
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
// This call will fail with badNonce, causing a retry
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 1: %v", err)
}
if count != 4 {
t.Errorf("total requests count: %d; want 4", count)
}
}
func TestLinkHeader(t *testing.T) {
h := http.Header{"Link": {
`<https://example.com/acme/new-authz>;rel="next"`,

View File

@@ -10,6 +10,7 @@ package autocert
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -30,9 +31,14 @@ import (
"time"
"golang.org/x/crypto/acme"
"golang.org/x/net/context"
)
// createCertRetryAfter is how much time to wait before removing a failed state
// entry due to an unsuccessful createCert call.
// This is a variable instead of a const for testing.
// TODO: Consider making it configurable or an exp backoff?
var createCertRetryAfter = time.Minute
// pseudoRand is safe for concurrent use.
var pseudoRand *lockedMathRand
@@ -41,8 +47,9 @@ func init() {
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
}
// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service
// during account registration.
// AcceptTOS is a Manager.Prompt function that always returns true to
// indicate acceptance of the CA's Terms of Service during account
// registration.
func AcceptTOS(tosURL string) bool { return true }
// HostPolicy specifies which host names the Manager is allowed to respond to.
@@ -76,18 +83,6 @@ func defaultHostPolicy(context.Context, string) error {
// It obtains and refreshes certificates automatically,
// as well as providing them to a TLS server via tls.Config.
//
// A simple usage example:
//
// m := autocert.Manager{
// Prompt: autocert.AcceptTOS,
// HostPolicy: autocert.HostWhitelist("example.org"),
// }
// s := &http.Server{
// Addr: ":https",
// TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
// }
// s.ListenAndServeTLS("", "")
//
// To preserve issued certificates and improve overall performance,
// use a cache implementation of Cache. For instance, DirCache.
type Manager struct {
@@ -123,7 +118,7 @@ type Manager struct {
// RenewBefore optionally specifies how early certificates should
// be renewed before they expire.
//
// If zero, they're renewed 1 week before expiration.
// If zero, they're renewed 30 days before expiration.
RenewBefore time.Duration
// Client is used to perform low-level operations, such as account registration
@@ -173,10 +168,23 @@ type Manager struct {
// The error is propagated back to the caller of GetCertificate and is user-visible.
// This does not affect cached certs. See HostPolicy field description for more details.
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if m.Prompt == nil {
return nil, errors.New("acme/autocert: Manager.Prompt not set")
}
name := hello.ServerName
if name == "" {
return nil, errors.New("acme/autocert: missing server name")
}
if !strings.Contains(strings.Trim(name, "."), ".") {
return nil, errors.New("acme/autocert: server name component count invalid")
}
if strings.ContainsAny(name, `/\`) {
return nil, errors.New("acme/autocert: server name contains invalid character")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") {
@@ -185,7 +193,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if cert := m.tokenCert[name]; cert != nil {
return cert, nil
}
if cert, err := m.cacheGet(name); err == nil {
if cert, err := m.cacheGet(ctx, name); err == nil {
return cert, nil
}
// TODO: cache error results?
@@ -194,7 +202,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// regular domain
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114
cert, err := m.cert(name)
cert, err := m.cert(ctx, name)
if err == nil {
return cert, nil
}
@@ -203,7 +211,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
}
// first-time
ctx := context.Background() // TODO: use a deadline?
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
@@ -211,14 +218,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if err != nil {
return nil, err
}
m.cachePut(name, cert)
m.cachePut(ctx, name, cert)
return cert, nil
}
// cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value.
func (m *Manager) cert(name string) (*tls.Certificate, error) {
func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) {
m.stateMu.Lock()
if s, ok := m.state[name]; ok {
m.stateMu.Unlock()
@@ -227,7 +234,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
return s.tlscert()
}
defer m.stateMu.Unlock()
cert, err := m.cacheGet(name)
cert, err := m.cacheGet(ctx, name)
if err != nil {
return nil, err
}
@@ -249,12 +256,11 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
}
// cacheGet always returns a valid certificate, or an error otherwise.
func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
// If a cached certficate exists but is not valid, ErrCacheMiss is returned.
func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) {
if m.Cache == nil {
return nil, ErrCacheMiss
}
// TODO: might want to define a cache timeout on m
ctx := context.Background()
data, err := m.Cache.Get(ctx, domain)
if err != nil {
return nil, err
@@ -263,7 +269,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
// private
priv, pub := pem.Decode(data)
if priv == nil || !strings.Contains(priv.Type, "PRIVATE") {
return nil, errors.New("acme/autocert: no private key found in cache")
return nil, ErrCacheMiss
}
privKey, err := parsePrivateKey(priv.Bytes)
if err != nil {
@@ -281,13 +287,14 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
pubDER = append(pubDER, b.Bytes)
}
if len(pub) > 0 {
return nil, errors.New("acme/autocert: invalid public key")
// Leftover content not consumed by pem.Decode. Corrupt. Ignore.
return nil, ErrCacheMiss
}
// verify and create TLS cert
leaf, err := validCert(domain, pubDER, privKey)
if err != nil {
return nil, err
return nil, ErrCacheMiss
}
tlscert := &tls.Certificate{
Certificate: pubDER,
@@ -297,7 +304,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
return tlscert, nil
}
func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error {
if m.Cache == nil {
return nil
}
@@ -329,8 +336,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
}
}
// TODO: might want to define a cache timeout on m
ctx := context.Background()
return m.Cache.Put(ctx, domain, buf.Bytes())
}
@@ -370,6 +375,23 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica
der, leaf, err := m.authorizedCert(ctx, state.key, domain)
if err != nil {
// Remove the failed state after some time,
// making the manager call createCert again on the following TLS hello.
time.AfterFunc(createCertRetryAfter, func() {
defer testDidRemoveState(domain)
m.stateMu.Lock()
defer m.stateMu.Unlock()
// Verify the state hasn't changed and it's still invalid
// before deleting.
s, ok := m.state[domain]
if !ok {
return
}
if _, err := validCert(domain, s.cert, s.key); err == nil {
return
}
delete(m.state, domain)
})
return nil, err
}
state.cert = der
@@ -418,7 +440,6 @@ func (m *Manager) certState(domain string) (*certState, error) {
// authorizedCert starts domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
// TODO: make m.verify retry or retry m.verify calls here
if err := m.verify(ctx, domain); err != nil {
return nil, nil, err
}
@@ -494,7 +515,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
if err != nil {
return err
}
m.putTokenCert(name, &cert)
m.putTokenCert(ctx, name, &cert)
defer func() {
// verification has ended at this point
// don't need token cert anymore
@@ -512,14 +533,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
// putTokenCert stores the cert under the named key in both m.tokenCert map
// and m.Cache.
func (m *Manager) putTokenCert(name string, cert *tls.Certificate) {
func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) {
m.tokenCertMu.Lock()
defer m.tokenCertMu.Unlock()
if m.tokenCert == nil {
m.tokenCert = make(map[string]*tls.Certificate)
}
m.tokenCert[name] = cert
m.cachePut(name, cert)
m.cachePut(ctx, name, cert)
}
// deleteTokenCert removes the token certificate for the specified domain name
@@ -644,10 +665,10 @@ func (m *Manager) hostPolicy() HostPolicy {
}
func (m *Manager) renewBefore() time.Duration {
if m.RenewBefore > maxRandRenew {
if m.RenewBefore > renewJitter {
return m.RenewBefore
}
return 7 * 24 * time.Hour // 1 week
return 720 * time.Hour // 30 days
}
// certState is ready when its mutex is unlocked for reading.
@@ -789,5 +810,10 @@ func (r *lockedMathRand) int63n(max int64) int64 {
return n
}
// for easier testing
var timeNow = time.Now
// For easier testing.
var (
timeNow = time.Now
// Called when a state is removed.
testDidRemoveState = func(domain string) {}
)

View File

@@ -5,6 +5,7 @@
package autocert
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -27,7 +28,6 @@ import (
"time"
"golang.org/x/crypto/acme"
"golang.org/x/net/context"
)
var discoTmpl = template.Must(template.New("disco").Parse(`{
@@ -150,7 +150,7 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
cert, err := man.cacheGet("example.org")
cert, err := man.cacheGet(context.Background(), "example.org")
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
@@ -159,9 +159,110 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
}
}
// tests man.GetCertificate flow using the provided hello argument.
func TestGetCertificate_nilPrompt(t *testing.T) {
man := &Manager{}
defer man.stopRenew()
url, finish := startACMEServerStub(t, man, "example.org")
defer finish()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man.Client = &acme.Client{
Key: key,
DirectoryURL: url,
}
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
if _, err := man.GetCertificate(hello); err == nil {
t.Error("got certificate for example.org; wanted error")
}
}
func TestGetCertificate_expiredCache(t *testing.T) {
// Make an expired cert and cache it.
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "example.org"},
NotAfter: time.Now(),
}
pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
if err != nil {
t.Fatal(err)
}
tlscert := &tls.Certificate{
Certificate: [][]byte{pub},
PrivateKey: pk,
}
man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
defer man.stopRenew()
if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
// The expired cached cert should trigger a new cert issuance
// and return without an error.
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
}
func TestGetCertificate_failedAttempt(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
const example = "example.org"
d := createCertRetryAfter
f := testDidRemoveState
defer func() {
createCertRetryAfter = d
testDidRemoveState = f
}()
createCertRetryAfter = 0
done := make(chan struct{})
testDidRemoveState = func(domain string) {
if domain != example {
t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
}
close(done)
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man := &Manager{
Prompt: AcceptTOS,
Client: &acme.Client{
Key: key,
DirectoryURL: ts.URL,
},
}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: example}
if _, err := man.GetCertificate(hello); err == nil {
t.Error("GetCertificate: err is nil")
}
select {
case <-time.After(5 * time.Second):
t.Errorf("took too long to remove the %q state", example)
case <-done:
man.stateMu.Lock()
defer man.stateMu.Unlock()
if v, exist := man.state[example]; exist {
t.Errorf("state exists for %q: %+v", example, v)
}
}
}
// startACMEServerStub runs an ACME server
// The domain argument is the expected domain name of a certificate request.
func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
// echo token-02 | shasum -a 256
// then divide result in 2 parts separated by dot
tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
@@ -187,7 +288,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
// discovery
case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil {
t.Fatalf("discoTmpl: %v", err)
t.Errorf("discoTmpl: %v", err)
}
// client key registration
case "/new-reg":
@@ -197,7 +298,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
w.Header().Set("location", ca.URL+"/authz/1")
w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil {
t.Fatalf("authzTmpl: %v", err)
t.Errorf("authzTmpl: %v", err)
}
// accept tls-sni-02 challenge
case "/challenge/2":
@@ -215,14 +316,14 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
csr, err := x509.ParseCertificateRequest(b)
if err != nil {
t.Fatalf("new-cert: CSR: %v", err)
t.Errorf("new-cert: CSR: %v", err)
}
if csr.Subject.CommonName != domain {
t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
}
der, err := dummyCert(csr.PublicKey, domain)
if err != nil {
t.Fatalf("new-cert: dummyCert: %v", err)
t.Errorf("new-cert: dummyCert: %v", err)
}
chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
w.Header().Set("link", chainUp)
@@ -232,14 +333,51 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
case "/ca-cert":
der, err := dummyCert(nil, "ca")
if err != nil {
t.Fatalf("ca-cert: dummyCert: %v", err)
t.Errorf("ca-cert: dummyCert: %v", err)
}
w.Write(der)
default:
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
}
}))
defer ca.Close()
finish = func() {
ca.Close()
// make sure token cert was removed
cancel := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
tick := time.NewTicker(100 * time.Millisecond)
defer tick.Stop()
for {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
if _, err := man.GetCertificate(hello); err != nil {
return
}
select {
case <-tick.C:
case <-cancel:
return
}
}
}()
select {
case <-done:
case <-time.After(5 * time.Second):
close(cancel)
t.Error("token cert was not removed")
<-done
}
}
return ca.URL, finish
}
// tests man.GetCertificate flow using the provided hello argument.
// The domain argument is the expected domain name of a certificate request.
func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
url, finish := startACMEServerStub(t, man, domain)
defer finish()
// use EC key to run faster on 386
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@@ -248,7 +386,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
}
man.Client = &acme.Client{
Key: key,
DirectoryURL: ca.URL,
DirectoryURL: url,
}
// simulate tls.Config.GetCertificate
@@ -279,23 +417,6 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
}
// make sure token cert was removed
done = make(chan struct{})
go func() {
for {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
if _, err := man.GetCertificate(hello); err != nil {
break
}
time.Sleep(100 * time.Millisecond)
}
close(done)
}()
select {
case <-time.After(5 * time.Second):
t.Error("token cert was not removed")
case <-done:
}
}
func TestAccountKeyCache(t *testing.T) {
@@ -335,10 +456,11 @@ func TestCache(t *testing.T) {
man := &Manager{Cache: newMemCache()}
defer man.stopRenew()
if err := man.cachePut("example.org", tlscert); err != nil {
ctx := context.Background()
if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
res, err := man.cacheGet("example.org")
res, err := man.cacheGet(ctx, "example.org")
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
@@ -438,3 +560,47 @@ func TestValidCert(t *testing.T) {
}
}
}
type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
return f(ctx, key)
}
func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
return fmt.Errorf("unsupported Put of %q = %q", key, data)
}
func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
return fmt.Errorf("unsupported Delete of %q", key)
}
func TestManagerGetCertificateBogusSNI(t *testing.T) {
m := Manager{
Prompt: AcceptTOS,
Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
return nil, fmt.Errorf("cache.Get of %s", key)
}),
}
tests := []struct {
name string
wantErr string
}{
{"foo.com", "cache.Get of foo.com"},
{"foo.com.", "cache.Get of foo.com"},
{`a\b.com`, "acme/autocert: server name contains invalid character"},
{`a/b.com`, "acme/autocert: server name contains invalid character"},
{"", "acme/autocert: missing server name"},
{"foo", "acme/autocert: server name component count invalid"},
{".foo", "acme/autocert: server name component count invalid"},
{"foo.", "acme/autocert: server name component count invalid"},
{"fo.o", "cache.Get of fo.o"},
}
for _, tt := range tests {
_, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
got := fmt.Sprint(err)
if got != tt.wantErr {
t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
}
}
}

View File

@@ -5,12 +5,11 @@
package autocert
import (
"context"
"errors"
"io/ioutil"
"os"
"path/filepath"
"golang.org/x/net/context"
)
// ErrCacheMiss is returned when a certificate is not found in cache.
@@ -78,12 +77,13 @@ func (d DirCache) Put(ctx context.Context, name string, data []byte) error {
if tmp, err = d.writeTempFile(name, data); err != nil {
return
}
// prevent overwriting the file if the context was cancelled
if ctx.Err() != nil {
return // no need to set err
select {
case <-ctx.Done():
// Don't overwrite the file if the context was canceled.
default:
newName := filepath.Join(string(d), name)
err = os.Rename(tmp, newName)
}
name = filepath.Join(string(d), name)
err = os.Rename(tmp, name)
}()
select {
case <-ctx.Done():

View File

@@ -5,13 +5,12 @@
package autocert
import (
"context"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"golang.org/x/net/context"
)
// make sure DirCache satisfies Cache interface

View File

@@ -0,0 +1,34 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert_test
import (
"crypto/tls"
"fmt"
"log"
"net/http"
"golang.org/x/crypto/acme/autocert"
)
func ExampleNewListener() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS)
})
log.Fatal(http.Serve(autocert.NewListener("example.com"), mux))
}
func ExampleManager() {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.org"),
}
s := &http.Server{
Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
}
s.ListenAndServeTLS("", "")
}

153
vendor/golang.org/x/crypto/acme/autocert/listener.go generated vendored Normal file
View File

@@ -0,0 +1,153 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"crypto/tls"
"log"
"net"
"os"
"path/filepath"
"runtime"
"time"
)
// NewListener returns a net.Listener that listens on the standard TLS
// port (443) on all interfaces and returns *tls.Conn connections with
// LetsEncrypt certificates for the provided domain or domains.
//
// It enables one-line HTTPS servers:
//
// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler))
//
// NewListener is a convenience function for a common configuration.
// More complex or custom configurations can use the autocert.Manager
// type instead.
//
// Use of this function implies acceptance of the LetsEncrypt Terms of
// Service. If domains is not empty, the provided domains are passed
// to HostWhitelist. If domains is empty, the listener will do
// LetsEncrypt challenges for any requested domain, which is not
// recommended.
//
// Certificates are cached in a "golang-autocert" directory under an
// operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
func NewListener(domains ...string) net.Listener {
m := &Manager{
Prompt: AcceptTOS,
}
if len(domains) > 0 {
m.HostPolicy = HostWhitelist(domains...)
}
dir := cacheDir()
if err := os.MkdirAll(dir, 0700); err != nil {
log.Printf("warning: autocert.NewListener not using a cache: %v", err)
} else {
m.Cache = DirCache(dir)
}
return m.Listener()
}
// Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
//
// Unlike NewListener, it is the caller's responsibility to initialize
// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
func (m *Manager) Listener() net.Listener {
ln := &listener{
m: m,
conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m
},
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln
}
type listener struct {
m *Manager
conf *tls.Config
tcpListener net.Listener
tcpListenErr error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.tcpListenErr != nil {
return nil, ln.tcpListenErr
}
conn, err := ln.tcpListener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
// Because Listener is a convenience function, help out with
// this too. This is not possible for the caller to set once
// we return a *tcp.Conn wrapping an inaccessible net.Conn.
// If callers don't want this, they can do things the manual
// way and tweak as needed. But this is what net/http does
// itself, so copy that. If net/http changes, we can change
// here too.
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(3 * time.Minute)
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Addr() net.Addr {
if ln.tcpListener != nil {
return ln.tcpListener.Addr()
}
// net.Listen failed. Return something non-nil in case callers
// call Addr before Accept:
return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}
func (ln *listener) Close() error {
if ln.tcpListenErr != nil {
return ln.tcpListenErr
}
return ln.tcpListener.Close()
}
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string {
const base = "golang-autocert"
switch runtime.GOOS {
case "darwin":
return filepath.Join(homeDir(), "Library", "Caches", base)
case "windows":
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
}
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base)
}
return filepath.Join(homeDir(), ".cache", base)
}

View File

@@ -5,15 +5,14 @@
package autocert
import (
"context"
"crypto"
"sync"
"time"
"golang.org/x/net/context"
)
// maxRandRenew is a maximum deviation from Manager.RenewBefore.
const maxRandRenew = time.Hour
// renewJitter is the maximum deviation from Manager.RenewBefore.
const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
@@ -65,7 +64,7 @@ func (dr *domainRenewal) renew() {
// TODO: rotate dr.key at some point?
next, err := dr.do(ctx)
if err != nil {
next = maxRandRenew / 2
next = renewJitter / 2
next += time.Duration(pseudoRand.int63n(int64(next)))
}
dr.timer = time.AfterFunc(next, dr.renew)
@@ -83,9 +82,9 @@ func (dr *domainRenewal) renew() {
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
if tlscert, err := dr.m.cacheGet(dr.domain); err == nil {
if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+maxRandRenew {
if next > dr.m.renewBefore()+renewJitter {
return next, nil
}
}
@@ -103,7 +102,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
if err != nil {
return 0, err
}
dr.m.cachePut(dr.domain, tlscert)
dr.m.cachePut(ctx, dr.domain, tlscert)
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
// m.state is guaranteed to be non-nil at this point
@@ -114,7 +113,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(maxRandRenew))
n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)
if d < 0 {
return 0

View File

@@ -5,6 +5,7 @@
package autocert
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@@ -31,7 +32,7 @@ func TestRenewalNext(t *testing.T) {
expiry time.Time
min, max time.Duration
}{
{now.Add(90 * 24 * time.Hour), 83*24*time.Hour - maxRandRenew, 83 * 24 * time.Hour},
{now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour},
{now.Add(time.Hour), 0, 1},
{now, 0, 1},
{now.Add(-time.Hour), 0, 1},
@@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) {
t.Fatal(err)
}
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
if err := man.cachePut(domain, tlscert); err != nil {
if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
t.Fatal(err)
}
@@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) {
// ensure the new cert is cached
after := time.Now().Add(future)
tlscert, err := man.cacheGet(domain)
tlscert, err := man.cacheGet(context.Background(), domain)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}

View File

@@ -134,7 +134,7 @@ func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
return "ES256", crypto.SHA256
case "P-384":
return "ES384", crypto.SHA384
case "P-512":
case "P-521":
return "ES512", crypto.SHA512
}
}

View File

@@ -12,11 +12,13 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"testing"
)
const testKeyPEM = `
const (
testKeyPEM = `
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq
WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30
@@ -46,10 +48,9 @@ EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO
-----END RSA PRIVATE KEY-----
`
// This thumbprint is for the testKey defined above.
const testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ"
// This thumbprint is for the testKey defined above.
testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ"
const (
// openssl ecparam -name secp256k1 -genkey -noout
testKeyECPEM = `
-----BEGIN EC PRIVATE KEY-----
@@ -58,39 +59,78 @@ AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5
QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ==
-----END EC PRIVATE KEY-----
`
// 1. opnessl ec -in key.pem -noout -text
// openssl ecparam -name secp384r1 -genkey -noout
testKeyEC384PEM = `
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD
Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj
JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke
WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg=
-----END EC PRIVATE KEY-----
`
// openssl ecparam -name secp521r1 -genkey -noout
testKeyEC512PEM = `
-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z
KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx
7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD
FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd
GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ==
-----END EC PRIVATE KEY-----
`
// 1. openssl ec -in key.pem -noout -text
// 2. remove first byte, 04 (the header); the rest is X and Y
// 3. covert each with: echo <val> | xxd -r -p | base64 | tr -d '=' | tr '/+' '_-'
testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ"
testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk"
// 3. convert each with: echo <val> | xxd -r -p | base64 -w 100 | tr -d '=' | tr '/+' '_-'
testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ"
testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk"
testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt"
testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo"
testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY"
testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax"
// echo -n '{"crv":"P-256","kty":"EC","x":"<testKeyECPubX>","y":"<testKeyECPubY>"}' | \
// openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '/+' '_-'
testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU"
)
var (
testKey *rsa.PrivateKey
testKeyEC *ecdsa.PrivateKey
testKey *rsa.PrivateKey
testKeyEC *ecdsa.PrivateKey
testKeyEC384 *ecdsa.PrivateKey
testKeyEC512 *ecdsa.PrivateKey
)
func init() {
d, _ := pem.Decode([]byte(testKeyPEM))
if d == nil {
panic("no block found in testKeyPEM")
}
var err error
testKey, err = x509.ParsePKCS1PrivateKey(d.Bytes)
if err != nil {
panic(err.Error())
}
testKey = parseRSA(testKeyPEM, "testKeyPEM")
testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM")
testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM")
testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM")
}
if d, _ = pem.Decode([]byte(testKeyECPEM)); d == nil {
panic("no block found in testKeyECPEM")
func decodePEM(s, name string) []byte {
d, _ := pem.Decode([]byte(s))
if d == nil {
panic("no block found in " + name)
}
testKeyEC, err = x509.ParseECPrivateKey(d.Bytes)
return d.Bytes
}
func parseRSA(s, name string) *rsa.PrivateKey {
b := decodePEM(s, name)
k, err := x509.ParsePKCS1PrivateKey(b)
if err != nil {
panic(err.Error())
panic(fmt.Sprintf("%s: %v", name, err))
}
return k
}
func parseEC(s, name string) *ecdsa.PrivateKey {
b := decodePEM(s, name)
k, err := x509.ParseECPrivateKey(b)
if err != nil {
panic(fmt.Sprintf("%s: %v", name, err))
}
return k
}
func TestJWSEncodeJSON(t *testing.T) {
@@ -141,50 +181,63 @@ func TestJWSEncodeJSON(t *testing.T) {
}
func TestJWSEncodeJSONEC(t *testing.T) {
claims := struct{ Msg string }{"Hello JWS"}
tt := []struct {
key *ecdsa.PrivateKey
x, y string
alg, crv string
}{
{testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"},
{testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"},
{testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"},
}
for i, test := range tt {
claims := struct{ Msg string }{"Hello JWS"}
b, err := jwsEncodeJSON(claims, test.key, "nonce")
if err != nil {
t.Errorf("%d: %v", i, err)
continue
}
var jws struct{ Protected, Payload, Signature string }
if err := json.Unmarshal(b, &jws); err != nil {
t.Errorf("%d: %v", i, err)
continue
}
b, err := jwsEncodeJSON(claims, testKeyEC, "nonce")
if err != nil {
t.Fatal(err)
}
var jws struct{ Protected, Payload, Signature string }
if err := json.Unmarshal(b, &jws); err != nil {
t.Fatal(err)
}
if b, err = base64.RawURLEncoding.DecodeString(jws.Protected); err != nil {
t.Fatalf("jws.Protected: %v", err)
}
var head struct {
Alg string
Nonce string
JWK struct {
Crv string
Kty string
X string
Y string
} `json:"jwk"`
}
if err := json.Unmarshal(b, &head); err != nil {
t.Fatalf("jws.Protected: %v", err)
}
if head.Alg != "ES256" {
t.Errorf("head.Alg = %q; want ES256", head.Alg)
}
if head.Nonce != "nonce" {
t.Errorf("head.Nonce = %q; want nonce", head.Nonce)
}
if head.JWK.Crv != "P-256" {
t.Errorf("head.JWK.Crv = %q; want P-256", head.JWK.Crv)
}
if head.JWK.Kty != "EC" {
t.Errorf("head.JWK.Kty = %q; want EC", head.JWK.Kty)
}
if head.JWK.X != testKeyECPubX {
t.Errorf("head.JWK.X = %q; want %q", head.JWK.X, testKeyECPubX)
}
if head.JWK.Y != testKeyECPubY {
t.Errorf("head.JWK.Y = %q; want %q", head.JWK.Y, testKeyECPubY)
b, err = base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil {
t.Errorf("%d: jws.Protected: %v", i, err)
}
var head struct {
Alg string
Nonce string
JWK struct {
Crv string
Kty string
X string
Y string
} `json:"jwk"`
}
if err := json.Unmarshal(b, &head); err != nil {
t.Errorf("%d: jws.Protected: %v", i, err)
}
if head.Alg != test.alg {
t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg)
}
if head.Nonce != "nonce" {
t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce)
}
if head.JWK.Crv != test.crv {
t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv)
}
if head.JWK.Kty != "EC" {
t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty)
}
if head.JWK.X != test.x {
t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x)
}
if head.JWK.Y != test.y {
t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y)
}
}
}

View File

@@ -1,9 +1,15 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"errors"
"fmt"
"net/http"
"strings"
"time"
)
// ACME server response statuses used to describe Authorization and Challenge states.
@@ -33,14 +39,8 @@ const (
CRLReasonAACompromise CRLReasonCode = 10
)
var (
// ErrAuthorizationFailed indicates that an authorization for an identifier
// did not succeed.
ErrAuthorizationFailed = errors.New("acme: identifier authorization failed")
// ErrUnsupportedKey is returned when an unsupported key type is encountered.
ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
)
// ErrUnsupportedKey is returned when an unsupported key type is encountered.
var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
// Error is an ACME error, defined in Problem Details for HTTP APIs doc
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.
@@ -53,6 +53,7 @@ type Error struct {
// Detail is a human-readable explanation specific to this occurrence of the problem.
Detail string
// Header is the original server error response headers.
// It may be nil.
Header http.Header
}
@@ -60,6 +61,50 @@ func (e *Error) Error() string {
return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail)
}
// AuthorizationError indicates that an authorization for an identifier
// did not succeed.
// It contains all errors from Challenge items of the failed Authorization.
type AuthorizationError struct {
// URI uniquely identifies the failed Authorization.
URI string
// Identifier is an AuthzID.Value of the failed Authorization.
Identifier string
// Errors is a collection of non-nil error values of Challenge items
// of the failed Authorization.
Errors []error
}
func (a *AuthorizationError) Error() string {
e := make([]string, len(a.Errors))
for i, err := range a.Errors {
e[i] = err.Error()
}
return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; "))
}
// RateLimit reports whether err represents a rate limit error and
// any Retry-After duration returned by the server.
//
// See the following for more details on rate limiting:
// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6
func RateLimit(err error) (time.Duration, bool) {
e, ok := err.(*Error)
if !ok {
return 0, false
}
// Some CA implementations may return incorrect values.
// Use case-insensitive comparison.
if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") {
return 0, false
}
if e.Header == nil {
return 0, true
}
return retryAfter(e.Header.Get("Retry-After"), 0), true
}
// Account is a user account. It is associated with a private key.
type Account struct {
// URI is the account unique ID, which is also a URL used to retrieve
@@ -118,6 +163,8 @@ type Directory struct {
}
// Challenge encodes a returned CA challenge.
// Its Error field may be non-nil if the challenge is part of an Authorization
// with StatusInvalid.
type Challenge struct {
// Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01".
Type string
@@ -130,6 +177,11 @@ type Challenge struct {
// Status identifies the status of this challenge.
Status string
// Error indicates the reason for an authorization failure
// when this challenge was used.
// The type of a non-nil value is *Error.
Error error
}
// Authorization encodes an authorization response.
@@ -187,12 +239,26 @@ func (z *wireAuthz) authorization(uri string) *Authorization {
return a
}
func (z *wireAuthz) error(uri string) *AuthorizationError {
err := &AuthorizationError{
URI: uri,
Identifier: z.Identifier.Value,
}
for _, raw := range z.Challenges {
if raw.Error != nil {
err.Errors = append(err.Errors, raw.Error.error(nil))
}
}
return err
}
// wireChallenge is ACME JSON challenge representation.
type wireChallenge struct {
URI string `json:"uri"`
Type string
Token string
Status string
Error *wireError
}
func (c *wireChallenge) challenge() *Challenge {
@@ -205,5 +271,25 @@ func (c *wireChallenge) challenge() *Challenge {
if v.Status == "" {
v.Status = StatusPending
}
if c.Error != nil {
v.Error = c.Error.error(nil)
}
return v
}
// wireError is a subset of fields of the Problem Details object
// as described in https://tools.ietf.org/html/rfc7807#section-3.1.
type wireError struct {
Status int
Type string
Detail string
}
func (e *wireError) error(h http.Header) *Error {
return &Error{
StatusCode: e.Status,
ProblemType: e.Type,
Detail: e.Detail,
Header: h,
}
}

63
vendor/golang.org/x/crypto/acme/types_test.go generated vendored Normal file
View File

@@ -0,0 +1,63 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"errors"
"net/http"
"testing"
"time"
)
func TestRateLimit(t *testing.T) {
now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC)
f := timeNow
defer func() { timeNow = f }()
timeNow = func() time.Time { return now }
h120, hTime := http.Header{}, http.Header{}
h120.Set("Retry-After", "120")
hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017")
err1 := &Error{
ProblemType: "urn:ietf:params:acme:error:nolimit",
Header: h120,
}
err2 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: h120,
}
err3 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: nil,
}
err4 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: hTime,
}
tt := []struct {
err error
res time.Duration
ok bool
}{
{nil, 0, false},
{errors.New("dummy"), 0, false},
{err1, 0, false},
{err2, 2 * time.Minute, true},
{err3, 0, true},
{err4, time.Hour, true},
}
for i, test := range tt {
res, ok := RateLimit(test.err)
if ok != test.ok {
t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok)
continue
}
if res != test.res {
t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res)
}
}
}

View File

@@ -4,7 +4,7 @@
// Package blake2b implements the BLAKE2b hash algorithm as
// defined in RFC 7693.
package blake2b
package blake2b // import "golang.org/x/crypto/blake2b"
import (
"encoding/binary"

View File

@@ -22,7 +22,7 @@ func fromHex(s string) []byte {
func TestHashes(t *testing.T) {
defer func(sse4, avx, avx2 bool) {
useSSE4, useAVX, useAVX2 = sse4, useAVX, avx2
useSSE4, useAVX, useAVX2 = sse4, avx, avx2
}(useSSE4, useAVX, useAVX2)
if useAVX2 {

View File

@@ -4,7 +4,7 @@
// Package blake2s implements the BLAKE2s hash algorithm as
// defined in RFC 7693.
package blake2s
package blake2s // import "golang.org/x/crypto/blake2s"
import (
"encoding/binary"

View File

@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539.
package chacha20poly1305
package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305"
import (
"crypto/cipher"

View File

@@ -14,13 +14,60 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
//go:noescape
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
//go:noescape
func haveSSSE3() bool
// cpuid is implemented in chacha20poly1305_amd64.s.
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
var canUseASM bool
// xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s.
func xgetbv() (eax, edx uint32)
var (
useASM bool
useAVX2 bool
)
func init() {
canUseASM = haveSSSE3()
detectCPUFeatures()
}
// detectCPUFeatures is used to detect if cpu instructions
// used by the functions implemented in assembler in
// chacha20poly1305_amd64.s are supported.
func detectCPUFeatures() {
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
_, _, ecx1, _ := cpuid(1, 0)
haveSSSE3 := isSet(9, ecx1)
useASM = haveSSSE3
haveOSXSAVE := isSet(27, ecx1)
osSupportsAVX := false
// For XGETBV, OSXSAVE bit is required and sufficient.
if haveOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(1, eax) && isSet(2, eax)
}
haveAVX := isSet(28, ecx1) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, _, _ := cpuid(7, 0)
haveAVX2 := isSet(5, ebx7) && haveAVX
haveBMI2 := isSet(8, ebx7)
useAVX2 = haveAVX2 && haveBMI2
}
// isSet checks if bit at bitpos is set in value.
func isSet(bitpos uint, value uint32) bool {
return value&(1<<bitpos) != 0
}
// setupState writes a ChaCha20 input matrix to state. See
@@ -47,7 +94,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
}
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
if !canUseASM {
if !useASM {
return c.sealGeneric(dst, nonce, plaintext, additionalData)
}
@@ -60,7 +107,7 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
}
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if !canUseASM {
if !useASM {
return c.openGeneric(dst, nonce, ciphertext, additionalData)
}

View File

@@ -278,15 +278,8 @@ TEXT ·chacha20Poly1305Open(SB), 0, $288-97
MOVQ ad+72(FP), adp
// Check for AVX2 support
CMPB runtime·support_avx2(SB), $0
JE noavx2bmi2Open
// Check BMI2 bit for MULXQ.
// runtime·cpuid_ebx7 is always available here
// because it passed avx2 check
TESTL $(1<<8), runtime·cpuid_ebx7(SB)
JNE chacha20Poly1305Open_AVX2
noavx2bmi2Open:
CMPB ·useAVX2(SB), $1
JE chacha20Poly1305Open_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
@@ -1491,16 +1484,8 @@ TEXT ·chacha20Poly1305Seal(SB), 0, $288-96
MOVQ src_len+56(FP), inl
MOVQ ad+72(FP), adp
// Check for AVX2 support
CMPB runtime·support_avx2(SB), $0
JE noavx2bmi2Seal
// Check BMI2 bit for MULXQ.
// runtime·cpuid_ebx7 is always available here
// because it passed avx2 check
TESTL $(1<<8), runtime·cpuid_ebx7(SB)
JNE chacha20Poly1305Seal_AVX2
noavx2bmi2Seal:
CMPB ·useAVX2(SB), $1
JE chacha20Poly1305Seal_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
@@ -2709,13 +2694,21 @@ sealAVX2Tail512LoopB:
JMP sealAVX2SealHash
// func haveSSSE3() bool
TEXT ·haveSSSE3(SB), NOSPLIT, $0
XORQ AX, AX
INCL AX
// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
TEXT ·cpuid(SB), NOSPLIT, $0-24
MOVL eaxArg+0(FP), AX
MOVL ecxArg+4(FP), CX
CPUID
SHRQ $9, CX
ANDQ $1, CX
MOVB CX, ret+0(FP)
MOVL AX, eax+8(FP)
MOVL BX, ebx+12(FP)
MOVL CX, ecx+16(FP)
MOVL DX, edx+20(FP)
RET
// func xgetbv() (eax, edx uint32)
TEXT ·xgetbv(SB),NOSPLIT,$0-8
MOVL $0, CX
XGETBV
MOVL AX, eax+0(FP)
MOVL DX, edx+4(FP)
RET

View File

@@ -5,7 +5,7 @@
// Package cryptobyte implements building and parsing of byte strings for
// DER-encoded ASN.1 and TLS messages. See the examples for the Builder and
// String types to get started.
package cryptobyte
package cryptobyte // import "golang.org/x/crypto/cryptobyte"
// String represents a string of bytes. It provides methods for parsing
// fixed-length and length-prefixed values from it.

View File

@@ -2,87 +2,64 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine
// func cswap(inout *[5]uint64, v uint64)
// func cswap(inout *[4][5]uint64, v uint64)
TEXT ·cswap(SB),7,$0
MOVQ inout+0(FP),DI
MOVQ v+8(FP),SI
CMPQ SI,$1
MOVQ 0(DI),SI
MOVQ 80(DI),DX
MOVQ 8(DI),CX
MOVQ 88(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,0(DI)
MOVQ DX,80(DI)
MOVQ CX,8(DI)
MOVQ R8,88(DI)
MOVQ 16(DI),SI
MOVQ 96(DI),DX
MOVQ 24(DI),CX
MOVQ 104(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,16(DI)
MOVQ DX,96(DI)
MOVQ CX,24(DI)
MOVQ R8,104(DI)
MOVQ 32(DI),SI
MOVQ 112(DI),DX
MOVQ 40(DI),CX
MOVQ 120(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,32(DI)
MOVQ DX,112(DI)
MOVQ CX,40(DI)
MOVQ R8,120(DI)
MOVQ 48(DI),SI
MOVQ 128(DI),DX
MOVQ 56(DI),CX
MOVQ 136(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,48(DI)
MOVQ DX,128(DI)
MOVQ CX,56(DI)
MOVQ R8,136(DI)
MOVQ 64(DI),SI
MOVQ 144(DI),DX
MOVQ 72(DI),CX
MOVQ 152(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,64(DI)
MOVQ DX,144(DI)
MOVQ CX,72(DI)
MOVQ R8,152(DI)
MOVQ DI,AX
MOVQ SI,DX
SUBQ $1, SI
NOTQ SI
MOVQ SI, X15
PSHUFD $0x44, X15, X15
MOVOU 0(DI), X0
MOVOU 16(DI), X2
MOVOU 32(DI), X4
MOVOU 48(DI), X6
MOVOU 64(DI), X8
MOVOU 80(DI), X1
MOVOU 96(DI), X3
MOVOU 112(DI), X5
MOVOU 128(DI), X7
MOVOU 144(DI), X9
MOVO X1, X10
MOVO X3, X11
MOVO X5, X12
MOVO X7, X13
MOVO X9, X14
PXOR X0, X10
PXOR X2, X11
PXOR X4, X12
PXOR X6, X13
PXOR X8, X14
PAND X15, X10
PAND X15, X11
PAND X15, X12
PAND X15, X13
PAND X15, X14
PXOR X10, X0
PXOR X10, X1
PXOR X11, X2
PXOR X11, X3
PXOR X12, X4
PXOR X12, X5
PXOR X13, X6
PXOR X13, X7
PXOR X14, X8
PXOR X14, X9
MOVOU X0, 0(DI)
MOVOU X2, 16(DI)
MOVOU X4, 32(DI)
MOVOU X6, 48(DI)
MOVOU X8, 64(DI)
MOVOU X1, 80(DI)
MOVOU X3, 96(DI)
MOVOU X5, 112(DI)
MOVOU X7, 128(DI)
MOVOU X9, 144(DI)
RET

View File

@@ -8,6 +8,10 @@
package curve25519
import (
"encoding/binary"
)
// This code is a port of the public domain, "ref10" implementation of
// curve25519 from SUPERCOP 20130419 by D. J. Bernstein.
@@ -50,17 +54,11 @@ func feCopy(dst, src *fieldElement) {
//
// Preconditions: b in {0,1}.
func feCSwap(f, g *fieldElement, b int32) {
var x fieldElement
b = -b
for i := range x {
x[i] = b & (f[i] ^ g[i])
}
for i := range f {
f[i] ^= x[i]
}
for i := range g {
g[i] ^= x[i]
t := b & (f[i] ^ g[i])
f[i] ^= t
g[i] ^= t
}
}
@@ -75,12 +73,7 @@ func load3(in []byte) int64 {
// load4 reads a 32-bit, little-endian value from in.
func load4(in []byte) int64 {
var r int64
r = int64(in[0])
r |= int64(in[1]) << 8
r |= int64(in[2]) << 16
r |= int64(in[3]) << 24
return r
return int64(binary.LittleEndian.Uint32(in))
}
func feFromBytes(dst *fieldElement, src *[32]byte) {

View File

@@ -27,3 +27,13 @@ func TestBaseScalarMult(t *testing.T) {
t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
}
}
func BenchmarkScalarBaseMult(b *testing.B) {
var in, out [32]byte
in[0] = 1
b.SetBytes(32)
for i := 0; i < b.N; i++ {
ScalarBaseMult(&out, &in)
}
}

View File

@@ -180,9 +180,12 @@ func TestCert(t *testing.T) {
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", ":0")
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@@ -200,6 +203,9 @@ func netPipe() (net.Conn, net.Conn, error) {
}
func TestAuth(t *testing.T) {
agent, _, cleanup := startAgent(t)
defer cleanup()
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@@ -208,9 +214,6 @@ func TestAuth(t *testing.T) {
defer a.Close()
defer b.Close()
agent, _, cleanup := startAgent(t)
defer cleanup()
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
t.Errorf("Add: %v", err)
}

View File

@@ -251,10 +251,18 @@ type CertChecker struct {
// for user certificates.
SupportedCriticalOptions []string
// IsAuthority should return true if the key is recognized as
// an authority. This allows for certificates to be signed by other
// certificates.
IsAuthority func(auth PublicKey) bool
// IsUserAuthority should return true if the key is recognized as an
// authority for the given user certificate. This allows for
// certificates to be signed by other certificates. This must be set
// if this CertChecker will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now
// is used.
@@ -356,7 +364,13 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
}
}
if !c.IsAuthority(cert.SignatureKey) {
// if this is a host cert, principal is the remote hostname as passed
// to CheckHostCert.
if cert.CertType == HostCert && !c.IsHostAuthority(cert.SignatureKey, principal) {
return fmt.Errorf("ssh: no authorities for hostname: %v", principal)
}
if cert.CertType == UserCert && !c.IsUserAuthority(cert.SignatureKey) {
return fmt.Errorf("ssh: certificate signed by unrecognized authority")
}

View File

@@ -104,7 +104,7 @@ func TestValidateCert(t *testing.T) {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
checker.IsAuthority = func(k PublicKey) bool {
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
@@ -142,7 +142,7 @@ func TestValidateCertTime(t *testing.T) {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
checker.IsAuthority = func(k PublicKey) bool {
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
@@ -160,7 +160,7 @@ func TestValidateCertTime(t *testing.T) {
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain"},
ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
@@ -168,8 +168,8 @@ func TestHostKeyCert(t *testing.T) {
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
IsAuthority: func(p PublicKey) bool {
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
IsHostAuthority: func(p PublicKey, h string) bool {
return h == "hostname" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
@@ -178,7 +178,7 @@ func TestHostKeyCert(t *testing.T) {
t.Errorf("NewCertSigner: %v", err)
}
for _, name := range []string{"hostname", "otherhost"} {
for _, name := range []string{"hostname", "otherhost", "lasthost"} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)

View File

@@ -14,7 +14,7 @@ import (
)
// Client implements a traditional SSH client that supports shells,
// subprocesses, port forwarding and tunneled dialing.
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct {
Conn
@@ -60,6 +60,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn
}

View File

@@ -179,31 +179,26 @@ func (cb publicKeyCallback) method() string {
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
// Authentication is performed by sending an enquiry to test if a key is
// acceptable to the remote. If the key is acceptable, the client will
// attempt to authenticate with the valid key. If not the client will repeat
// the process with the remaining keys.
signers, err := cb()
if err != nil {
return false, nil, err
}
var validKeys []Signer
for _, signer := range signers {
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
}
// methods that may continue if this auth is not successful.
var methods []string
for _, signer := range validKeys {
pub := signer.PublicKey()
for _, signer := range signers {
ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return false, nil, err
}
if !ok {
continue
}
pub := signer.PublicKey()
pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
@@ -236,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil {
return false, nil, err
}
if success {
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success || !containsMethod(methods, cb.method()) {
return success, methods, err
}
}
return false, methods, nil
}
func containsMethod(methods []string, method string) bool {
for _, m := range methods {
if m == method {
return true
}
}
return false
}
// validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubKey := key.Marshal()

View File

@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
defer c2.Close()
certChecker := CertChecker{
IsAuthority: func(k PublicKey) bool {
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
@@ -76,8 +76,6 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
}
return nil, errors.New("keyboard-interactive failed")
},
AuthLogCallback: func(conn ConnMetadata, method string, err error) {
},
}
serverConfig.AddHostKey(testSigners["rsa"])
@@ -482,3 +480,100 @@ func TestClientAuthNone(t *testing.T) {
t.Fatalf("server: got %q, want %q", serverConn.User(), user)
}
}
// Test if authentication attempts are limited on server when MaxAuthTries is set
func TestClientAuthMaxAuthTries(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
MaxAuthTries: 2,
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == "right" {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
for tries := 2; tries < 4; tries++ {
n := tries
clientConfig := &ClientConfig{
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
n--
if n == 0 {
return "right", nil
} else {
return "wrong", nil
}
}), tries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", clientConfig)
if tries > 2 {
if err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
} else {
if err != nil {
t.Fatalf("client: got %s, want no error", err)
}
}
}
}
// Test if authentication attempts are correctly limited on server
// when more public keys are provided then MaxAuthTries
func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
signers := []Signer{}
for i := 0; i < 6; i++ {
signers = append(signers, testSigners["dsa"])
}
validConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, validConfig); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
invalidConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append(signers, testSigners["rsa"])...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, invalidConfig); err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
}

View File

@@ -72,7 +72,7 @@ func TestHostKeyCheck(t *testing.T) {
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
t.Errorf("%s: got error %q, missing %q", err.Error(), tt.wantError)
t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
}
} else if tt.wantError != "" {
t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)

View File

@@ -9,6 +9,7 @@ import (
"crypto/rand"
"fmt"
"io"
"math"
"sync"
_ "crypto/sha1"
@@ -186,7 +187,7 @@ type Config struct {
// The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If
// unspecified, 1 gigabyte is used.
// unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a
@@ -230,11 +231,12 @@ func (c *Config) SetDefaults() {
}
if c.RekeyThreshold == 0 {
// RFC 4253, section 9 suggests rekeying after 1G.
c.RekeyThreshold = 1 << 30
}
if c.RekeyThreshold < minRekeyThreshold {
// cipher specific default
} else if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold
} else if c.RekeyThreshold >= math.MaxInt64 {
// Avoid weirdness if somebody uses -1 as a threshold.
c.RekeyThreshold = math.MaxInt64
}
}

View File

@@ -107,6 +107,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
config: config,
}
t.resetReadThresholds()
t.resetWriteThresholds()
// We always start with a mandatory key exchange.
t.requestKex <- struct{}{}
@@ -237,6 +239,17 @@ func (t *handshakeTransport) requestKeyExchange() {
}
}
func (t *handshakeTransport) resetWriteThresholds() {
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) kexLoop() {
write:
@@ -285,12 +298,8 @@ write:
t.writeError = err
t.sentInitPacket = nil
t.sentInitMsg = nil
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
}
t.resetWriteThresholds()
// we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out
@@ -344,6 +353,17 @@ write:
// key exchange itself.
const packetRekeyThreshold = (1 << 31)
func (t *handshakeTransport) resetReadThresholds() {
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket()
if err != nil {
@@ -391,12 +411,7 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
return nil, err
}
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
}
t.resetReadThresholds()
// By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore.

View File

@@ -40,9 +40,12 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", ":0")
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@@ -526,3 +529,31 @@ func TestDisconnect(t *testing.T) {
t.Errorf("readPacket 3 succeeded")
}
}
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{
Ciphers: []string{"aes128-ctr"},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
trC.Close()
rgb := (1024 + trC.readBytesLeft) >> 30
wgb := (1024 + trC.writeBytesLeft) >> 30
if rgb != 64 {
t.Errorf("got rekey after %dG read, want 64G", rgb)
}
if wgb != 64 {
t.Errorf("got rekey after %dG write, want 64G", wgb)
}
}

View File

@@ -824,7 +824,7 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
// Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
magic := append([]byte("openssh-key-v1"), 0)
if !bytes.Equal(magic, key[0:len(magic)]) {
return nil, errors.New("ssh: invalid openssh private key format")
@@ -844,14 +844,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, err
}
if w.KdfName != "none" || w.CipherName != "none" {
return nil, errors.New("ssh: cannot decode encrypted private keys")
}
pk1 := struct {
Check1 uint32
Check2 uint32
Keytype string
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
Rest []byte `ssh:"rest"`
}{}
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
@@ -862,24 +863,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, errors.New("ssh: checkint mismatch")
}
// we only handle ed25519 keys currently
if pk1.Keytype != KeyAlgoED25519 {
// we only handle ed25519 and rsa keys currently
switch pk1.Keytype {
case KeyAlgoRSA:
// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
key := struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N,
E: int(key.E.Int64()),
},
D: key.D,
Primes: []*big.Int{key.P, key.Q},
}
if err := pk.Validate(); err != nil {
return nil, err
}
pk.Precompute()
return pk, nil
case KeyAlgoED25519:
key := struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if len(key.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, key.Priv)
return &pk, nil
default:
return nil, errors.New("ssh: unhandled key type")
}
for i, b := range pk1.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
if len(pk1.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, pk1.Priv)
return &pk, nil
}
// FingerprintLegacyMD5 returns the user presentation of the key's

546
vendor/golang.org/x/crypto/ssh/knownhosts/knownhosts.go generated vendored Normal file
View File

@@ -0,0 +1,546 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package knownhosts implements a parser for the OpenSSH
// known_hosts host key database.
package knownhosts
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"golang.org/x/crypto/ssh"
)
// See the sshd manpage
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
// background.
type addr struct{ host, port string }
func (a *addr) String() string {
h := a.host
if strings.Contains(h, ":") {
h = "[" + h + "]"
}
return h + ":" + a.port
}
type matcher interface {
match([]addr) bool
}
type hostPattern struct {
negate bool
addr addr
}
func (p *hostPattern) String() string {
n := ""
if p.negate {
n = "!"
}
return n + p.addr.String()
}
type hostPatterns []hostPattern
func (ps hostPatterns) match(addrs []addr) bool {
matched := false
for _, p := range ps {
for _, a := range addrs {
m := p.match(a)
if !m {
continue
}
if p.negate {
return false
}
matched = true
}
}
return matched
}
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
func wildcardMatch(pat []byte, str []byte) bool {
for {
if len(pat) == 0 {
return len(str) == 0
}
if len(str) == 0 {
return false
}
if pat[0] == '*' {
if len(pat) == 1 {
return true
}
for j := range str {
if wildcardMatch(pat[1:], str[j:]) {
return true
}
}
return false
}
if pat[0] == '?' || pat[0] == str[0] {
pat = pat[1:]
str = str[1:]
} else {
return false
}
}
}
func (l *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port
}
type keyDBLine struct {
cert bool
matcher matcher
knownKey KnownKey
}
func serialize(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
func (l *keyDBLine) match(addrs []addr) bool {
return l.matcher.match(addrs)
}
type hostKeyDB struct {
// Serialized version of revoked keys
revoked map[string]*KnownKey
lines []keyDBLine
}
func newHostKeyDB() *hostKeyDB {
db := &hostKeyDB{
revoked: make(map[string]*KnownKey),
}
return db
}
func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
// IsAuthorityForHost can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {
return false
}
a := addr{host: h, port: p}
for _, l := range db.lines {
if l.cert && keyEq(l.knownKey.Key, remote) && l.match([]addr{a}) {
return true
}
}
return false
}
// IsRevoked can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
_, ok := db.revoked[string(key.Marshal())]
return ok
}
const markerCert = "@cert-authority"
const markerRevoked = "@revoked"
func nextWord(line []byte) (string, []byte) {
i := bytes.IndexAny(line, "\t ")
if i == -1 {
return string(line), nil
}
return string(line[:i]), bytes.TrimSpace(line[i:])
}
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
host, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing host pattern")
}
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
return "", "", nil, err
}
return marker, host, key, nil
}
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
if marker == markerRevoked {
db.revoked[string(key.Marshal())] = &KnownKey{
Key: key,
Filename: filename,
Line: linenum,
}
return nil
}
entry := keyDBLine{
cert: marker == markerCert,
knownKey: KnownKey{
Filename: filename,
Line: linenum,
Key: key,
},
}
if pattern[0] == '|' {
entry.matcher, err = newHashedHost(pattern)
} else {
entry.matcher, err = newHostnameMatcher(pattern)
}
if err != nil {
return err
}
db.lines = append(db.lines, entry)
return nil
}
func newHostnameMatcher(pattern string) (matcher, error) {
var hps hostPatterns
for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
var a addr
var negate bool
if p[0] == '!' {
negate = true
p = p[1:]
}
if len(p) == 0 {
return nil, errors.New("knownhosts: negation without following hostname")
}
var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
a.host = p
a.port = "22"
}
}
hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
return hps, nil
}
// KnownKey represents a key declared in a known_hosts file.
type KnownKey struct {
Key ssh.PublicKey
Filename string
Line int
}
func (k *KnownKey) String() string {
return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
}
// KeyError is returned if we did not find the key in the host key
// database, or there was a mismatch. Typically, in batch
// applications, this should be interpreted as failure. Interactive
// applications can offer an interactive prompt to the user.
type KeyError struct {
// Want holds the accepted host keys. For each key algorithm,
// there can be one hostkey. If Want is empty, the host is
// unknown. If Want is non-empty, there was a mismatch, which
// can signify a MITM attack.
Want []KnownKey
}
func (u *KeyError) Error() string {
if len(u.Want) == 0 {
return "knownhosts: key is unknown"
}
return "knownhosts: key mismatch"
}
// RevokedError is returned if we found a key that was revoked.
type RevokedError struct {
Revoked KnownKey
}
func (r *RevokedError) Error() string {
return "knownhosts: key is revoked"
}
// check checks a key against the host database. This should not be
// used for verifying certificates.
func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
return &RevokedError{Revoked: *revoked}
}
host, port, err := net.SplitHostPort(remote.String())
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
}
addrs := []addr{
{host, port},
}
if address != "" {
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
}
addrs = append(addrs, addr{host, port})
}
return db.checkAddrs(addrs, remoteKey)
}
// checkAddrs checks if we can find the given public key for any of
// the given addresses. If we only find an entry for the IP address,
// or only the hostname, then this still succeeds.
func (db *hostKeyDB) checkAddrs(addrs []addr, remoteKey ssh.PublicKey) error {
// TODO(hanwen): are these the right semantics? What if there
// is just a key for the IP address, but not for the
// hostname?
// Algorithm => key.
knownKeys := map[string]KnownKey{}
for _, l := range db.lines {
if l.match(addrs) {
typ := l.knownKey.Key.Type()
if _, ok := knownKeys[typ]; !ok {
knownKeys[typ] = l.knownKey
}
}
}
keyErr := &KeyError{}
for _, v := range knownKeys {
keyErr.Want = append(keyErr.Want, v)
}
// Unknown remote host.
if len(knownKeys) == 0 {
return keyErr
}
// If the remote host starts using a different, unknown key type, we
// also interpret that as a mismatch.
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
return keyErr
}
return nil
}
// The Read function parses file contents.
func (db *hostKeyDB) Read(r io.Reader, filename string) error {
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Bytes()
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if err := db.parseLine(line, filename, lineNum); err != nil {
return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
}
}
return scanner.Err()
}
// New creates a host key callback from the given OpenSSH host key
// files. The returned callback is for use in
// ssh.ClientConfig.HostKeyCallback. Hashed hostnames are not supported.
func New(files ...string) (ssh.HostKeyCallback, error) {
db := newHostKeyDB()
for _, fn := range files {
f, err := os.Open(fn)
if err != nil {
return nil, err
}
defer f.Close()
if err := db.Read(f, fn); err != nil {
return nil, err
}
}
var certChecker ssh.CertChecker
certChecker.IsHostAuthority = db.IsHostAuthority
certChecker.IsRevoked = db.IsRevoked
certChecker.HostKeyFallback = db.check
return certChecker.CheckHostKey, nil
}
// Normalize normalizes an address into the form used in known_hosts
func Normalize(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
host = address
port = "22"
}
entry := host
if port != "22" {
entry = "[" + entry + "]:" + port
} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
entry = "[" + entry + "]"
}
return entry
}
// Line returns a line to add append to the known_hosts files.
func Line(addresses []string, key ssh.PublicKey) string {
var trimmed []string
for _, a := range addresses {
trimmed = append(trimmed, Normalize(a))
}
return strings.Join(trimmed, ",") + " " + serialize(key)
}
// HashHostname hashes the given hostname. The hostname is not
// normalized before hashing.
func HashHostname(hostname string) string {
// TODO(hanwen): check if we can safely normalize this always.
salt := make([]byte, sha1.Size)
_, err := rand.Read(salt)
if err != nil {
panic(fmt.Sprintf("crypto/rand failure %v", err))
}
hash := hashHost(hostname, salt)
return encodeHash(sha1HashType, salt, hash)
}
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
if len(encoded) == 0 || encoded[0] != '|' {
err = errors.New("knownhosts: hashed host must start with '|'")
return
}
components := strings.Split(encoded, "|")
if len(components) != 4 {
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
return
}
hashType = components[1]
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
return
}
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
return
}
return
}
func encodeHash(typ string, salt []byte, hash []byte) string {
return strings.Join([]string{"",
typ,
base64.StdEncoding.EncodeToString(salt),
base64.StdEncoding.EncodeToString(hash),
}, "|")
}
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
func hashHost(hostname string, salt []byte) []byte {
mac := hmac.New(sha1.New, salt)
mac.Write([]byte(hostname))
return mac.Sum(nil)
}
type hashedHost struct {
salt []byte
hash []byte
}
const sha1HashType = "1"
func newHashedHost(encoded string) (*hashedHost, error) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
return nil, err
}
// The type field seems for future algorithm agility, but it's
// actually hardcoded in openssh currently, see
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
if typ != sha1HashType {
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
}
return &hashedHost{salt: salt, hash: hash}, nil
}
func (h *hashedHost) match(addrs []addr) bool {
for _, a := range addrs {
if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) {
return true
}
}
return false
}

View File

@@ -0,0 +1,329 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package knownhosts
import (
"bytes"
"fmt"
"net"
"reflect"
"testing"
"golang.org/x/crypto/ssh"
)
const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
var ecKey, alternateEdKey, edKey ssh.PublicKey
var testAddr = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 22,
}
var testAddr6 = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
},
Port: 22,
}
func init() {
var err error
ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
if err != nil {
panic(err)
}
edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
if err != nil {
panic(err)
}
alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
if err != nil {
panic(err)
}
}
func testDB(t *testing.T, s string) *hostKeyDB {
db := newHostKeyDB()
if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
t.Fatalf("Read: %v", err)
}
return db
}
func TestRevoked(t *testing.T) {
db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
want := &RevokedError{
Revoked: KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
},
}
if err := db.check("", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatal("no error for revoked key")
} else if !reflect.DeepEqual(want, err) {
t.Fatalf("got %#v, want %#v", want, err)
}
}
func TestHostAuthority(t *testing.T) {
for _, m := range []struct {
authorityFor string
address string
good bool
}{
{authorityFor: "localhost", address: "localhost:22", good: true},
{authorityFor: "localhost", address: "localhost", good: false},
{authorityFor: "localhost", address: "localhost:1234", good: false},
{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
{authorityFor: "[localhost]:1234", address: "localhost", good: false},
} {
db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
m.authorityFor, m.address, m.good, ok)
}
}
}
func TestBracket(t *testing.T) {
db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 29418,
}, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) > 0 {
t.Fatalf("got Want %v, want []", ke.Want)
}
}
func TestNewKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, ecKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
}
}
func TestSameKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, alternateEdKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
t.Fatalf("got key %q, want %q", got, want)
}
}
func TestIPAddress(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestIPv6Address(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr6, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestBasic(t *testing.T) {
str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
want := KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
}
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded, want KeyError")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
} else if len(ke.Want) != 1 {
t.Errorf("got %v, want 1 entry", ke)
} else if !reflect.DeepEqual(ke.Want[0], want) {
t.Errorf("got %v, want %v", ke.Want[0], want)
}
}
func TestNegate(t *testing.T) {
str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got error type %T, want *KeyError", err)
} else if len(ke.Want) != 0 {
t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
}
}
func TestWildcard(t *testing.T) {
str := fmt.Sprintf("server*.domain %s", edKeyStr)
db := testDB(t, str)
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s, want %s", got, want)
}
}
func TestLine(t *testing.T) {
for in, want := range map[string]string{
"server.org": "server.org " + edKeyStr,
"server.org:22": "server.org " + edKeyStr,
"server.org:23": "[server.org]:23 " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
} {
if got := Line([]string{in}, edKey); got != want {
t.Errorf("Line(%q) = %q, want %q", in, got, want)
}
}
}
func TestWildcardMatch(t *testing.T) {
for _, c := range []struct {
pat, str string
want bool
}{
{"a?b", "abb", true},
{"ab", "abc", false},
{"abc", "ab", false},
{"a*b", "axxxb", true},
{"a*b", "axbxb", true},
{"a*b", "axbxbc", false},
{"a*?", "axbxc", true},
{"a*b*", "axxbxxxxxx", true},
{"a*b*c", "axxbxxxxxxc", true},
{"a*b*?", "axxbxxxxxxc", true},
{"a*b*z", "axxbxxbxxxz", true},
{"a*b*z", "axxbxxzxxxz", true},
{"a*b*z", "axxbxxzxxx", false},
} {
got := wildcardMatch([]byte(c.pat), []byte(c.str))
if got != c.want {
t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
}
}
}
// TODO(hanwen): test coverage for certificates.
const testHostname = "hostname"
// generated with keygen -H -f
const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
func TestHostHash(t *testing.T) {
testHostHash(t, testHostname, encodedTestHostnameHash)
}
func TestHashList(t *testing.T) {
encoded := HashHostname(testHostname)
testHostHash(t, testHostname, encoded)
}
func testHostHash(t *testing.T, hostname, encoded string) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
t.Fatalf("decodeHash: %v", err)
}
if got := encodeHash(typ, salt, hash); got != encoded {
t.Errorf("got encoding %s want %s", got, encoded)
}
if typ != sha1HashType {
t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
}
got := hashHost(hostname, salt)
if !bytes.Equal(got, hash) {
t.Errorf("got hash %x want %x", got, hash)
}
}
func TestNormalize(t *testing.T) {
for in, want := range map[string]string{
"127.0.0.1:22": "127.0.0.1",
"[127.0.0.1]:22": "127.0.0.1",
"[127.0.0.1]:23": "[127.0.0.1]:23",
"127.0.0.1:23": "[127.0.0.1]:23",
"[a.b.c]:22": "a.b.c",
"[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
} {
got := Normalize(in)
if got != want {
t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
}
}
}
func TestHashedHostkeyCheck(t *testing.T) {
str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
db := testDB(t, str)
if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
t.Errorf("check(%s): %v", testHostname, err)
}
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
}

View File

@@ -45,6 +45,12 @@ type ServerConfig struct {
// authenticating.
NoClientAuth bool
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
// to 6.
MaxAuthTries int
// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
@@ -143,6 +149,10 @@ type ServerConn struct {
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
if fullConf.MaxAuthTries == 0 {
fullConf.MaxAuthTries = 6
}
s := &connection{
sshConn: sshConn{conn: c},
}
@@ -267,8 +277,23 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
var cache pubKeyCache
var perms *Permissions
authFailures := 0
userAuthLoop:
for {
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
discMsg := &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
}
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
}
var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil {
return nil, err
@@ -289,6 +314,11 @@ userAuthLoop:
if config.NoClientAuth {
authErr = nil
}
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password":
if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
@@ -360,6 +390,7 @@ userAuthLoop:
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
@@ -409,6 +440,8 @@ userAuthLoop:
break userAuthLoop
}
authFailures++
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")

115
vendor/golang.org/x/crypto/ssh/streamlocal.go generated vendored Normal file
View File

@@ -0,0 +1,115 @@
package ssh
import (
"errors"
"io"
"net"
)
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "direct-streamlocal@openssh.com" string.
//
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
type streamLocalChannelOpenDirectMsg struct {
socketPath string
reserved0 string
reserved1 uint32
}
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "forwarded-streamlocal@openssh.com" string.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved0 string
}
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
type streamLocalChannelForwardMsg struct {
socketPath string
}
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
m := streamLocalChannelForwardMsg{
socketPath,
}
// send message
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
return &unixListener{socketPath, c, ch}, nil
}
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
msg := streamLocalChannelOpenDirectMsg{
socketPath: socketPath,
}
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type unixListener struct {
socketPath string
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *unixListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
},
raddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
}, nil
}
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
m := streamLocalChannelForwardMsg{
l.socketPath,
}
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
}
return err
}
// Addr returns the listener's network address.
func (l *unixListener) Addr() net.Addr {
return &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
}
}

View File

@@ -20,12 +20,20 @@ import (
// addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
switch n {
case "tcp", "tcp4", "tcp6":
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
case "unix":
return c.ListenUnix(addr)
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
return c.ListenTCP(laddr)
}
// Automatic port allocation is broken with OpenSSH before 6.0. See
@@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
}
// Register this forward, using the port number we obtained.
ch := c.forwards.add(*laddr)
ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil
}
@@ -131,7 +139,7 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
laddr net.TCPAddr
laddr net.Addr
c chan forward
}
@@ -139,16 +147,16 @@ type forwardEntry struct {
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection
newCh NewChannel // the ssh client channel underlying this forward
raddr net.Addr // the raddr of the incoming connection
}
func (l *forwardList) add(addr net.TCPAddr) chan forward {
func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
addr,
make(chan forward, 1),
laddr: addr,
c: make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
@@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
var payload forwardedTCPPayload
if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
var (
laddr net.Addr
raddr net.Addr
err error
)
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
var payload forwardedTCPPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err := parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
if ok := l.forward(*laddr, *raddr, ch); !ok {
case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
continue
}
}
}
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.TCPAddr) {
func (l *forwardList) remove(addr net.Addr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
@@ -231,12 +264,12 @@ func (l *forwardList) closeAll() {
l.entries = nil
}
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
f.c <- forward{ch, &raddr}
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
f.c <- forward{newCh: ch, raddr: raddr}
return true
}
}
@@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
}
go DiscardRequests(incoming)
return &tcpChanConn{
return &chanConn{
Channel: ch,
laddr: l.laddr,
raddr: s.raddr,
@@ -277,7 +310,7 @@ func (l *tcpListener) Close() error {
}
// this also closes the listener.
l.conn.forwards.remove(*l.laddr)
l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
@@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr {
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
var ch Channel
switch n {
case "tcp", "tcp4", "tcp6":
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
return &chanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
case "unix":
var err error
ch, err = c.dialStreamLocal(addr)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
raddr: &net.UnixAddr{
Name: addr,
Net: "unix",
},
}, nil
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
}
// DialTCP connects to the remote address raddr on the network net,
@@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
if err != nil {
return nil, err
}
return &tcpChanConn{
return &chanConn{
Channel: ch,
laddr: laddr,
raddr: raddr,
@@ -366,26 +422,26 @@ type tcpChan struct {
Channel // the backing channel
}
// tcpChanConn fulfills the net.Conn interface without
// chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct {
type chanConn struct {
Channel
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpChanConn) LocalAddr() net.Addr {
func (t *chanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpChanConn) RemoteAddr() net.Addr {
func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
@@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
func (t *chanConn) SetReadDeadline(deadline time.Time) error {
// for compatibility with previous version,
// the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}

View File

@@ -14,14 +14,12 @@ import (
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
state *unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
// see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c
var termio unix.Termio
err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio)
_, err := unix.IoctlGetTermio(fd, unix.TCGETA)
return err == nil
}
@@ -71,3 +69,60 @@ func ReadPassword(fd int) ([]byte, error) {
return ret, nil
}
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldTermios := *oldTermiosPtr
newTermios := oldTermios
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

128
vendor/golang.org/x/crypto/ssh/test/dial_unix_test.go generated vendored Normal file
View File

@@ -0,0 +1,128 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip and direct-streamlocal functional tests
import (
"fmt"
"io"
"io/ioutil"
"net"
"strings"
"testing"
)
type dialTester interface {
TestServerConn(t *testing.T, c net.Conn)
TestClientConn(t *testing.T, c net.Conn)
}
func testDial(t *testing.T, n, listenAddr string, x dialTester) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen(n, listenAddr)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
x.TestServerConn(t, c)
io.WriteString(c, testData)
c.Close()
}
}()
conn, err := sshConn.Dial(n, l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
x.TestClientConn(t, conn)
defer conn.Close()
b, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
t.Logf("got %q", string(b))
if string(b) != testData {
t.Fatalf("expected %q, got %q", testData, string(b))
}
}
type tcpDialTester struct {
listenAddr string
}
func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
host := strings.Split(x.listenAddr, ":")[0]
prefix := host + ":"
if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
}
if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
}
}
func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
// we use zero addresses. see *Client.Dial.
if c.LocalAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
}
if c.RemoteAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
}
}
func TestDialTCP(t *testing.T) {
x := &tcpDialTester{
listenAddr: "127.0.0.1:0",
}
testDial(t, "tcp", x.listenAddr, x)
}
type unixDialTester struct {
listenAddr string
}
func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
if c.LocalAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
}
if c.RemoteAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
}
}
func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
if c.RemoteAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
}
if c.LocalAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
}
}
func TestDialUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
x := &unixDialTester{
listenAddr: addr,
}
testDial(t, "unix", x.listenAddr, x)
}

View File

@@ -16,13 +16,17 @@ import (
"time"
)
func TestPortForward(t *testing.T) {
type closeWriter interface {
CloseWrite() error
}
func testPortForward(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) {
}()
forwardedAddr := sshListener.Addr().String()
tcpConn, err := net.Dial("tcp", forwardedAddr)
netConn, err := net.Dial(n, forwardedAddr)
if err != nil {
t.Fatalf("TCP dial failed: %v", err)
t.Fatalf("net dial failed: %v", err)
}
readChan := make(chan []byte)
go func() {
data, _ := ioutil.ReadAll(tcpConn)
data, _ := ioutil.ReadAll(netConn)
readChan <- data
}()
@@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) {
for len(sent) < 1000*1000 {
// Send random sized chunks
m := rand.Intn(len(data))
n, err := tcpConn.Write(data[:m])
n, err := netConn.Write(data[:m])
if err != nil {
break
}
sent = append(sent, data[:n]...)
}
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
t.Errorf("tcpConn.CloseWrite: %v", err)
if err := netConn.(closeWriter).CloseWrite(); err != nil {
t.Errorf("netConn.CloseWrite: %v", err)
}
read := <-readChan
@@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) {
}
// Check that the forward disappeared.
tcpConn, err = net.Dial("tcp", forwardedAddr)
netConn, err = net.Dial(n, forwardedAddr)
if err == nil {
tcpConn.Close()
netConn.Close()
t.Errorf("still listening to %s after closing", forwardedAddr)
}
}
func TestAcceptClose(t *testing.T) {
func TestPortForwardTCP(t *testing.T) {
testPortForward(t, "tcp", "localhost:0")
}
func TestPortForwardUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForward(t, "unix", addr)
}
func testAcceptClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) {
}
}
func TestAcceptCloseTCP(t *testing.T) {
testAcceptClose(t, "tcp", "localhost:0")
}
func TestAcceptCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testAcceptClose(t, "unix", addr)
}
// Check that listeners exit if the underlying client transport dies.
func TestPortForwardConnectionClose(t *testing.T) {
func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) {
t.Logf("quit as expected (error %v)", err)
}
}
func TestPortForwardConnectionCloseTCP(t *testing.T) {
testPortForwardConnectionClose(t, "tcp", "localhost:0")
}
func TestPortForwardConnectionCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForwardConnectionClose(t, "unix", addr)
}

View File

@@ -1,46 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip functional tests
import (
"io"
"net"
"testing"
)
func TestDial(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
io.WriteString(c, c.RemoteAddr().String())
c.Close()
}
}()
conn, err := sshConn.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
}

View File

@@ -266,3 +266,13 @@ func newServer(t *testing.T) *server {
},
}
}
func newTempSocket(t *testing.T) (string, func()) {
dir, err := ioutil.TempDir("", "socket")
if err != nil {
t.Fatal(err)
}
deferFunc := func() { os.RemoveAll(dir) }
addr := filepath.Join(dir, "sock")
return addr, deferFunc
}

View File

@@ -48,6 +48,22 @@ AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg==
-----END OPENSSH PRIVATE KEY-----
`),
"rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l
oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz
a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH
c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE
uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+
yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4
IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE
pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk
kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk
pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju
X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB
ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ
wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID
-----END OPENSSH PRIVATE KEY-----`),
"user": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD