Skip to content

Commit

Permalink
Add signing key rotation support (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
goodoldneon committed Apr 17, 2024
1 parent adaa74f commit 98c088c
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 51 deletions.
179 changes: 141 additions & 38 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 64,10 @@ type HandlerOpts struct {
// to os.Getenv("INNGEST_SIGNING_KEY").
SigningKey *string

// SigningKeyFallback is the fallback signing key for your app. If nil, this
// defaults to os.Getenv("INNGEST_SIGNING_KEY_FALLBACK").
SigningKeyFallback *string

// Env is the branch environment to deploy to. If nil, this uses
// os.Getenv("INNGEST_ENV"). This only deploys to branches if the
// signing key is a branch signing key.
Expand Down Expand Up @@ -99,6 103,19 @@ func (h HandlerOpts) GetSigningKey() string {
return *h.SigningKey
}

// GetSigningKeyFallback returns the signing key fallback defined within
// HandlerOpts, or the default defined within INNGEST_SIGNING_KEY_FALLBACK.
//
// This is the fallback private key used to register functions and communicate
// with the private API. If a request fails auth with the signing key then we'll
// try again with the fallback
func (h HandlerOpts) GetSigningKeyFallback() string {
if h.SigningKeyFallback == nil {
return os.Getenv("INNGEST_SIGNING_KEY_FALLBACK")
}
return *h.SigningKeyFallback
}

// GetEnv returns the env defined within HandlerOpts, or the default
// defined within INNGEST_ENV.
//
Expand Down Expand Up @@ -211,6 228,11 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
SetBasicResponseHeaders(w)

switch r.Method {
case http.MethodGet:
if err := h.introspect(w, r); err != nil {
_ = publicerr.WriteHTTP(w, err)
}
return
case http.MethodPost:
if err := h.invoke(w, r); err != nil {
_ = publicerr.WriteHTTP(w, err)
Expand Down Expand Up @@ -357,40 379,41 @@ func (h *handler) register(w http.ResponseWriter, r *http.Request) error {
registerURL = *h.RegisterURL
}

byt, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("error marshalling function config: %w", err)
}
req, err := http.NewRequest(http.MethodPost, registerURL, bytes.NewReader(byt))
if err != nil {
return fmt.Errorf("error creating new request: %w", err)
}
if syncID != "" {
qp := req.URL.Query()
qp.Set("deployId", syncID)
req.URL.RawQuery = qp.Encode()
}
createRequest := func() (*http.Request, error) {
byt, err := json.Marshal(config)
if err != nil {
return nil, fmt.Errorf("error marshalling function config: %w", err)
}

// If the request specifies a server kind then include it as an expectation
// in the outgoing request
if r.Header.Get(HeaderKeyServerKind) != "" {
req.Header.Set(
HeaderKeyExpectedServerKind,
r.Header.Get(HeaderKeyServerKind),
)
}
req, err := http.NewRequest(http.MethodPost, registerURL, bytes.NewReader(byt))
if err != nil {
return nil, fmt.Errorf("error creating new request: %w", err)
}
if syncID != "" {
qp := req.URL.Query()
qp.Set("deployId", syncID)
req.URL.RawQuery = qp.Encode()
}

key, err := hashedSigningKey([]byte(h.GetSigningKey()))
if err != nil {
return fmt.Errorf("error creating signing key: %w", err)
}
req.Header.Add(HeaderKeyAuthorization, fmt.Sprintf("Bearer %s", string(key)))
if h.GetEnv() != "" {
req.Header.Add(HeaderKeyEnv, h.GetEnv())
// If the request specifies a server kind then include it as an expectation
// in the outgoing request
if r.Header.Get(HeaderKeyServerKind) != "" {
req.Header.Set(
HeaderKeyExpectedServerKind,
r.Header.Get(HeaderKeyServerKind),
)
}

SetBasicRequestHeaders(req)

return req, nil
}
SetBasicRequestHeaders(req)

resp, err := http.DefaultClient.Do(req)
resp, err := fetchWithAuthFallback(
createRequest,
h.GetSigningKey(),
h.GetSigningKeyFallback(),
)
if err != nil {
return fmt.Errorf("error performing registration request: %w", err)
}
Expand Down Expand Up @@ -447,14 470,17 @@ func (h *handler) invoke(w http.ResponseWriter, r *http.Request) error {
}
}

if !IsDev() {
// Validate the signature.
if valid, err := ValidateSignature(r.Context(), sig, []byte(h.GetSigningKey()), byt); !valid {
h.Logger.Error("unauthorized inngest invoke request", "error", err)
return publicerr.Error{
Message: "unauthorized",
Status: 401,
}
if valid, err := ValidateSignature(
r.Context(),
sig,
h.GetSigningKey(),
h.GetSigningKeyFallback(),
byt,
); !valid {
h.Logger.Error("unauthorized inngest invoke request", "error", err)
return publicerr.Error{
Message: "unauthorized",
Status: 401,
}
}

Expand Down Expand Up @@ -863,6 620,83 @@ func (h *handler) invoke(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(resp)
}

type insecureIntrospection struct {
FunctionCount int `json:"function_count"`
HasEventKey bool `json:"has_event_key"`
HasSigningKey bool `json:"has_signing_key"`
Mode string `json:"mode"`
}

type secureIntrospection struct {
insecureIntrospection
SigningKeyFallbackHash *string `json:"signing_key_fallback_hash"`
SigningKeyHash *string `json:"signing_key_hash"`
}

func (h *handler) introspect(w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()

mode := "cloud"
if IsDev() {
mode = "dev"
}

sig := r.Header.Get(HeaderKeySignature)
valid, _ := ValidateSignature(
r.Context(),
sig,
h.GetSigningKey(),
h.GetSigningKeyFallback(),
[]byte{},
)
if valid {
var signingKeyHash *string
if h.GetSigningKey() != "" {
key, err := hashedSigningKey([]byte(h.GetSigningKey()))
if err != nil {
return fmt.Errorf("error hashing signing key: %w", err)
}
hash := string(key)
signingKeyHash = &hash
}

var signingKeyFallbackHash *string
if h.GetSigningKeyFallback() != "" {
key, err := hashedSigningKey([]byte(h.GetSigningKeyFallback()))
if err != nil {
return fmt.Errorf("error hashing signing key fallback: %w", err)
}
hash := string(key)
signingKeyFallbackHash = &hash
}

introspection := secureIntrospection{
insecureIntrospection: insecureIntrospection{
FunctionCount: len(h.funcs),
HasEventKey: os.Getenv("INNGEST_EVENT_KEY") != "",
HasSigningKey: h.GetSigningKey() != "",
Mode: mode,
},
SigningKeyFallbackHash: signingKeyFallbackHash,
SigningKeyHash: signingKeyHash,
}

w.Header().Set(HeaderKeyContentType, "application/json")
return json.NewEncoder(w).Encode(introspection)
}

introspection := insecureIntrospection{
FunctionCount: len(h.funcs),
HasEventKey: os.Getenv("INNGEST_EVENT_KEY") != "",
HasSigningKey: h.GetSigningKey() != "",
Mode: mode,
}

w.Header().Set(HeaderKeyContentType, "application/json")
return json.NewEncoder(w).Encode(introspection)

}

type StreamResponse struct {
StatusCode int `json:"status"`
Body any `json:"body"`
Expand Down
103 changes: 102 additions & 1 deletion handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 25,7 @@ import (

func init() {
os.Setenv("INNGEST_SIGNING_KEY", string(testKey))
os.Setenv("INNGEST_SIGNING_KEY_FALLBACK", string(testKeyFallback))
}

type EventA struct {
Expand Down Expand Up @@ -448,6 449,106 @@ func TestSteps(t *testing.T) {

}

func TestIntrospection(t *testing.T) {
fn := CreateFunction(
FunctionOpts{Name: "My servable function!"},
EventTrigger("test/event.a", nil),
func(ctx context.Context, input Input[any]) (any, error) {
return nil, nil
},
)
h := NewHandler("introspection", HandlerOpts{})
h.Register(fn)
server := httptest.NewServer(h)
defer server.Close()

t.Run("no signature", func(t *testing.T) {
// When the request has no signature, respond with the insecure
// introspection body

r := require.New(t)

reqBody := []byte("")
req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewReader(reqBody))
r.NoError(err)
resp, err := http.DefaultClient.Do(req)
r.Equal(http.StatusOK, resp.StatusCode)
r.NoError(err)

var respBody map[string]any
err = json.NewDecoder(resp.Body).Decode(&respBody)
r.NoError(err)

r.Equal(map[string]any{
"function_count": float64(1),
"has_event_key": false,
"has_signing_key": true,
"mode": "cloud",
}, respBody)
})

t.Run("valid signature", func(t *testing.T) {
// When the request has a valid signature, respond with the secure
// introspection body

r := require.New(t)

reqBody := []byte("")
sig := Sign(context.Background(), time.Now(), []byte(testKey), reqBody)
req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewReader(reqBody))
r.NoError(err)
req.Header.Set("X-Inngest-Signature", sig)
resp, err := http.DefaultClient.Do(req)
r.Equal(http.StatusOK, resp.StatusCode)
r.NoError(err)

var respBody map[string]any
err = json.NewDecoder(resp.Body).Decode(&respBody)
r.NoError(err)

signingKeyHash, err := hashedSigningKey([]byte(testKey))
r.NoError(err)
signingKeyFallbackHash, err := hashedSigningKey([]byte(testKeyFallback))
r.NoError(err)
r.Equal(map[string]any{
"function_count": float64(1),
"has_event_key": false,
"has_signing_key": true,
"mode": "cloud",
"signing_key_fallback_hash": string(signingKeyFallbackHash),
"signing_key_hash": string(signingKeyHash),
}, respBody)
})

t.Run("invalid signature", func(t *testing.T) {
// When the request has an invalid signature, respond with the insecure
// introspection body

r := require.New(t)

reqBody := []byte("")
invalidKey := "deadbeef"
sig := Sign(context.Background(), time.Now(), []byte(invalidKey), reqBody)
req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewReader(reqBody))
r.NoError(err)
req.Header.Set("X-Inngest-Signature", sig)
resp, err := http.DefaultClient.Do(req)
r.Equal(http.StatusOK, resp.StatusCode)
r.NoError(err)

var respBody map[string]any
err = json.NewDecoder(resp.Body).Decode(&respBody)
r.NoError(err)

r.Equal(map[string]any{
"function_count": float64(1),
"has_event_key": false,
"has_signing_key": true,
"mode": "cloud",
}, respBody)
})
}

func createRequest(t *testing.T, evt any) *sdkrequest.Request {
t.Helper()

Expand Down Expand Up @@ -487,7 588,7 @@ func handlerPost(t *testing.T, url string, r *sdkrequest.Request) *http.Response
t.Helper()

body := marshalRequest(t, r)
sig := Sign(context.Background(), time.Now(), testKey, body)
sig := Sign(context.Background(), time.Now(), []byte(testKey), body)

req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
require.NoError(t, err)
Expand Down
Loading

0 comments on commit 98c088c

Please sign in to comment.