Skip to content

Commit

Permalink
feat: add WithMaxHTTPMaxBytes option to fetcher to limit HTTP respons…
Browse files Browse the repository at this point in the history
…e body size
  • Loading branch information
alnr committed Nov 13, 2023
1 parent e1d7bd3 commit 9909a38
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
35 changes: 28 additions & 7 deletions fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,32 @@ import (

// Fetcher is able to load file contents from http, https, file, and base64 locations.
type Fetcher struct {
hc *retryablehttp.Client
hc *retryablehttp.Client
limit int64
}

type opts struct {
hc *retryablehttp.Client
hc *retryablehttp.Client
limit int64
}

var ErrUnknownScheme = stderrors.New("unknown scheme")

// WithClient sets the http.Client the fetcher uses.
func WithClient(hc *retryablehttp.Client) func(*opts) {
func WithClient(hc *retryablehttp.Client) Modifier {
return func(o *opts) {
o.hc = hc
}
}

// WithMaxHTTPMaxBytes reads at most limit bytes from the HTTP response body,
// returning bytes.ErrToLarge if the limit would be exceeded.
func WithMaxHTTPMaxBytes(limit int64) Modifier {
return func(o *opts) {
o.limit = limit
}
}

func newOpts() *opts {
return &opts{
hc: httpx.NewResilientClient(),
Expand All @@ -52,7 +62,7 @@ func NewFetcher(opts ...Modifier) *Fetcher {
for _, f := range opts {
f(o)
}
return &Fetcher{hc: o.hc}
return &Fetcher{hc: o.hc, limit: o.limit}
}

// Fetch fetches the file contents from the source.
Expand Down Expand Up @@ -94,7 +104,18 @@ func (f *Fetcher) fetchRemote(ctx context.Context, source string) (*bytes.Buffer
return nil, errors.Errorf("expected http response status code 200 but got %d when fetching: %s", res.StatusCode, source)
}

return f.decode(res.Body)
if f.limit > 0 {
var buf bytes.Buffer
n, err := io.Copy(&buf, io.LimitReader(res.Body, f.limit+1))
if n > f.limit {
return nil, bytes.ErrTooLarge
}
if err != nil {
return nil, err
}
return &buf, nil
}
return f.toBuffer(res.Body)
}

func (f *Fetcher) fetchFile(source string) (*bytes.Buffer, error) {
Expand All @@ -106,10 +127,10 @@ func (f *Fetcher) fetchFile(source string) (*bytes.Buffer, error) {
_ = fp.Close()
}()

return f.decode(fp)
return f.toBuffer(fp)
}

func (f *Fetcher) decode(r io.Reader) (*bytes.Buffer, error) {
func (f *Fetcher) toBuffer(r io.Reader) (*bytes.Buffer, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return nil, err
Expand Down
19 changes: 15 additions & 4 deletions fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package fetcher

import (
"bytes"
"context"
"encoding/base64"
"fmt"
Expand All @@ -16,7 +17,6 @@ import (

"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -67,18 +67,29 @@ func TestFetcher(t *testing.T) {

t.Run("case=returns proper error on unknown scheme", func(t *testing.T) {
_, err := NewFetcher().Fetch("unknown-scheme://foo")
require.NotNil(t, err)

assert.True(t, errors.Is(err, ErrUnknownScheme))
assert.ErrorIs(t, err, ErrUnknownScheme)
assert.Contains(t, err.Error(), "unknown-scheme")
})

t.Run("case=FetcherContext cancels the HTTP request", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := NewFetcher().FetchContext(ctx, "https://config.invalid")
require.NotNil(t, err)

assert.ErrorIs(t, err, context.DeadlineExceeded)
})

t.Run("case=with-limit", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(bytes.Repeat([]byte("test"), 1000))
}))
t.Cleanup(srv.Close)

_, err := NewFetcher(WithMaxHTTPMaxBytes(3999)).Fetch(srv.URL)
assert.ErrorIs(t, err, bytes.ErrTooLarge)

_, err = NewFetcher(WithMaxHTTPMaxBytes(4000)).Fetch(srv.URL)
assert.NoError(t, err)
})
}

0 comments on commit 9909a38

Please sign in to comment.