listen.go

97 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
// Package router provides domain-based reverse proxy, automatic TLS
// via Let's Encrypt, request logging, and server lifecycle management.
//
//	router.Proxy("app.example.com", "myapp:5000")
//	router.Listen(router.WithLogger(), security.New(), application.New(views))
package router

import (
	"cmp"
	"context"
	"log"
	"net/http"
	"os"
	"os/signal"
	"syscall"
	"time"
)

// Listen is the declarative entry point. It composes the middleware chain
// and starts the server. The last middleware is typically the application.
// Blocks until shutdown signal (SIGINT/SIGTERM).
//
//	router.Listen(
//	    router.WithLogger(),
//	    security.New(security.WithNonce(), security.WithHeaders()),
//	    application.New(views, ...),
//	)
func Listen(middleware ...func(http.Handler) http.Handler) {
	// Build handler chain: first middleware is outermost
	var handler http.Handler = http.NotFoundHandler()
	for i := len(middleware) - 1; i >= 0; i-- {
		handler = middleware[i](handler)
	}

	// Wrap with domain-based routing
	handler = Handler(handler)

	port := cmp.Or(os.Getenv("PORT"), "5000")
	addr := "0.0.0.0:" + port

	// Decide HTTP vs HTTPS based on configured domains
	domainsMu.RLock()
	hasDomains := len(domains) > 0
	domainsMu.RUnlock()

	if hasDomains || os.Getenv("CERT_DIR") != "" || port == "80" || port == "443" {
		ensureCertManager()

		httpsSrv := &http.Server{
			Addr:      ":443",
			Handler:   handler,
			TLSConfig: certMgr.TLSConfig(),
		}
		go func() {
			log.Printf("HTTPS server starting on https://0.0.0.0:443")
			if err := httpsSrv.ListenAndServeTLS("", ""); err != http.ErrServerClosed {
				log.Printf("HTTPS server error: %v", err)
			}
		}()

		httpSrv := &http.Server{
			Addr:    addr,
			Handler: certMgr.HTTPHandler(handler),
		}
		go func() {
			log.Printf("HTTP server starting on http://%s (ACME)", addr)
			if err := httpSrv.ListenAndServe(); err != http.ErrServerClosed {
				log.Printf("HTTP server error: %v", err)
			}
		}()

		waitForShutdown(httpsSrv, httpSrv)
	} else {
		srv := &http.Server{Addr: addr, Handler: handler}
		go func() {
			log.Printf("HTTP server starting on http://%s", addr)
			if err := srv.ListenAndServe(); err != http.ErrServerClosed {
				log.Printf("HTTP server error: %v", err)
			}
		}()

		waitForShutdown(srv)
	}
}

func waitForShutdown(servers ...*http.Server) {
	quit := make(chan os.Signal, 1)
	signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
	<-quit
	log.Println("Shutting down...")

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	for _, s := range servers {
		s.Shutdown(ctx)
	}
}