diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index fdd53f33c85..7e1d65df30b 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -11,9 +11,7 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" "os" - "strconv" "strings" ) @@ -228,23 +226,12 @@ func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Erro // // Caller should close r.Body when done reading from it. func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err os.Error) { - var req Request - req.Method = "POST" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = true - req.Body = ioutil.NopCloser(body) - req.Header = Header{ - "Content-Type": {bodyType}, - } - req.TransferEncoding = []string{"chunked"} - - req.URL, err = ParseURL(url) + req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - - return send(&req, c.Transport) + req.Header.Set("Content-Type", bodyType) + return send(req, c.Transport) } // PostForm issues a POST to the specified URL, @@ -262,25 +249,7 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) { // // Caller should close r.Body when done reading from it. func (c *Client) PostForm(url string, data map[string]string) (r *Response, err os.Error) { - var req Request - req.Method = "POST" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = true - body := urlencode(data) - req.Body = ioutil.NopCloser(body) - req.Header = Header{ - "Content-Type": {"application/x-www-form-urlencoded"}, - "Content-Length": {strconv.Itoa(body.Len())}, - } - req.ContentLength = int64(body.Len()) - - req.URL, err = ParseURL(url) - if err != nil { - return nil, err - } - - return send(&req, c.Transport) + return c.Post(url, "application/x-www-form-urlencoded", urlencode(data)) } // TODO: remove this function when PostForm takes a multimap. diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index ba14e4e4d38..822a8889cad 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -78,6 +78,61 @@ func TestGetRequestFormat(t *testing.T) { } } +func TestPostRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + json := `{"key":"value"}` + b := strings.NewReader(json) + client.Post(url, "application/json", b) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len(json)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } +} + +func TestPostFormRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + form := map[string]string{"foo": "bar"} + client.PostForm(url, form) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e { + t.Errorf("got Content-Type %q, want %q", g, e) + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len("foo=bar")); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } + +} + func TestRedirects(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index 2f39de182b3..2f6b651c3e3 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -10,6 +10,7 @@ package http import ( "bufio" + "bytes" "crypto/tls" "container/vector" "encoding/base64" @@ -231,7 +232,7 @@ const defaultUserAgent = "Go http package" // Method (defaults to "GET") // UserAgent (defaults to defaultUserAgent) // Referer -// Header +// Header (only keys not already in this list) // Cookie // ContentLength // TransferEncoding @@ -256,6 +257,9 @@ func (req *Request) WriteProxy(w io.Writer) os.Error { func (req *Request) write(w io.Writer, usingProxy bool) os.Error { host := req.Host if host == "" { + if req.URL == nil { + return os.NewError("http: Request.Write on Request with no Host or URL set") + } host = req.URL.Host } @@ -475,6 +479,17 @@ func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { Body: rc, Host: u.Host, } + if body != nil { + switch v := body.(type) { + case *strings.Reader: + req.ContentLength = int64(v.Len()) + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + default: + req.ContentLength = -1 // chunked + } + } + return req, nil } diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index beb51fb8d75..2889048a940 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -175,6 +175,35 @@ var reqWriteTests = []reqWriteTest{ "abcdef", }, + // HTTP/1.1 POST with Content-Length in headers + { + Request{ + Method: "POST", + RawURL: "http://example.com/", + Host: "example.com", + Header: Header{ + "Content-Length": []string{"10"}, // ignored + }, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ diff --git a/src/pkg/strings/reader.go b/src/pkg/strings/reader.go index cd424115d0f..10b0278e1c4 100644 --- a/src/pkg/strings/reader.go +++ b/src/pkg/strings/reader.go @@ -17,6 +17,12 @@ type Reader struct { prevRune int // index of previous rune; or < 0 } +// Len returns the number of bytes of the unread portion of the +// string. +func (r *Reader) Len() int { + return len(r.s) - r.i +} + func (r *Reader) Read(b []byte) (n int, err os.Error) { if r.i >= len(r.s) { return 0, os.EOF