From 35163f1638a28d8d3890c3d645b95859830b73e5 Mon Sep 17 00:00:00 2001 From: Shantanu Date: Tue, 14 Jan 2025 17:01:23 -0800 Subject: [PATCH] return custom error from authz roundtripper --- authz/roundtripper.go | 15 ++++++++++----- client/client.go | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/authz/roundtripper.go b/authz/roundtripper.go index 764c5db..ffe8035 100644 --- a/authz/roundtripper.go +++ b/authz/roundtripper.go @@ -6,15 +6,20 @@ import ( "net/http" ) -var ErrInvalidClient = errors.New("invalid client") +var ErrInvalidAuthzClient = errors.New("invalid authz client") type AuthorizedRoundTripper struct { token string savvyVersion string + // wrap error returned by RoundTrip + wrapErr error } -func NewRoundTripper(token, savvyVersion string) *AuthorizedRoundTripper { - return &AuthorizedRoundTripper{token: token, savvyVersion: savvyVersion} +// NewRoundTripper returns a new AuthorizedRoundTripper +// +// Caller must provide non nil err to wrap the error returned by RoundTrip +func NewRoundTripper(token, savvyVersion string, err error) *AuthorizedRoundTripper { + return &AuthorizedRoundTripper{token: token, savvyVersion: savvyVersion, wrapErr: err} } func (a *AuthorizedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -26,14 +31,14 @@ func (a *AuthorizedRoundTripper) RoundTrip(req *http.Request) (*http.Response, e // Use the embedded Transport to perform the actual request res, err := http.DefaultTransport.RoundTrip(clonedReq) if err != nil { - err = fmt.Errorf("%w: %v", ErrInvalidClient, err) + err = fmt.Errorf("%w: %v", a.wrapErr, err) return nil, err } // If we get a 401 Unauthorized, then the token is expired // and we need to refresh it if res.StatusCode == http.StatusUnauthorized { - return nil, fmt.Errorf("%w: invalid token", ErrInvalidClient) + return nil, fmt.Errorf("%w: invalid token", a.wrapErr) } return res, err } diff --git a/client/client.go b/client/client.go index de65090..75ec4d1 100644 --- a/client/client.go +++ b/client/client.go @@ -89,7 +89,7 @@ func New() (Client, error) { } cl := &http.Client{ - Transport: authz.NewRoundTripper(cfg.Token, config.Version()), + Transport: authz.NewRoundTripper(cfg.Token, config.Version(), ErrInvalidClient), } c := &client{