middleware.go

103 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
package application

import (
	"bufio"
	"encoding/json"
	"fmt"
	"log"
	"net"
	"net/http"
	"os"
	"strings"
	"time"
)

// requestLog represents a logged request.
type requestLog struct {
	Timestamp  string `json:"timestamp"`
	Method     string `json:"method"`
	Path       string `json:"path"`
	Status     int    `json:"status"`
	Duration   string `json:"duration"`
	DurationMs int64  `json:"duration_ms"`
	IP         string `json:"ip"`
	UserAgent  string `json:"user_agent,omitempty"`
}

// responseWriter wraps http.ResponseWriter to capture status code.
// It also implements http.Flusher to support SSE streaming.
type responseWriter struct {
	http.ResponseWriter
	status int
}

// WriteHeader captures the status code and delegates to the underlying ResponseWriter.
func (rw *responseWriter) WriteHeader(code int) {
	rw.status = code
	rw.ResponseWriter.WriteHeader(code)
}

// Flush implements http.Flusher for SSE streaming support.
func (rw *responseWriter) Flush() {
	if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
		flusher.Flush()
	}
}

// Hijack implements http.Hijacker for WebSocket upgrade support.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
	if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
		return hijacker.Hijack()
	}
	return nil, nil, fmt.Errorf("upstream ResponseWriter does not implement http.Hijacker")
}

// LoggingMiddleware returns middleware that logs all requests.
func LoggingMiddleware(next http.Handler) http.Handler {
	isProduction := os.Getenv("ENV") == "production"

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()

		// Wrap response writer to capture status
		wrapped := &responseWriter{ResponseWriter: w, status: http.StatusOK}

		// Process request
		next.ServeHTTP(wrapped, r)

		// Calculate duration
		duration := time.Since(start)

		// Get client IP
		ip := r.RemoteAddr
		if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
			// X-Forwarded-For: client, proxy1, proxy2 — use first (client) IP
			if i := strings.IndexByte(forwarded, ','); i > 0 {
				ip = strings.TrimSpace(forwarded[:i])
			} else {
				ip = forwarded
			}
		}

		// Log request
		if isProduction {
			// Structured JSON logging for production
			logEntry := requestLog{
				Timestamp:  time.Now().UTC().Format(time.RFC3339),
				Method:     r.Method,
				Path:       r.URL.Path,
				Status:     wrapped.status,
				Duration:   duration.String(),
				DurationMs: duration.Milliseconds(),
				IP:         ip,
				UserAgent:  r.UserAgent(),
			}
			if data, err := json.Marshal(logEntry); err == nil {
				log.Println(string(data))
			}
		} else {
			// Human-readable logging for development
			log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, duration)
		}
	})
}