From 95e2c9513b1eadbcf58253c429724a360d026cfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Sat, 18 May 2024 11:58:56 +0200 Subject: [PATCH] Misc adjustments * Rework (some) tests * Add EnableETagPair option to add a pair of eTags, even if the server does not provide one * Add CacheKey option * Remove redundant nil check * Optionally allow caching for other HTTP methods than GET and HEAD --- go.mod | 9 + go.sum | 12 + httpcache.go | 122 ++++++++-- httpcache_test.go | 571 +++++++++++++++++----------------------------- 4 files changed, 332 insertions(+), 382 deletions(-) create mode 100644 go.sum diff --git a/go.mod b/go.mod index 222cca6..115ba59 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,12 @@ module github.com/gohugoio/httpcache go 1.22.2 + +require github.com/frankban/quicktest v1.14.6 + +require ( + github.com/google/go-cmp v0.5.9 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ab408ab --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= diff --git a/httpcache.go b/httpcache.go index b44c99e..cd0e1f0 100644 --- a/httpcache.go +++ b/httpcache.go @@ -8,7 +8,10 @@ package httpcache import ( "bufio" "bytes" + "crypto/md5" + "encoding/hex" "errors" + "hash" "io" "net/http" "net/http/httputil" @@ -23,6 +26,15 @@ const ( transparent // XFromCache is the header added to responses that are returned from the cache XFromCache = "X-From-Cache" + + // xEtags is the prefix for the header with the custom etag pair set in the cached response. + xEtags = "X-Etags-" + + // XETag1 is the key for the first eTag value. + XETag1 = xEtags + "1" + + // XETag2 is the key for the second eTag value. + XETag2 = xEtags + "2" ) // A Cache interface is used by the Transport to store and retrieve responses. @@ -37,7 +49,16 @@ type Cache interface { } // cacheKey returns the cache key for req. -func cacheKey(req *http.Request) string { +func (t *Transport) cacheKey(req *http.Request) string { + if t.CacheKey != nil { + return t.CacheKey(req) + } + + cacheable := (req.Method != http.MethodHead || req.Method == "HEAD") && req.Header.Get("range") == "" + if !cacheable { + return "" + } + if req.Method == http.MethodGet { return req.URL.String() } else { @@ -47,8 +68,8 @@ func cacheKey(req *http.Request) string { // cachedResponse returns the cached http.Response for req if present, and nil // otherwise. -func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) { - cachedVal, ok := c.Get(cacheKey(req)) +func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) { + cachedVal, ok := t.Cache.Get(t.cacheKey(req)) if !ok { return } @@ -63,6 +84,12 @@ type memoryCache struct { items map[string][]byte } +func (c *memoryCache) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.items) +} + // Get returns the []byte representation of the response and true if present, false if not func (c *memoryCache) Get(key string) (resp []byte, ok bool) { c.mu.RLock() @@ -105,11 +132,21 @@ type Transport struct { // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool + // if EnableETagPair is true, the Transport will store the pair of eTags in the response header. + // These are stored in the X-Etags-1 and X-Etags-2 headers. + // If these are different, the response has been modified. + // If the server does not return an eTag, the MD5 hash of the response body is used. + EnableETagPair bool + + // CacheKey is an optional func that returns the key to use to store the response. + // An empty string signals that this request should not be cached. + CacheKey func(req *http.Request) string + // Around is an optional func. // If set, the Transport will call Around at the start of RoundTrip // and defer the returned func until the end of RoundTrip. // Typically used to implement a lock that is held for the duration of the RoundTrip. - Around func(key string) func() + Around func(req *http.Request, key string) func() } // varyMatches will return false unless all of the cached values for the headers listed in Vary @@ -133,14 +170,18 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // to give the server a chance to respond with NotModified. If this happens, then the cached Response // will be returned. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - cacheKey := cacheKey(req) + cacheKey := t.cacheKey(req) if f := t.Around; f != nil { - defer f(cacheKey)() + defer f(req, cacheKey)() } - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + + var cachedXEtag string + + cacheable := cacheKey != "" + var cachedResp *http.Response if cacheable { - cachedResp, err = cachedResponse(t.Cache, req) + cachedResp, err = t.cachedResponse(req) } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) @@ -155,6 +196,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if t.MarkCachedResponses { cachedResp.Header.Set(XFromCache, "1") } + if t.EnableETagPair { + cachedXEtag, _ = getXETags(cachedResp.Header) + } if varyMatches(cachedResp, req) { // Can only use cached value if the new request doesn't Vary significantly @@ -185,15 +229,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } resp, err = transport.RoundTrip(req) - if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { + + if err == nil && req.Method != http.MethodHead && resp.StatusCode == http.StatusNotModified { // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) for _, header := range endToEndHeaders { cachedResp.Header[header] = resp.Header[header] } resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { + } else if (err != nil || resp.StatusCode >= 500) && + req.Method != http.MethodHead && canStaleOnError(cachedResp.Header, req.Header) { // In case of transport failure and stale-if-error activated, returns cached content // when available return cachedResp, nil @@ -227,11 +272,39 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } } switch req.Method { - case "GET": - // Delay caching until EOF is reached. - resp.Body = &cachingReadCloser{ - R: resp.Body, + case http.MethodHead: + respBytes, err := httputil.DumpResponse(resp, true) + if err == nil { + t.Cache.Set(cacheKey, respBytes) + } + default: + var etagHash hash.Hash + r := resp.Body + if t.EnableETagPair { + if etag := resp.Header.Get("etag"); etag != "" { + resp.Header.Set(XETag1, etag) + resp.Header.Set(XETag2, cachedXEtag) + } else { + etagHash = md5.New() + r = struct { + io.Reader + io.Closer + }{ + io.TeeReader(r, etagHash), + resp.Body, + } + } + } + + r = &cachingReadCloser{ + R: r, OnEOF: func(r io.Reader) { + if etagHash != nil { + md5Str := hex.EncodeToString(etagHash.Sum(nil)) + resp.Header.Set(XETag1, md5Str) + resp.Header.Set(XETag2, cachedXEtag) + + } resp := *resp resp.Body = io.NopCloser(r) respBytes, err := httputil.DumpResponse(&resp, true) @@ -239,12 +312,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error t.Cache.Set(cacheKey, respBytes) } }, + buf: &bytes.Buffer{}, } - default: - respBytes, err := httputil.DumpResponse(resp, true) - if err == nil { - t.Cache.Set(cacheKey, respBytes) - } + // Delay caching until EOF is reached. + resp.Body = r + } } else { t.Cache.Delete(cacheKey) @@ -278,6 +350,10 @@ type timer interface { var clock timer = &realClock{} +func getXETags(h http.Header) (string, string) { + return h.Get(XETag1), h.Get(XETag2) +} + // getFreshness will return one of fresh/stale/transparent based on the cache-control // values of the request and the response // @@ -522,7 +598,7 @@ type cachingReadCloser struct { // OnEOF is called with a copy of the content of R when EOF is reached. OnEOF func(io.Reader) - buf bytes.Buffer // buf stores a copy of the content of R. + buf *bytes.Buffer // buf stores a copy of the content of R. } // Read reads the next len(p) bytes from R or until R is drained. The @@ -533,7 +609,7 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) { n, err = r.R.Read(p) r.buf.Write(p[:n]) if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) + r.OnEOF(r.buf) } return n, err } @@ -545,6 +621,6 @@ func (r *cachingReadCloser) Close() error { // newMemoryCacheTransport returns a new Transport using the in-memory cache implementation func newMemoryCacheTransport() *Transport { c := newMemoryCache() - t := &Transport{Cache: c, MarkCachedResponses: true} + t := &Transport{Cache: c} return t } diff --git a/httpcache_test.go b/httpcache_test.go index 79167d6..6359f01 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -11,6 +11,8 @@ import ( "strconv" "testing" "time" + + qt "github.com/frankban/quicktest" ) var s struct { @@ -55,6 +57,10 @@ func setup() { w.Write([]byte(r.Method)) })) + mux.HandleFunc("/helloheaderasbody", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Header.Get("Hello"))) + })) + mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lm := "Fri, 14 Dec 2010 01:01:50 GMT" if r.Header.Get("if-modified-since") == lm { @@ -164,8 +170,15 @@ func teardown() { s.server.Close() } +func cacheSize() int { + return s.transport.Cache.(*memoryCache).Size() +} + func resetTest() { s.transport.Cache = newMemoryCache() + s.transport.CacheKey = nil + s.transport.EnableETagPair = false + s.transport.MarkCachedResponses = false clock = &realClock{} } @@ -173,194 +186,120 @@ func resetTest() { // in cache and get incorrectly used for a following cacheable method request. func TestCacheableMethod(t *testing.T) { resetTest() + c := qt.New(t) { - req, err := http.NewRequest("POST", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "POST"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } + + body, resp := doMethod(t, "POST", "/method", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(body, qt.Equals, "POST") } { - req, err := http.NewRequest("GET", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "GET"; got != want { - t.Errorf("got wrong body %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("XFromCache header isn't blank") - } + body, resp := doMethod(t, "GET", "/method", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(body, qt.Equals, "GET") + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + } } -func TestDontServeHeadResponseToGetRequest(t *testing.T) { +func TestCacheKey(t *testing.T) { resetTest() - url := s.server.URL + "/" - req, err := http.NewRequest(http.MethodHead, url, nil) - if err != nil { - t.Fatal(err) + c := qt.New(t) + s.transport.CacheKey = func(req *http.Request) string { + return "foo" + } + _, resp := doMethod(t, "GET", "/method", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + _, ok := s.transport.Cache.Get("foo") + c.Assert(ok, qt.Equals, true) +} + +func TestEnableETagPair(t *testing.T) { + resetTest() + c := qt.New(t) + s.transport.EnableETagPair = true + + { + _, resp := doMethod(t, "GET", "/etag", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "124567") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "") } - _, err = s.client.Do(req) - if err != nil { - t.Fatal(err) + { + _, resp := doMethod(t, "GET", "/etag", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "124567") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "124567") } - req, err = http.NewRequest(http.MethodGet, url, nil) - if err != nil { - t.Fatal(err) + + // No HTTP caching in the following requests. + { + _, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "") } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) + { + _, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "61b7d44bc024f189195b549bf094fbe8") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc") } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("Cache should not match") +} + +func TestAround(t *testing.T) { + resetTest() + c := qt.New(t) + count := 0 + s.transport.Around = func(req *http.Request, key string) func() { + count++ + return func() { + count++ + } } + _, resp := doMethod(t, "GET", "/method", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(count, qt.Equals, 2) +} + +func TestDontServeHeadResponseToGetRequest(t *testing.T) { + resetTest() + c := qt.New(t) + doMethod(t, http.MethodHead, "/", nil) + _, resp := doMethod(t, http.MethodGet, "/", nil) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") } func TestDontStorePartialRangeInCache(t *testing.T) { resetTest() + c := qt.New(t) + s.transport.MarkCachedResponses = true + { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } + body, resp := doMethod(t, "GET", "/range", map[string]string{"range": "bytes=4-9"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusPartialContent) + c.Assert(body, qt.Equals, " text ") + c.Assert(cacheSize(), qt.Equals, 0) } { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Error("XFromCache header isn't blank") - } + body, resp := doMethod(t, "GET", "/range", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(body, qt.Equals, "Some text content") + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) } { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "1" { - t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } + body, resp := doMethod(t, "GET", "/range", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(body, qt.Equals, "Some text content") + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "1") + c.Assert(cacheSize(), qt.Equals, 1) } { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } + body, resp := doMethod(t, "GET", "/range", map[string]string{"range": "bytes=4-9"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusPartialContent) + c.Assert(body, qt.Equals, " text ") + c.Assert(cacheSize(), qt.Equals, 1) } } @@ -418,250 +357,132 @@ func TestOnlyReadBodyOnDemand(t *testing.T) { func TestGetOnlyIfCachedHit(t *testing.T) { resetTest() + c := qt.New(t) + s.transport.MarkCachedResponses = true { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + _, resp := doMethod(t, "GET", "/", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) } { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } + _, resp := doMethod(t, "GET", "/", map[string]string{"cache-control": "only-if-cached"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "1") + c.Assert(cacheSize(), qt.Equals, 1) } } func TestGetOnlyIfCachedMiss(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - if resp.StatusCode != http.StatusGatewayTimeout { - t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) - } + s.transport.MarkCachedResponses = true + c := qt.New(t) + _, resp := doMethod(t, "GET", "/", map[string]string{"cache-control": "only-if-cached"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusGatewayTimeout) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) } func TestGetNoStoreRequest(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("Cache-Control", "no-store") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } + s.transport.MarkCachedResponses = true + c := qt.New(t) + for i := 0; i < 2; i++ { + + _, resp := doMethod(t, "GET", "/", map[string]string{"cache-control": "no-store"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 0) + } } func TestGetNoStoreResponse(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } + s.transport.MarkCachedResponses = true + c := qt.New(t) + for i := 0; i < 2; i++ { + _, resp := doMethod(t, "GET", "/nostore", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 0) } } func TestGetWithEtag(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + s.transport.MarkCachedResponses = true + c := qt.New(t) + { + _, resp := doMethod(t, "GET", "/etag", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) } { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - // additional assertions to verify that 304 response is converted properly - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if _, ok := resp.Header["Connection"]; ok { - t.Fatalf("Connection header isn't absent") - } + _, resp := doMethod(t, "GET", "/etag", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "1") + c.Assert(cacheSize(), qt.Equals, 1) + _, ok := resp.Header["Connection"] + c.Assert(ok, qt.IsFalse) } } func TestGetWithLastModified(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) - if err != nil { - t.Fatal(err) - } + s.transport.MarkCachedResponses = true + c := qt.New(t) { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + _, resp := doMethod(t, "GET", "/lastmodified", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) } { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } + _, resp := doMethod(t, "GET", "/lastmodified", nil) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "1") + c.Assert(cacheSize(), qt.Equals, 1) } } func TestGetWithVary(t *testing.T) { resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") + s.transport.MarkCachedResponses = true + c := qt.New(t) { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") != "Accept" { - t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + _, resp := doMethod(t, "GET", "/varyaccept", map[string]string{"Accept": "text/plain"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) + c.Assert(resp.Header.Get("Vary"), qt.Equals, "Accept") } { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } + _, resp := doMethod(t, "GET", "/varyaccept", map[string]string{"Accept": "text/plain"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "1") } - req.Header.Set("Accept", "text/html") { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } + _, resp := doMethod(t, "GET", "/varyaccept", map[string]string{"Accept": "text/html"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) + c.Assert(resp.Header.Get("Vary"), qt.Equals, "Accept") } - req.Header.Set("Accept", "") { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } + _, resp := doMethod(t, "GET", "/varyaccept", map[string]string{"Accept": ""}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XFromCache), qt.Equals, "") + c.Assert(cacheSize(), qt.Equals, 1) + c.Assert(resp.Header.Get("Vary"), qt.Equals, "Accept") } } func TestGetWithDoubleVary(t *testing.T) { resetTest() + s.transport.MarkCachedResponses = true req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) if err != nil { t.Fatal(err) @@ -718,6 +539,7 @@ func TestGetWithDoubleVary(t *testing.T) { func TestGetWith2VaryHeaders(t *testing.T) { resetTest() + s.transport.MarkCachedResponses = true // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( @@ -817,6 +639,7 @@ func TestGetWith2VaryHeaders(t *testing.T) { func TestGetVaryUnused(t *testing.T) { resetTest() + s.transport.MarkCachedResponses = true req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) if err != nil { t.Fatal(err) @@ -850,6 +673,7 @@ func TestGetVaryUnused(t *testing.T) { func TestUpdateFields(t *testing.T) { resetTest() + s.transport.MarkCachedResponses = true req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) if err != nil { t.Fatal(err) @@ -1472,3 +1296,32 @@ func TestClientTimeout(t *testing.T) { t.Error("client.Do took 2+ seconds, want < 2 seconds") } } + +func doMethod(t testing.TB, method string, p string, headers map[string]string) (string, *http.Response) { + t.Helper() + req, err := http.NewRequest(method, s.server.URL+p, nil) + if err != nil { + t.Fatal(err) + } + if len(headers) > 0 { + for k, v := range headers { + req.Header.Set(k, v) + } + } + + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + + return buf.String(), resp +}