policy.go

77 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
// Package security provides HTTP security middleware, authentication,
// and rate limiting backed by a database.
//
//	security.New(
//	    security.WithNonce(),
//	    security.WithHeaders(),
//	    security.WithRateLimit(db, 100, 60),
//	)
package security

import (
	"net/http"

	"congo.gg/pkg/database"
)

// Policy composes security middleware into a single handler chain.
type Policy struct {
	useNonce   bool
	useHeaders bool
	rateLimitDB     *database.Database
	rateLimitMax    int
	rateLimitWindow int // seconds
}

// Option configures a security policy.
type Option func(*Policy)

// WithNonce enables per-request CSP nonce generation.
func WithNonce() Option {
	return func(p *Policy) { p.useNonce = true }
}

// WithHeaders enables security response headers (CSP, X-Frame-Options, etc).
func WithHeaders() Option {
	return func(p *Policy) { p.useHeaders = true }
}

// WithRateLimit enables IP-based rate limiting backed by a database.
// n is the max requests per window. windowSeconds is the time period.
func WithRateLimit(db *database.Database, n, windowSeconds int) Option {
	return func(p *Policy) {
		p.rateLimitDB = db
		p.rateLimitMax = n
		p.rateLimitWindow = windowSeconds
	}
}

// New creates a security policy middleware from functional options.
// The returned middleware composes all enabled features into one handler.
func New(opts ...Option) func(http.Handler) http.Handler {
	p := &Policy{}
	for _, opt := range opts {
		opt(p)
	}

	return func(next http.Handler) http.Handler {
		h := next

		// Rate limiting (innermost — check before processing)
		if p.rateLimitMax > 0 && p.rateLimitDB != nil {
			h = rateLimit(p.rateLimitDB, p.rateLimitMax, p.rateLimitWindow, h)
		}

		// Security headers
		if p.useHeaders {
			h = headers(h)
		}

		// Nonce (outermost — sets context for headers)
		if p.useNonce {
			h = nonce(h)
		}

		return h
	}
}