nonce_test.go

53 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
package security

import (
	"context"
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestNonceFrom_Empty(t *testing.T) {
	if n := NonceFrom(context.Background()); n != "" {
		t.Errorf("expected empty nonce, got %q", n)
	}
}

func TestNonce_SetsContext(t *testing.T) {
	var captured string
	mw := New(WithNonce())
	handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		captured = NonceFrom(r.Context())
	}))

	rec := httptest.NewRecorder()
	req := httptest.NewRequest("GET", "/", nil)
	handler.ServeHTTP(rec, req)

	if captured == "" {
		t.Error("expected nonce to be set in context")
	}
	if len(captured) < 10 {
		t.Errorf("nonce too short: %q", captured)
	}
}

func TestNonce_UniquePerRequest(t *testing.T) {
	var nonces []string
	mw := New(WithNonce())
	handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		nonces = append(nonces, NonceFrom(r.Context()))
	}))

	for range 5 {
		handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
	}

	seen := make(map[string]bool)
	for _, n := range nonces {
		if seen[n] {
			t.Errorf("duplicate nonce: %q", n)
		}
		seen[n] = true
	}
}