mirror of
https://github.com/golang/go
synced 2024-11-17 09:04:44 -07:00
net/http: switch HTTP1 to ASCII equivalents of string functions
The current implementation uses UTF-aware functions like strings.EqualFold and strings.ToLower. This could, in some cases, cause http smuggling. Change-Id: I0e76a993470a1e1b1b472f4b2859ea0a2b22ada0 Reviewed-on: https://go-review.googlesource.com/c/go/+/308009 Run-TryBot: Filippo Valsorda <filippo@golang.org> TryBot-Result: Go Bot <gobot@golang.org> Trust: Roberto Clapis <roberto@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
dc50683bf7
commit
5c489514bc
@ -440,7 +440,7 @@ var depsRules = `
|
||||
# HTTP, King of Dependencies.
|
||||
|
||||
FMT
|
||||
< golang.org/x/net/http2/hpack, net/http/internal;
|
||||
< golang.org/x/net/http2/hpack, net/http/internal, net/http/internal/ascii;
|
||||
|
||||
FMT, NET, container/list, encoding/binary, log
|
||||
< golang.org/x/text/transform
|
||||
@ -458,6 +458,7 @@ var depsRules = `
|
||||
golang.org/x/net/http/httpproxy,
|
||||
golang.org/x/net/http2/hpack,
|
||||
net/http/internal,
|
||||
net/http/internal/ascii,
|
||||
net/http/httptrace,
|
||||
mime/multipart,
|
||||
log
|
||||
@ -468,7 +469,7 @@ var depsRules = `
|
||||
encoding/json, net/http
|
||||
< expvar;
|
||||
|
||||
net/http
|
||||
net/http, net/http/internal/ascii
|
||||
< net/http/cookiejar, net/http/httputil;
|
||||
|
||||
net/http, flag
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http/internal/ascii"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
@ -547,7 +548,10 @@ func urlErrorOp(method string) string {
|
||||
if method == "" {
|
||||
return "Get"
|
||||
}
|
||||
return method[:1] + strings.ToLower(method[1:])
|
||||
if lowerMethod, ok := ascii.ToLower(method); ok {
|
||||
return method[:1] + lowerMethod[1:]
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
// Do sends an HTTP request and returns an HTTP response, following
|
||||
|
@ -7,6 +7,7 @@ package http
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -93,15 +94,23 @@ func readSetCookies(h Header) []*Cookie {
|
||||
if j := strings.Index(attr, "="); j >= 0 {
|
||||
attr, val = attr[:j], attr[j+1:]
|
||||
}
|
||||
lowerAttr := strings.ToLower(attr)
|
||||
lowerAttr, isASCII := ascii.ToLower(attr)
|
||||
if !isASCII {
|
||||
continue
|
||||
}
|
||||
val, ok = parseCookieValue(val, false)
|
||||
if !ok {
|
||||
c.Unparsed = append(c.Unparsed, parts[i])
|
||||
continue
|
||||
}
|
||||
|
||||
switch lowerAttr {
|
||||
case "samesite":
|
||||
lowerVal := strings.ToLower(val)
|
||||
lowerVal, ascii := ascii.ToLower(val)
|
||||
if !ascii {
|
||||
c.SameSite = SameSiteDefaultMode
|
||||
continue
|
||||
}
|
||||
switch lowerVal {
|
||||
case "lax":
|
||||
c.SameSite = SameSiteLaxMode
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/internal/ascii"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -296,7 +297,6 @@ func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
|
||||
// host name.
|
||||
func canonicalHost(host string) (string, error) {
|
||||
var err error
|
||||
host = strings.ToLower(host)
|
||||
if hasPort(host) {
|
||||
host, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
@ -307,7 +307,13 @@ func canonicalHost(host string) (string, error) {
|
||||
// Strip trailing dot from fully qualified domain names.
|
||||
host = host[:len(host)-1]
|
||||
}
|
||||
return toASCII(host)
|
||||
encoded, err := toASCII(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// We know this is ascii, no need to check.
|
||||
lower, _ := ascii.ToLower(encoded)
|
||||
return lower, nil
|
||||
}
|
||||
|
||||
// hasPort reports whether host contains a port number. host may be a host
|
||||
@ -469,7 +475,12 @@ func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
|
||||
// both are illegal.
|
||||
return "", false, errMalformedDomain
|
||||
}
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
domain, isASCII := ascii.ToLower(domain)
|
||||
if !isASCII {
|
||||
// Received non-ASCII domain, e.g. "perché.com" instead of "xn--perch-fsa.com"
|
||||
return "", false, errMalformedDomain
|
||||
}
|
||||
|
||||
if domain[len(domain)-1] == '.' {
|
||||
// We received stuff like "Domain=www.example.com.".
|
||||
|
@ -8,6 +8,7 @@ package cookiejar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/internal/ascii"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@ -133,12 +134,12 @@ const acePrefix = "xn--"
|
||||
// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
|
||||
// toASCII("golang") is "golang".
|
||||
func toASCII(s string) (string, error) {
|
||||
if ascii(s) {
|
||||
if ascii.Is(s) {
|
||||
return s, nil
|
||||
}
|
||||
labels := strings.Split(s, ".")
|
||||
for i, label := range labels {
|
||||
if !ascii(label) {
|
||||
if !ascii.Is(label) {
|
||||
a, err := encode(acePrefix, label)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -148,12 +149,3 @@ func toASCII(s string) (string, error) {
|
||||
}
|
||||
return strings.Join(labels, "."), nil
|
||||
}
|
||||
|
||||
func ascii(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] >= utf8.RuneSelf {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ package http
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptrace"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -251,7 +252,7 @@ func hasToken(v, token string) bool {
|
||||
if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(v[sp:sp+len(token)], token) {
|
||||
if ascii.EqualFold(v[sp:sp+len(token)], token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -62,15 +62,6 @@ func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
||||
|
||||
func isASCII(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] >= utf8.RuneSelf {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// stringContainsCTLByte reports whether s contains any ASCII control character.
|
||||
func stringContainsCTLByte(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -242,6 +243,10 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
if !ascii.IsPrint(reqUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
removeConnectionHeaders(outreq.Header)
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
@ -538,13 +543,16 @@ func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(h.Get("Upgrade"))
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
if reqUpType != resUpType {
|
||||
if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||
}
|
||||
if !ascii.EqualFold(reqUpType, resUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/internal/ascii"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
@ -1183,7 +1184,7 @@ func TestReverseProxyWebSocket(t *testing.T) {
|
||||
t.Errorf("Header(XHeader) = %q; want %q", got, want)
|
||||
}
|
||||
|
||||
if upgradeType(res.Header) != "websocket" {
|
||||
if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
|
||||
t.Fatalf("not websocket upgrade; got %#v", res.Header)
|
||||
}
|
||||
rwc, ok := res.Body.(io.ReadWriteCloser)
|
||||
@ -1300,7 +1301,7 @@ func TestReverseProxyWebSocketCancelation(t *testing.T) {
|
||||
t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
|
||||
}
|
||||
|
||||
if g, w := upgradeType(res.Header), "websocket"; g != w {
|
||||
if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
|
||||
t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
|
||||
}
|
||||
|
||||
|
61
src/net/http/internal/ascii/print.go
Normal file
61
src/net/http/internal/ascii/print.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ascii
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t
|
||||
// are equal, ASCII-case-insensitively.
|
||||
func EqualFold(s, t string) bool {
|
||||
if len(s) != len(t) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(s); i++ {
|
||||
if lower(s[i]) != lower(t[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// lower returns the ASCII lowercase version of b.
|
||||
func lower(b byte) byte {
|
||||
if 'A' <= b && b <= 'Z' {
|
||||
return b + ('a' - 'A')
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// IsPrint returns whether s is ASCII and printable according to
|
||||
// https://tools.ietf.org/html/rfc20#section-4.2.
|
||||
func IsPrint(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < ' ' || s[i] > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Is returns whether s is ASCII.
|
||||
func Is(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ToLower returns the lowercase version of s if s is ASCII and printable.
|
||||
func ToLower(s string) (lower string, ok bool) {
|
||||
if !IsPrint(s) {
|
||||
return "", false
|
||||
}
|
||||
return strings.ToLower(s), true
|
||||
}
|
95
src/net/http/internal/ascii/print_test.go
Normal file
95
src/net/http/internal/ascii/print_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ascii
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEqualFold(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
a, b string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "simple match",
|
||||
a: "CHUNKED",
|
||||
b: "chunked",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "same string",
|
||||
a: "chunked",
|
||||
b: "chunked",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Unicode Kelvin symbol",
|
||||
a: "chunKed", // This "K" is 'KELVIN SIGN' (\u212A)
|
||||
b: "chunked",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := EqualFold(tt.a, tt.b); got != tt.want {
|
||||
t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrint(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ASCII low",
|
||||
in: "This is a space: ' '",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ASCII high",
|
||||
in: "This is a tilde: '~'",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ASCII low non-print",
|
||||
in: "This is a unit separator: \x1F",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Ascii high non-print",
|
||||
in: "This is a Delete: \x7F",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Unicode letter",
|
||||
in: "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A)
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Unicode emoji",
|
||||
in: "Gophers like 🧀",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsPrint(tt.in); got != tt.want {
|
||||
t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http/httptrace"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
urlpkg "net/url"
|
||||
@ -723,7 +724,7 @@ func idnaASCII(v string) (string, error) {
|
||||
// version does not.
|
||||
// Note that for correct ASCII IDNs ToASCII will only do considerably more
|
||||
// work, but it will not cause an allocation.
|
||||
if isASCII(v) {
|
||||
if ascii.Is(v) {
|
||||
return v, nil
|
||||
}
|
||||
return idna.Lookup.ToASCII(v)
|
||||
@ -948,7 +949,7 @@ func (r *Request) BasicAuth() (username, password string, ok bool) {
|
||||
func parseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
const prefix = "Basic "
|
||||
// Case insensitive prefix match. See Issue 22736.
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
|
||||
@ -1456,5 +1457,5 @@ func requestMethodUsuallyLacksBody(method string) bool {
|
||||
// an HTTP/1 connection.
|
||||
func (r *Request) requiresHTTP1() bool {
|
||||
return hasToken(r.Header.Get("Connection"), "upgrade") &&
|
||||
strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
ascii.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"net/http/httptrace"
|
||||
"net/http/internal"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"sort"
|
||||
@ -638,7 +639,7 @@ func (t *transferReader) parseTransferEncoding() error {
|
||||
if len(raw) != 1 {
|
||||
return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)}
|
||||
}
|
||||
if strings.ToLower(textproto.TrimString(raw[0])) != "chunked" {
|
||||
if !ascii.EqualFold(textproto.TrimString(raw[0]), "chunked") {
|
||||
return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])}
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http/httptrace"
|
||||
"net/http/internal/ascii"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -2185,7 +2186,7 @@ func (pc *persistConn) readLoop() {
|
||||
}
|
||||
|
||||
resp.Body = body
|
||||
if rc.addedGzip && strings.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
|
||||
if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
|
||||
resp.Body = &gzipReader{body: body}
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-Length")
|
||||
|
Loading…
Reference in New Issue
Block a user