// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package app provides a managed lifecycle for command-line applications,
// ensuring graceful shutdown on OS signals.
//
// The [Run] function is the main entry point. It wraps your application
// components ([Runnable]), executing them concurrently. It listens for
// interrupt signals (like SIGINT/SIGTERM) and propagates a cancellation signal
// via a [context.Context]. This allows your application to perform cleanup
// tasks before exiting.
//
// # Usage
//
// A typical use case involves starting workers or servers that run until
// interrupted. The [Run] function handles signal trapping, concurrency, and
// timeouts, letting you focus on the business logic.
//
// Example:
//
// func main() {
// // 1. Configure a logger.
// logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
//
// // 2. Define the application components.
// // These functions block until ctx is canceled or an error occurs.
// worker := func(ctx context.Context) error {
// logger.Info("Worker started")
//
// // Simulate a task that runs periodically.
// ticker := time.NewTicker(1 * time.Second)
// defer ticker.Stop()
//
// for {
// select {
// case <-ctx.Done():
// // Context canceled (signal received or sibling component failed).
// logger.Info("Worker stopping...")
//
// // Perform cleanup (e.g., closing DB connections).
// time.Sleep(500 * time.Millisecond)
// return nil
//
// case t := <-ticker.C:
// logger.Info("Working...", "time", t.Format(time.TimeOnly))
// }
// }
// }
//
// server := func(ctx context.Context) error {
// logger.Info("Server started")
// <-ctx.Done()
// logger.Info("Server stopping...")
// return nil
// }
//
// // 3. Run the application components concurrently.
// err := app.RunAll(
// []app.Runnable{worker, server},
// app.WithLogger(logger),
// )
// if err != nil {
// logger.Error("Application failed", "error", err)
// os.Exit(1)
// }
// }
package app
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"runtime/debug"
"syscall"
"time"
"golang.org/x/sync/errgroup"
)
// DefaultTimeout is the default duration to wait for the application to
// gracefully shut down after receiving a termination signal.
const DefaultTimeout = 10 * time.Second
// Runnable defines a function that can be executed by the application runner.
// It receives a [context.Context] that is canceled when a shutdown signal is
// received, or if another concurrently running [Runnable] returns an error.
// The function should perform its cleanup and return when the context is done.
type Runnable func(ctx context.Context) error
// config holds the internal settings for the application runner, including
// logging, timeouts, signal handling, and parent context.
type config struct {
logger *slog.Logger
timeout time.Duration
signals []os.Signal
ctx context.Context
}
// Option is a function that configures the application runner [config].
type Option func(*config)
// WithLogger provides a custom [slog.Logger] for the application runner. If
// not set, the runner defaults to [slog.Default]. A nil value will be ignored.
func WithLogger(log *slog.Logger) Option {
return func(opts *config) {
if log != nil {
opts.logger = log
}
}
}
// WithTimeout sets a custom [time.Duration] timeout for the graceful shutdown
// process. If the application components take longer than this duration to
// return after a shutdown signal is received, the runner will exit with an
// error. A negative or zero duration will be ignored, and the [DefaultTimeout]
// is used instead.
func WithTimeout(d time.Duration) Option {
return func(opts *config) {
if d > 0 {
opts.timeout = d
}
}
}
// WithSignals allows customization of which [os.Signal] triggers a shutdown.
// If not used, it defaults to [syscall.SIGTERM] and [syscall.SIGINT].
func WithSignals(signals ...os.Signal) Option {
return func(c *config) {
if len(signals) > 0 {
c.signals = signals
}
}
}
// WithContext sets a parent [context.Context] for the runner. The runner's main
// context will be a child of this parent. Cancelling the parent context
// triggers a graceful shutdown. If not set, [context.Background] is used as
// the default parent. A nil value will be ignored.
func WithContext(ctx context.Context) Option {
return func(c *config) {
if ctx != nil {
c.ctx = ctx
}
}
}
// Run provides a managed execution environment for a single [Runnable].
// It launches the [Runnable] in a separate goroutine and blocks until it
// completes, an OS interrupt signal is caught, or the parent context is
// canceled. For running multiple components concurrently, see [RunAll].
func Run(runnable Runnable, opts ...Option) error {
return RunAll([]Runnable{runnable}, opts...)
}
// RunAll provides a managed execution environment for multiple [Runnable]
// functions. It launches each [Runnable] in a separate goroutine and blocks
// until they all complete on their own, an OS interrupt signal is caught, the
// parent context (if specified via [WithContext]) is canceled, or any single
// [Runnable] returns an error.
//
// Upon receiving a signal or encountering an error in any [Runnable], it
// cancels the context passed to all Runnables and waits for the specified
// shutdown timeout. The Runnables are expected to honor the context
// cancellation and perform any necessary cleanup before returning. [RunAll]
// returns any error from the Runnables themselves, or an error if the shutdown
// process times out.
func RunAll(runnables []Runnable, opts ...Option) error {
cfg := config{
logger: slog.Default(),
timeout: DefaultTimeout,
signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT},
ctx: context.Background(),
}
for _, opt := range opts {
opt(&cfg)
}
// Create a context that cancels on OS signals.
ctx, cancel := signal.NotifyContext(cfg.ctx, cfg.signals...)
defer cancel()
// Use errgroup to manage concurrent runnables. The group context will be
// canceled if the base context is canceled, or if any goroutine returns an
// error.
g, gCtx := errgroup.WithContext(ctx)
cfg.logger.Info("Application started", slog.Int("components", len(runnables)))
failureCh := make(chan struct{}, 1)
for _, fn := range runnables {
g.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
stack := string(debug.Stack())
err = fmt.Errorf("application panic: %v\nstack: %s", r, stack)
}
if err != nil {
select {
case failureCh <- struct{}{}:
default:
}
}
}()
return fn(gCtx)
})
}
errCh := make(chan error, 1)
go func() {
errCh <- g.Wait()
}()
select {
case err := <-errCh:
// The application exited naturally.
if err != nil {
return fmt.Errorf("application exited with error: %w", err)
}
cfg.logger.Info("Application stopped")
return nil
case <-ctx.Done():
// A signal was received (or parent context canceled).
cfg.logger.Info("Shutdown signal received, initiating graceful shutdown")
case <-failureCh:
// A component failed, triggering a shutdown of the others.
cfg.logger.Info("Component failure detected, initiating graceful shutdown")
}
timer := time.NewTimer(cfg.timeout)
defer timer.Stop()
select {
case err := <-errCh:
// We consider context.Canceled as a natural byproduct of the shutdown.
if err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("application exited with error: %w", err)
}
cfg.logger.Info("Shutdown completed successfully")
return nil
case <-timer.C:
return fmt.Errorf("shutdown timed out after %v", cfg.timeout)
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package auth provides JWT-based authentication and authorization middleware
// for the router ecosystem.
//
// It defines a Guard that intercepts incoming requests, extracts and verifies
// a Bearer token, and evaluates a set of authorization rules. Successfully
// parsed claims are injected into the request context for downstream handlers.
//
// # Usage
//
// To secure your API, configure a Guard with a JWT verifier and apply it as
// middleware using the router's chaining capabilities. You can define custom
// rules or use the provided role-based rules.
//
// Example:
//
// // 1. Initialize the JWT verifier and the auth guard.
// verifier := jwt.NewVerifier[*auth.Claims](keySet)
// guard := auth.NewGuard(verifier)
//
// // 2. Setup the router.
// r := router.New()
//
// // 3. Protect a route requiring authentication and specific roles.
// r.HandleFunc(
// "POST /admin/users",
// createUser,
// guard.Secure(auth.HasRole[*auth.Claims]("admin")),
// )
//
// // Inside your handler, retrieve the claims:
// func createUser(e *router.Exchange) error {
// claims, ok := auth.From[*auth.Claims](e)
// if !ok {
// return errors.New("claims missing")
// }
// // ... handle request ...
// }
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"slices"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/jose/jwt"
"github.com/deep-rent/nexus/router"
)
// Scheme defines the expected authentication scheme for the Authorization
// header. It is used to extract the JWT token from the request.
const Scheme = "Bearer"
const (
// ReasonMissingToken indicates that the Authorization header was either
// missing or did not contain a valid Bearer token.
ReasonMissingToken = "missing_token"
// ReasonInvalidToken indicates that the token was provided but failed
// verification (e.g., expired, mismatched signature, or malformed).
ReasonInvalidToken = "invalid_token"
// ReasonInsufficientPrivileges indicates that the token is valid, but the
// associated claims failed to satisfy the required authorization rules.
ReasonInsufficientPrivileges = "insufficient_privileges"
)
const (
// RoleAdmin represents an elevated user role with full administrative access.
RoleAdmin = "admin"
)
// contextKey prevents collisions with other packages.
type contextKey struct{}
// claimsKey is the internal context key used to store and retrieve parsed
// JWT claims.
var claimsKey contextKey
// FromContext retrieves the parsed claims from a standard [context.Context].
// It returns the claims and a boolean indicating whether they were found.
func FromContext[T jwt.Claims](ctx context.Context) (T, bool) {
claims, ok := ctx.Value(claimsKey).(T)
return claims, ok
}
// FromRequest retrieves the parsed claims directly from an [*http.Request].
// It returns the claims and a boolean indicating whether they were found.
func FromRequest[T jwt.Claims](req *http.Request) (T, bool) {
return FromContext[T](req.Context())
}
// From retrieves the parsed claims from a [*router.Exchange].
// This is the preferred method for extracting claims within route handlers.
func From[T jwt.Claims](e *router.Exchange) (T, bool) {
return FromContext[T](e.Context())
}
// RoleClaims defines an interface for JWT claims that include role-based
// access control capabilities.
type RoleClaims interface {
jwt.Claims
// HasRole checks if the provided role exists within the claims.
HasRole(name string) bool
}
// Claims is a standard implementation of [RoleClaims]. It embeds the standard
// JWT reserved claims and adds a custom "rol" slice for roles.
type Claims struct {
jwt.Reserved
// Rol is a slice of strings representing the assigned permissions or
// groups.
Rol []string `json:"rol,omitempty"`
}
// HasRole returns true if the specified role is present in the Roles slice.
func (c *Claims) HasRole(name string) bool {
return slices.Contains(c.Rol, name)
}
// Ensure Claims implements RoleClaims.
var _ RoleClaims = (*Claims)(nil)
// Rule defines an authorization condition that must be met for a request to
// proceed. Rules are evaluated after the JWT has been successfully verified.
type Rule[T jwt.Claims] interface {
// Eval evaluates the rule against the current context and parsed claims.
// Returning an error indicates the rule failed, and access is denied.
Eval(ctx context.Context, claims T) error
}
// RuleFunc is an adapter to allow the use of ordinary functions as security
// rules. If f is a function with the appropriate signature, RuleFunc(f) is a
// [Rule] that calls f.
type RuleFunc[T jwt.Claims] func(context.Context, T) error
// Eval calls f(ctx, claims) to implement the [Rule] interface.
func (f RuleFunc[T]) Eval(ctx context.Context, claims T) error {
return f(ctx, claims)
}
// HasRole creates a [Rule] that enforces the presence of a specific single
// role.
func HasRole[T RoleClaims](role string) Rule[T] {
return RuleFunc[T](func(ctx context.Context, claims T) error {
if !claims.HasRole(role) {
return fmt.Errorf("requires role %q", role)
}
return nil
})
}
// AnyRole creates a [Rule] that passes if the user possesses at least one of
// the specified roles.
func AnyRole[T RoleClaims](roles ...string) Rule[T] {
return RuleFunc[T](func(ctx context.Context, claims T) error {
if slices.ContainsFunc(roles, claims.HasRole) {
return nil
}
return fmt.Errorf("requires at least one of the roles: %v", roles)
})
}
// AllRoles creates a [Rule] that mandates the presence of all specified roles.
func AllRoles[T RoleClaims](roles ...string) Rule[T] {
return RuleFunc[T](func(ctx context.Context, claims T) error {
for _, role := range roles {
if !claims.HasRole(role) {
return fmt.Errorf("missing required role %q", role)
}
}
return nil
})
}
// Guard is responsible for intercepting HTTP requests, validating their JWT
// authentication, and enforcing defined authorization rules.
type Guard[T jwt.Claims] struct {
// verifier is the internal [jwt.Verifier] used to process tokens.
verifier jwt.Verifier[T]
}
// NewGuard creates a new [Guard] using the provided JWT verifier.
func NewGuard[T jwt.Claims](v jwt.Verifier[T]) *Guard[T] {
return &Guard[T]{
verifier: v,
}
}
// Secure produces a [router.Middleware] that protects routes.
//
// It extracts a Bearer token from the Authorization header, verifies its
// signature and validity, and ensures all provided rules pass. If any step
// fails, it returns a structured [*router.Error] and halts the middleware
// chain.
func (g *Guard[T]) Secure(rules ...Rule[T]) router.Middleware {
return func(next router.Handler) router.Handler {
return router.HandlerFunc(func(e *router.Exchange) error {
token := header.Credentials(e.R.Header, Scheme)
if token == "" {
return &router.Error{
Status: http.StatusUnauthorized,
Reason: ReasonMissingToken,
Description: "missing or malformed bearer token",
}
}
claims, err := g.verifier.Verify([]byte(token))
if err != nil {
return &router.Error{
Status: http.StatusUnauthorized,
Reason: ReasonInvalidToken,
Description: "the provided token is invalid or expired",
Cause: err,
}
}
for _, rule := range rules {
if err := rule.Eval(e.Context(), claims); err != nil {
// If a rule explicitly returns a router.Error, pass it
// through to allow custom API error responses.
if re, ok := errors.AsType[*router.Error](err); ok {
return re
}
// Otherwise, wrap it in a standard 403 Forbidden error.
return &router.Error{
Status: http.StatusForbidden,
Reason: ReasonInsufficientPrivileges,
Description: "access denied by security policy",
Cause: err,
}
}
}
// Embed the verified claims into the request context and update
// the Exchange so downstream handlers have access.
e.R = e.R.WithContext(context.WithValue(e.Context(), claimsKey, claims))
return next.ServeHTTP(e)
})
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package backoff provides customizable strategies for retrying operations
// with increasing delays.
//
// The core of the package is the [Strategy] interface, which computes the next
// backoff duration. Implementations are stateful; the [Strategy.Next] method
// returns progressively longer durations with each call. Once the retried
// operation is successful or abandoned, [Strategy.Done] must be called to reset
// the strategy's internal state.
//
// # Usage
//
// A default exponential backoff strategy with jitter can be created using the
// [New] function. The behavior can be customized using various [Option]
// functions, such as [WithMinDelay], [WithMaxDelay], [WithGrowthFactor], and
// [WithJitterAmount]. Jitter is added by default to prevent multiple clients
// from retrying in sync (the "thundering herd" problem), which can overwhelm a
// recovering service.
//
// Example:
//
// s := backoff.New(
// backoff.WithMinDelay(500*time.Millisecond),
// backoff.WithMaxDelay(30*time.Second),
// )
// defer s.Done()
//
// for {
// err := doWork()
// if err == nil {
// break
// }
// time.Sleep(s.Next())
// }
package backoff
import (
"math"
"sync/atomic"
"time"
"github.com/deep-rent/nexus/internal/jitter"
)
const (
// DefaultMinDelay is the default minimum time between consecutive retries.
DefaultMinDelay = 1 * time.Second
// DefaultMaxDelay is the default maximum time between consecutive retries.
DefaultMaxDelay = 1 * time.Minute
// DefaultGrowthFactor is the default growth factor in exponential backoff.
DefaultGrowthFactor float64 = 2.0
// DefaultJitterAmount is the default amount of jitter applied.
DefaultJitterAmount float64 = 0.5
)
// Strategy defines the contract for a backoff algorithm.
// Implementations of this interface are expected to be safe for concurrent use.
type Strategy interface {
// Next returns the backoff duration for the upcoming retry attempt.
// This method is stateful and returns incrementally larger durations based
// on the number of times it has been called since the last call to Done.
// The returned duration is bounded by [Strategy.MinDelay] and
// [Strategy.MaxDelay].
Next() time.Duration
// Done resets the strategy's internal state, such as its attempt counter.
// This must be called after the retried operation succeeds or is abandoned.
Done()
// MinDelay returns the lower bound for the backoff duration returned by Next.
MinDelay() time.Duration
// MaxDelay returns the upper bound for the backoff duration returned by Next.
MaxDelay() time.Duration
}
// constant is a [Strategy] implementation that always returns a fixed delay.
type constant struct {
// delay is the fixed duration returned by [constant.Next].
delay time.Duration
}
// Constant produces a [Strategy] that always yields the same delay duration.
// If the provided delay is negative, it is treated as zero (meaning no delay).
func Constant(delay time.Duration) Strategy {
return &constant{delay: max(0, delay)}
}
// Next returns the fixed delay for this [constant] strategy.
func (c *constant) Next() time.Duration { return c.delay }
// Done implements [Strategy.Done] but performs no action as [constant] is
// stateless.
func (c *constant) Done() {}
// MinDelay returns the fixed delay duration.
func (c *constant) MinDelay() time.Duration { return c.delay }
// MaxDelay returns the fixed delay duration.
func (c *constant) MaxDelay() time.Duration { return c.delay }
var _ Strategy = (*constant)(nil)
// linear is a [Strategy] implementation that increases delay linearly based
// on the attempt count.
type linear struct {
// minDelay is the base step for the linear increment.
minDelay time.Duration
// maxDelay is the ceiling for the backoff duration.
maxDelay time.Duration
// attempts tracks the number of times [linear.Next] has been called.
attempts atomic.Int64
}
// Next returns the backoff duration for the upcoming retry attempt.
func (l *linear) Next() time.Duration {
n := l.attempts.Add(1)
// If n is huge, simply return the delay limit to avoid overflow.
if l.minDelay > 0 && n > int64(l.maxDelay/l.minDelay) {
return l.maxDelay
}
d := l.minDelay * time.Duration(n)
return max(l.minDelay, min(l.maxDelay, d))
}
// Done resets the attempt counter for the [linear] strategy.
func (l *linear) Done() { l.attempts.Store(0) }
// MinDelay returns the minimum delay configured for this [linear] strategy.
func (l *linear) MinDelay() time.Duration { return l.minDelay }
// MaxDelay returns the maximum delay configured for this [linear] strategy.
func (l *linear) MaxDelay() time.Duration { return l.maxDelay }
var _ Strategy = (*linear)(nil)
// exponential is a [Strategy] implementation that increases delay exponentially.
type exponential struct {
// minDelay is the initial base duration before growth.
minDelay time.Duration
// maxDelay is the ceiling for the backoff duration.
maxDelay time.Duration
// growthFactor is the multiplier applied to each subsequent attempt.
growthFactor float64
// attempts tracks the number of times [exponential.Next] has been called.
attempts atomic.Int64
}
// Next returns the backoff duration for the upcoming retry attempt.
func (e *exponential) Next() time.Duration {
n := e.attempts.Add(1)
d := time.Duration(float64(e.minDelay) * math.Pow(e.growthFactor, float64(n)))
return max(e.minDelay, min(e.maxDelay, d))
}
// Done resets the attempt counter for the [exponential] strategy.
func (e *exponential) Done() { e.attempts.Store(0) }
// MinDelay returns the minimum delay configured for this [exponential] strategy.
func (e *exponential) MinDelay() time.Duration { return e.minDelay }
// MaxDelay returns the maximum delay configured for this [exponential] strategy.
func (e *exponential) MaxDelay() time.Duration { return e.maxDelay }
var _ Strategy = (*exponential)(nil)
// spread decorates a [Strategy] with jittering capabilities in order to spread
// out retry attempts over time.
type spread struct {
// s is the underlying [Strategy] being jittered.
s Strategy
// j is the jitter implementation used to modify durations.
j *jitter.Jitter
}
// Next returns the backoff duration from the underlying strategy after
// applying jitter.
func (s *spread) Next() time.Duration {
return s.j.Apply(s.s.Next())
}
// Done resets the underlying strategy's state.
func (s *spread) Done() {
s.s.Done()
}
// MinDelay returns the jittered lower bound of the underlying [Strategy].
func (s *spread) MinDelay() time.Duration {
return s.j.Floor(s.s.MinDelay(), 1)
}
// MaxDelay returns the maximum delay of the underlying [Strategy].
// Jitter does not affect the maximum delay.
func (s *spread) MaxDelay() time.Duration {
return s.s.MaxDelay()
}
var _ Strategy = (*spread)(nil)
// config holds the parameters for building a [Strategy] via [New].
type config struct {
// minDelay is the minimum duration between retries.
minDelay time.Duration
// maxDelay is the maximum duration between retries.
maxDelay time.Duration
// growthFactor is the exponential multiplier.
growthFactor float64
// jitterAmount is the fraction of jitter to apply.
jitterAmount float64
// r is the source of randomness for jitter.
r jitter.Rand
}
// Option customizes the behavior of a backoff [Strategy].
type Option func(*config)
// WithMinDelay sets the minimum time between consecutive retries.
// It is capped at zero (meaning no delay) if a negative duration is provided.
// If equal to or greater than the maximum delay, the backoff delays remain
// constant at the maximum delay. If not customized, the [DefaultMinDelay]
// is used.
//
// When jitter is introduced, the minimum delay is effectively reduced
// proportional to the jitter amount. Thus, the strategy might return a delay
// shorter than the configured minimum delay, depending on the random output.
func WithMinDelay(d time.Duration) Option {
return func(c *config) {
c.minDelay = max(0, d)
}
}
// WithMaxDelay sets the maximum time between consecutive retries.
// It is capped at zero (meaning no delay) if a negative duration is provided.
// If less than or equal to the minimum delay, the backoff delays remain
// constant at the maximum delay. If not customized, the [DefaultMaxDelay]
// is used.
func WithMaxDelay(d time.Duration) Option {
return func(c *config) {
c.maxDelay = max(0, d)
}
}
// WithGrowthFactor determines the growth factor (multiplier) for exponential
// backoff. A factor equal to one results in linear backoff, where the minimum
// delay becomes the step size. Any factor less than one is treated as one.
// If not customized, the [DefaultGrowthFactor] is used.
func WithGrowthFactor(f float64) Option {
return func(c *config) {
c.growthFactor = f
}
}
// WithJitterAmount specifies the amount of random jitter to apply to the
// backoff delays. It is expressed as a fraction of the delay, where 0 means no
// jitter and 1 means full jitter. The given number is capped between 0 and 1.
// If not customized, the [DefaultJitterAmount] is used.
//
// Jitter scatters the retry attempts in time, which aims to mitigate the
// thundering herd problem, where many clients retry simultaneously.
func WithJitterAmount(p float64) Option {
return func(c *config) {
c.jitterAmount = min(1, p)
}
}
// WithRand sets the source of randomness for jittering. If not specified or
// nil, a default source will be seeded with the current system time.
func WithRand(r jitter.Rand) Option {
return func(c *config) {
if r != nil {
c.r = r
}
}
}
// New creates a new backoff [Strategy] based on the provided options.
func New(opts ...Option) Strategy {
c := config{
minDelay: DefaultMinDelay,
maxDelay: DefaultMaxDelay,
growthFactor: DefaultGrowthFactor,
jitterAmount: DefaultJitterAmount,
}
for _, opt := range opts {
opt(&c)
}
if c.minDelay >= c.maxDelay {
return &constant{
delay: c.maxDelay,
}
}
if c.growthFactor <= 1 {
return &linear{
minDelay: c.minDelay,
maxDelay: c.maxDelay,
}
}
s := &exponential{
minDelay: c.minDelay,
maxDelay: c.maxDelay,
growthFactor: c.growthFactor,
}
if c.jitterAmount <= 0 {
return s
}
return &spread{
s: s,
j: jitter.New(c.jitterAmount, c.r),
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cache provides a generic, auto-refreshing in-memory cache for a
// resource fetched from a URL.
//
// The core of the package is the [Controller], a [scheduler.Tick] that
// periodically fetches a remote resource, parses it, and caches it in memory.
// It is designed to be resilient, with a built-in, configurable retry
// mechanism for handling transient network failures.
//
// The refresh interval is intelligently determined by the resource's caching
// headers (e.g., Cache-Control, Expires), but can be clamped within a
// specified min/max range. The controller also handles conditional requests
// using ETag and Last-Modified headers to reduce bandwidth and server load.
//
// # Usage
//
// A typical use case involves creating a scheduler, defining a [Mapper]
// function to parse the HTTP response, creating and configuring a [Controller],
// and then dispatching it to run in the background.
//
// Example:
//
// type Resource struct {
// // fields for the parsed data
// }
//
// // 1. Create a scheduler to manage the refresh ticks.
// sched := scheduler.New(context.Background())
// defer sched.Shutdown()
//
// // 2. Define a mapper to parse the response body into your target type.
// var mapper cache.Mapper[Resource] = func(r *cache.Response) (Resource, error) {
// var data Resource
// err := json.Unmarshal(r.Body, &data)
// return data, err
// }
//
// // 3. Create and configure the cache controller.
// ctrl := cache.NewController(
// "https://api.example.com/resource",
// mapper,
// cache.WithMinInterval(5*time.Minute),
// cache.WithHeader("Authorization", "Bearer *****"),
// )
//
// // 4. Dispatch the controller to start fetching in the background.
// sched.Dispatch(ctrl)
//
// // 5. You can wait for the first successful fetch.
// <-ctrl.Ready()
//
// // 6. Get the cached data.
// if data, ok := ctrl.Get(); ok {
// fmt.Printf("Successfully fetched and cached data: %+v\n", data)
// }
package cache
import (
"context"
"crypto/tls"
"errors"
"io"
"log/slog"
"net"
"net/http"
"sync"
"time"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/retry"
"github.com/deep-rent/nexus/scheduler"
)
const (
// DefaultTimeout is the default timeout for a single HTTP request.
DefaultTimeout = 30 * time.Second
// DefaultMinInterval is the default lower bound for the refresh interval.
DefaultMinInterval = 15 * time.Minute
// DefaultMaxInterval is the default upper bound for the refresh interval.
DefaultMaxInterval = 24 * time.Hour
)
// Mapper is a function that parses a response's raw response body into the
// target type T. It is responsible for decoding the data (e.g., from JSON or
// XML) and returning the structured result. An error should be returned if
// parsing fails fatally. For warnings or debug information, invoke the logger
// contained in the [Response]. If the mapping takes a considerable amount of
// time, it should generally respect the context contained in the [Response].
type Mapper[T any] func(r *Response) (T, error)
// Response provides contextual information to a [Mapper] function,
// including the response body, request context, and a logger.
type Response struct {
// Body is the raw response payload to be mapped.
Body []byte
// Ctx is the context controlling the HTTP exchange.
Ctx context.Context
// Logger is the logger instance inherited from the [Controller].
Logger *slog.Logger
}
// Controller manages the lifecycle of a cached resource. It implements
// [scheduler.Tick], allowing it to be run by a scheduler to periodically
// refresh the resource from a URL.
type Controller[T any] interface {
scheduler.Tick
// Get retrieves the currently cached resource. The boolean return value is
// true if the cache has been successfully populated at least once.
Get() (T, bool)
// Ready returns a channel that is closed once the first successful fetch of
// the resource is complete. This allows consumers to block until the cache
// is warmed up.
Ready() <-chan struct{}
}
// NewController creates and configures a new cache [Controller].
//
// It requires a URL for the resource to fetch and a [Mapper] function to parse
// the response. If no [http.Client] is provided via options, it creates a
// default one with a sensible timeout and a retry transport.
func NewController[T any](
url string,
mapper Mapper[T],
opts ...Option,
) Controller[T] {
cfg := config{
client: nil,
timeout: DefaultTimeout,
headers: make([]header.Header, 0, 3),
minInterval: DefaultMinInterval,
maxInterval: DefaultMaxInterval,
logger: slog.Default(),
}
for _, opt := range opts {
opt(&cfg)
}
client := cfg.client
if client == nil {
d := &net.Dialer{
Timeout: cfg.timeout / 3,
KeepAlive: 0,
}
var t http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: d.DialContext,
TLSClientConfig: cfg.tls,
TLSHandshakeTimeout: cfg.timeout / 3,
ResponseHeaderTimeout: cfg.timeout * 9 / 10,
ExpectContinueTimeout: 1 * time.Second,
DisableKeepAlives: true,
}
t = retry.NewTransport(header.NewTransport(t, cfg.headers...), cfg.retry...)
client = &http.Client{
Timeout: cfg.timeout,
Transport: t,
}
}
return &controller[T]{
url: url,
mapper: mapper,
client: client,
minInterval: cfg.minInterval,
maxInterval: cfg.maxInterval,
logger: cfg.logger,
readyChan: make(chan struct{}),
}
}
// controller is the internal implementation of the [Controller] interface.
type controller[T any] struct {
// url is the endpoint from which the resource is fetched.
url string
// mapper is the user-provided function to parse the raw body.
mapper Mapper[T]
// client is the HTTP client used for fetching.
client *http.Client
// minInterval is the minimum wait time between refreshes.
minInterval time.Duration
// maxInterval is the maximum wait time between refreshes.
maxInterval time.Duration
// now is an internal hook for time mocking.
now func() time.Time
// logger handles internal logging of fetch cycles.
logger *slog.Logger
// readyOnce ensures the ready channel is closed only once.
readyOnce sync.Once
// readyChan is closed upon the first successful fetch.
readyChan chan struct{}
// mu protects the following cached fields.
mu sync.RWMutex
// resource stores the most recently successfully parsed T.
resource T
// ok indicates if resource has been populated.
ok bool
// etag stores the ETag from the last successful response.
etag string
// lastModified stores the Last-Modified header from the last response.
lastModified string
}
// Get retrieves the currently cached resource.
func (c *controller[T]) Get() (T, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.resource, c.ok
}
// Ready returns a channel that is closed when the cache is first populated.
func (c *controller[T]) Ready() <-chan struct{} {
return c.readyChan
}
// ready ensures the ready channel is closed exactly once.
func (c *controller[T]) ready() {
c.readyOnce.Do(func() { close(c.readyChan) })
}
// Run executes a single fetch-and-cache cycle. It implements the
// [scheduler.Tick] interface. It handles conditional requests, response
// parsing, and caching, and returns the duration to wait before the next run.
func (c *controller[T]) Run(ctx context.Context) time.Duration {
c.logger.Debug("Fetching resource")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
if err != nil {
// This is a non-retriable error in request creation.
c.logger.Error("Failed to create request", slog.Any("error", err))
return c.minInterval // Wait a long time before trying to create it again.
}
// Add conditional headers if we have them from a previous request.
c.mu.RLock()
if c.etag != "" {
req.Header.Set("If-None-Match", c.etag)
}
if c.lastModified != "" {
req.Header.Set("If-Modified-Since", c.lastModified)
}
c.mu.RUnlock()
res, err := c.client.Do(req)
if err != nil {
if !errors.Is(err, context.Canceled) {
c.logger.Error(
"HTTP request failed after retries",
slog.Any("error", err),
)
}
return c.minInterval
}
defer func() {
if err := res.Body.Close(); err != nil {
c.logger.Warn("Failed to close response body", slog.Any("error", err))
}
}()
switch code := res.StatusCode; code {
case http.StatusNotModified:
c.logger.Debug("Resource unchanged", slog.String("etag", c.etag))
c.ready()
return c.delay(res.Header)
case http.StatusOK:
body, err := io.ReadAll(res.Body)
if err != nil {
c.logger.Error("Failed to read response body", slog.Any("error", err))
return c.minInterval
}
resource, err := c.mapper(&Response{
Body: body,
Ctx: req.Context(),
Logger: c.logger,
})
if err != nil {
c.logger.Error("Couldn't parse response body", slog.Any("error", err))
return c.minInterval
}
c.mu.Lock()
c.resource = resource
c.etag = res.Header.Get("ETag")
c.lastModified = res.Header.Get("Last-Modified")
c.ok = true
c.mu.Unlock()
c.logger.Info("Resource updated successfully")
c.ready()
return c.delay(res.Header)
default:
c.logger.Error(
"Received a non-retriable HTTP status code",
slog.Int("status", code),
)
return c.minInterval
}
}
// delay calculates the duration until the next fetch based on caching headers,
// clamped by the configured min/max intervals.
func (c *controller[T]) delay(h http.Header) time.Duration {
d := header.Lifetime(h, c.now)
if d > c.maxInterval {
return c.maxInterval
}
if d < c.minInterval {
return c.minInterval
}
return d
}
var _ Controller[any] = (*controller[any])(nil)
// config holds the internal configuration for the cache controller.
type config struct {
// client is an optional custom HTTP client.
client *http.Client
// timeout is the default request timeout.
timeout time.Duration
// headers are static headers applied to every request.
headers []header.Header
// tls is the TLS configuration for the default client.
tls *tls.Config
// minInterval is the floor for refresh delays.
minInterval time.Duration
// maxInterval is the ceiling for refresh delays.
maxInterval time.Duration
// retry are options for the default transport's retry logic.
retry []retry.Option
// logger is the destination for internal logs.
logger *slog.Logger
}
// Option is a function that configures the cache [Controller].
type Option func(*config)
// WithClient provides a custom [http.Client] to be used for requests. This is
// useful for advanced configurations, such as custom transports or connection
// pooling. If not provided, a default client with retry logic is created.
func WithClient(client *http.Client) Option {
return func(c *config) {
if client != nil {
c.client = client
}
}
}
// WithTimeout sets the total timeout for a single HTTP fetch attempt, including
// connection, redirects, and reading the response body. This is ignored if a
// custom client is provided via [WithClient].
func WithTimeout(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.timeout = d
}
}
}
// WithHeader adds a static header to every request sent by the controller. This
// can be called multiple times to add multiple headers.
func WithHeader(k, v string) Option {
return func(c *config) {
c.headers = append(c.headers, header.New(k, v))
}
}
// WithTLSConfig provides a custom [tls.Config] for the default HTTP transport.
// This is ignored if a custom client is provided via [WithClient].
func WithTLSConfig(tls *tls.Config) Option {
return func(c *config) {
c.tls = tls
}
}
// WithMinInterval sets the minimum duration between refresh attempts. The
// refresh delay, typically determined by caching headers, will not be shorter
// than this.
func WithMinInterval(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.minInterval = d
}
}
}
// WithMaxInterval sets the maximum duration between refresh attempts. The
// refresh delay will not be longer than this value.
func WithMaxInterval(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.maxInterval = d
}
}
}
// WithRetryOptions configures the retry mechanism for the default HTTP client.
// These options are ignored if a custom client is provided via [WithClient].
func WithRetryOptions(opts ...retry.Option) Option {
return func(c *config) {
c.retry = append(c.retry, opts...)
}
}
// WithLogger provides a custom [slog.Logger] for the controller. If not
// provided, [slog.Default] is used.
func WithLogger(logger *slog.Logger) Option {
return func(c *config) {
if logger != nil {
c.logger = logger
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package di provides a type-safe, concurrent dependency injection container
// for Go applications.
//
// The core concepts are:
// - [Injector]: The main container that holds all service bindings.
// - [Container]: An interface passed to providers to resolve nested
// dependencies.
// - [Slot]: A unique, typed key used to register and retrieve services.
// - [Provider]: A factory function that creates an instance of a service.
// - [Resolver]: A strategy that defines the lifecycle of a service.
//
// # Usage
//
// Let's explore how to use this container by modeling a simple chemical
// reaction: forming a Salt. This example shows how to use a single interface
// (Ion) for dependencies that fulfill different roles (Cation and Anion).
//
// Step 1: Define a Reusable Abstraction (The Ion Role)
//
// Instead of separate Cation and Anion interfaces, we can define a single,
// reusable Ion interface. Our Salt struct will now depend on two instances
// of this same interface type.
//
// Example:
//
// // Ion represents any particle with a symbol and a charge.
// type Ion interface {
// Symbol() string
// Charge() int
// }
//
// // Salt is our final product, which depends on two Ions.
// type Salt struct {
// cation Ion
// anion Ion
// }
//
// func (s Salt) Formula() string {
// return s.cation.Symbol() + s.anion.Symbol()
// }
//
// Step 2: Create Slots for Roles (The Unique Labels)
//
// The key insight here is that slots distinguish dependencies by their role,
// not just their type. Even though both slots below are for the Ion type,
// they are unique keys. This allows us to inject the right ion into the right
// place. The tags ("ion", "cation", etc.) are optional but help with debugging.
//
// Example:
//
// var (
// SlotCation = di.NewSlot[Ion]("ion", "cation")
// SlotAnion = di.NewSlot[Ion]("ion", "anion")
// SlotSalt = di.NewSlot[Salt]("compound", "salt")
// )
//
// Step 3: Write Providers (The Recipes)
//
// Providers now return the generic Ion interface. The Salt provider can
// then request two different Ions by using their distinct role-based slots.
//
// Example:
//
// // ProvideSodium provides a concrete Ion to fulfill the Cation role.
// func ProvideSodium(c di.Container) (Ion, error) {
// type Sodium struct{}
// func (na Sodium) Symbol() string { return "Na" }
// func (na Sodium) Charge() int { return 1 }
// return Sodium{}, nil
// }
//
// // ProvideChloride provides a concrete Ion to fulfill the Anion role.
// func ProvideChloride(c di.Container) (Ion, error) {
// type Chloride struct{}
// func (cl Chloride) Symbol() string { return "Cl" }
// func (cl Chloride) Charge() int { return -1 }
// return Chloride{}, nil
// }
//
// // ProvideSalt requests dependencies by their role-specific slots.
// func ProvideSalt(c di.Container) (Salt, error) {
// // Request the Ion fulfilling the "Cation" role.
// cation := di.Required[Ion](c, SlotCation)
// // Request the Ion fulfilling the "Anion" role.
// anion := di.Required[Ion](c, SlotAnion)
// return Salt{cation: cation, anion: anion}, nil
// }
//
// Step 4: Assemble the Solution (Configure the Injector)
//
// Now, create an Injector and bind the concrete providers to their
// respective role slots. We use Transient scope to obtain fresh ions
// each time we form a new salt molecule.
//
// Example:
//
// // 1. Create the injector.
// solution := di.NewInjector()
//
// // 2. Bind concrete providers to their roles.
// di.Bind(solution, SlotCation, ProvideSodium, di.Transient())
// di.Bind(solution, SlotAnion, ProvideChloride, di.Transient())
//
// // 3. Bind the provider for the final product.
// di.Bind(solution, SlotSalt, ProvideSalt, di.Singleton())
//
// Step 5: Trigger the Reaction (Resolve the Final Product)
//
// When we ask for the Salt, the injector provides the previously registered
// atoms (dependencies) to the Salt provider to form the final molecule.
//
// Example:
//
// // This call triggers the entire dependency chain.
// salt := di.Required[Salt](solution, SlotSalt)
//
// fmt.Printf("Successfully formed: %s\n", salt.Formula())
// // Output: Successfully formed: NaCl
package di
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
)
// Sentinel errors for standard DI failure modes.
var (
// ErrCycle indicates that a circular dependency was found.
ErrCycle = errors.New("circular dependency")
// ErrUnbound indicates that a slot has no provider bound to it.
ErrUnbound = errors.New("unbound slot")
)
// Slot is an abstract, typed symbol for an injectable service.
// It is a unique pointer that acts as a map key within the [Injector],
// while the generic type T provides compile-time type safety.
type Slot[T any] *struct {
// _ ensures a non-zero size, guaranteeing unique memory addresses.
_ byte
}
// slots is a global, concurrent map that stores the debug tag for each slot.
var slots = &sync.Map{}
// Reset clears all registered slot tags from the internal global map.
func Reset() {
slots = &sync.Map{}
}
// NewSlot creates a new, unique [Slot] for a given type T.
//
// The optional keys are used to create a descriptive name for debugging and
// error messages. Multiple keys are joined with dots. This is useful to group
// related services, e.g., by package or feature. The assigned tag can be
// retrieved later using the [Tag] function.
func NewSlot[T any](keys ...string) Slot[T] {
// Explicitly cast to the named generic type BEFORE storing.
s := Slot[T](new(struct{ _ byte })) // Unique allocation
t := reflect.TypeFor[T]().String()
var tag string
if len(keys) == 0 {
// Case 1: Unnamed slot, e.g., @string
tag = "@" + t
} else {
// Case 2: Named slot, e.g., a.b.c@int
tag = strings.Join(keys, ".") + "@" + t
}
slots.Store(s, tag)
return s
}
// Tag returns the pre-formatted debug string for a slot.
//
// The tag is of the form "name@type", where "name" is the optional name
// assigned during slot creation, and "type" is the Go type of the slot. If no
// name was provided, it returns just the type prefixed with "@". If the slot is
// unknown, it falls back to the pointer address.
func Tag(slot any) string {
if name, ok := slots.Load(slot); ok {
return name.(string)
}
// Fallback to pointer address for unnamed slots.
return fmt.Sprintf("%p", slot)
}
// Container represents an interface for resolving dependencies.
// Both the top-level [Injector] and the internal resolution state implement
// this interface.
type Container interface {
// Context returns the context associated with this resolution container.
Context() context.Context
// Resolve performs the lookup for the given slot.
Resolve(slot any) (any, error)
}
// Provider defines the function signature for a service factory.
//
// When a service is requested, its provider is called with a [Container], which
// it can then use to resolve any of its own dependencies (e.g., by calling
// Use). How often the provider is called depends on the number of injection
// sites and the resolution strategy used when binding the provider to a slot.
type Provider[T any] func(c Container) (T, error)
// binding holds a provider and its associated resolution strategy.
type binding struct {
// provider is the internal wrapped function that returns any.
provider func(c Container) (any, error)
// resolver is the lifecycle strategy for this binding.
resolver Resolver
}
// config holds configuration options for an [Injector].
type config struct {
// ctx is the base context for the injector.
ctx context.Context
}
// Option configures an [Injector].
type Option func(*config)
// WithContext sets the root context for the [Injector]. If ctx is nil, the
// background context is used by default.
func WithContext(ctx context.Context) Option {
return func(cfg *config) {
if ctx != nil {
cfg.ctx = ctx
}
}
}
// Injector is the main dependency injection container.
// It holds all service bindings and manages their lifecycle. An [Injector] is
// safe for concurrent reads (e.g., using [Use], [Required]), but is not safe
// for concurrent writes (e.g., using [Bind], [Override]). Bindings should be
// configured once at application startup.
type Injector struct {
// ctx is the root context for all resolutions.
ctx context.Context
// bindings maps slots to their respective provider and resolver.
bindings map[any]*binding
// mu protects the bindings map from concurrent access.
mu sync.RWMutex
}
// NewInjector creates and returns a new, empty [Injector] with the given
// options. If no options are provided, it defaults to using
// [context.Background].
func NewInjector(opts ...Option) *Injector {
cfg := config{
ctx: context.Background(),
}
for _, opt := range opts {
opt(&cfg)
}
return &Injector{
ctx: cfg.ctx,
bindings: make(map[any]*binding),
}
}
// Context returns the injector's context.
func (in *Injector) Context() context.Context {
return in.ctx
}
// Bind registers a provider and its resolver for a specific service slot.
// It is typically called during application initialization.
func Bind[T any](
in *Injector,
slot Slot[T],
provider Provider[T],
resolver Resolver,
) {
in.mu.Lock()
defer in.mu.Unlock()
if _, ok := in.bindings[slot]; ok {
panic(fmt.Sprintf("slot %s is already bound", Tag(slot)))
}
in.bindings[slot] = &binding{
provider: func(c Container) (any, error) { return provider(c) },
resolver: resolver,
}
}
// Use resolves a service from the [Container] for a given slot. It is the
// primary method for retrieving dependencies when an error is an expected
// outcome.
func Use[T any](c Container, slot Slot[T]) (T, error) {
v, err := c.Resolve(slot)
if err != nil {
var zero T
return zero, err
}
if v == nil {
var zero T
return zero, nil
}
t, ok := v.(T)
if ok {
return t, nil
}
panic(fmt.Sprintf("provider returned %T for slot %s", v, Tag(slot)))
}
// Optional resolves a service and panics if any resolution error occurs.
// However, unlike [Required], it allows the provider to return a nil value
// without panicking. It is useful for dependencies that are truly optional.
func Optional[T any](c Container, slot Slot[T]) T {
v, err := Use(c, slot)
if err != nil {
panic(err)
}
return v
}
// Required resolves a service and panics if an error occurs OR if the resolved
// value is nil. This should be used for critical dependencies that must always
// be present.
func Required[T any](c Container, slot Slot[T]) T {
v := Optional(c, slot)
val := reflect.ValueOf(v)
switch val.Kind() {
case
reflect.Pointer,
reflect.Interface,
reflect.Slice,
reflect.Map,
reflect.Chan,
reflect.Func:
if val.IsNil() {
panic(fmt.Sprintf(
"required dependency for slot %s is nil",
Tag(slot),
))
}
}
return v
}
// Override registers a provider for a slot, replacing any existing binding.
func Override[T any](
in *Injector,
slot Slot[T],
provider Provider[T],
resolver Resolver,
) {
in.mu.Lock()
defer in.mu.Unlock()
in.bindings[slot] = &binding{
provider: func(c Container) (any, error) { return provider(c) },
resolver: resolver,
}
}
// Resolve is a non-generic method to resolve a dependency from a slot.
// In most cases, the type-safe functions ([Use], [Optional], [Required]) should
// be preferred.
func (in *Injector) Resolve(slot any) (any, error) {
// This is a top-level call, so create a fresh map for cycle detection.
return in.resolve(slot, make(map[any]bool))
}
// resolve is the internal, recursive implementation for dependency resolution.
func (in *Injector) resolve(slot any, visiting map[any]bool) (any, error) {
if visiting[slot] {
return nil, fmt.Errorf(
"%w detected while resolving slot %s",
ErrCycle, Tag(slot),
)
}
visiting[slot] = true
in.mu.RLock()
b, ok := in.bindings[slot]
in.mu.RUnlock()
if !ok {
return nil, fmt.Errorf(
"%w: no provider bound for slot %s",
ErrUnbound, Tag(slot),
)
}
val, err := b.resolver.Resolve(in, b.provider, slot, visiting)
delete(visiting, slot) // Clean up the map on the way back up the call stack.
return val, err
}
// Resolver defines a strategy for managing a service's lifecycle.
type Resolver interface {
// Resolve determines how the provider should be invoked and if the result
// should be cached.
Resolve(
in *Injector,
provider func(c Container) (any, error),
slot any,
visiting map[any]bool,
) (any, error)
}
// singleton is a [Resolver] that caches the service instance.
type singleton struct {
// instance is the cached value returned by the provider.
instance any
// err is the cached error returned by the provider.
err error
// once ensures the provider is only called once.
once sync.Once
}
// Resolve implements [Resolver.Resolve] by caching the provider output.
func (s *singleton) Resolve(
in *Injector,
provider func(c Container) (any, error),
slot any,
visiting map[any]bool,
) (any, error) {
s.once.Do(func() {
s.instance, s.err = provide(in, provider, slot, visiting)
})
return s.instance, s.err
}
// Singleton returns a [Resolver] that creates an instance once per injector and
// reuses it for all subsequent requests.
func Singleton() Resolver {
return &singleton{}
}
// transient is a [Resolver] that always creates a new service instance.
type transient struct{}
// Resolve implements [Resolver.Resolve] by calling the provider every time.
func (transient) Resolve(
in *Injector,
provider func(c Container) (any, error),
slot any,
visiting map[any]bool,
) (any, error) {
return provide(in, provider, slot, visiting)
}
// Transient returns a [Resolver] that creates a new instance of the service
// every time it is requested.
func Transient() Resolver {
return transient{}
}
// resolutionState is a lightweight container passed down during graph
// traversal. It tracks cycles seamlessly without inflating context trees or
// duplicating injectors.
type resolutionState struct {
// injector is the parent container.
injector *Injector
// visiting is the map used for circular dependency detection.
visiting map[any]bool
}
// Context implements [Container.Context].
func (r *resolutionState) Context() context.Context {
return r.injector.Context()
}
// Resolve implements [Container.Resolve].
func (r *resolutionState) Resolve(slot any) (any, error) {
return r.injector.resolve(slot, r.visiting)
}
// statePool minimizes allocations during deep dependency tree resolutions.
var statePool = sync.Pool{
New: func() any { return &resolutionState{} },
}
// provide is an internal helper that safely invokes a provider function.
// It retrieves a state object to maintain the cycle detection map seamlessly.
func provide(
in *Injector,
provider func(c Container) (any, error),
slot any,
visiting map[any]bool,
) (instance any, err error) {
defer func() {
if r := recover(); r != nil {
if e, ok := r.(error); ok {
err = fmt.Errorf(
"panic during provider call for slot %s: %w",
Tag(slot), e,
)
} else {
err = fmt.Errorf(
"panic during provider call for slot %s: %v",
Tag(slot), r,
)
}
}
}()
// Grab a reusable state object to track the resolution cycle.
state := statePool.Get().(*resolutionState)
state.injector = in
state.visiting = visiting
// Guarantee the state is scrubbed and returned to the pool.
defer func() {
state.injector = nil
state.visiting = nil
statePool.Put(state)
}()
return provider(state)
}
// scopedKey is the context key for the scoped dependency cache.
type scopedKey struct{}
// NewScope creates a new context that carries a cache for scoped dependencies.
// This should be called at the beginning of an operation that defines a scope,
// such as a new HTTP request. The returned context should be passed to a new
// or child injector via [WithContext].
func NewScope(ctx context.Context) context.Context {
return context.WithValue(ctx, scopedKey{}, &sync.Map{})
}
// scopedEntry tracks the lifecycle of a dependency within a specific scope.
type scopedEntry struct {
// once ensures single execution per scope.
once sync.Once
// val is the cached scope instance.
val any
// err is the cached scope error.
err error
}
// scoped is a [Resolver] that ties the service lifecycle to a context scope.
type scoped struct{}
// Resolve implements [Resolver.Resolve] using a context-based cache.
func (s scoped) Resolve(
in *Injector,
provider func(c Container) (any, error),
slot any,
visiting map[any]bool,
) (any, error) {
val := in.Context().Value(scopedKey{})
cache, ok := val.(*sync.Map)
if !ok || cache == nil {
return nil, fmt.Errorf(
"no scope cache found in context for scoped slot %s", Tag(slot),
)
}
// Retrieve or create the synchronization entry
act, _ := cache.LoadOrStore(slot, &scopedEntry{})
ent := act.(*scopedEntry)
// Ensure the provider runs exactly once per scope
ent.once.Do(func() {
ent.val, ent.err = provide(in, provider, slot, visiting)
})
return ent.val, ent.err
}
// Scoped returns a [Resolver] that ties the lifecycle of a service to a
// [context.Context]. A new instance is created once per scope, defined by a
// call to [NewScope]. It requires that the injector's context was created
// via [NewScope].
func Scoped() Resolver {
return scoped{}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package env provides functionality for unmarshaling environment variables
// into Go structs.
//
// By default, all exported fields in a struct are mapped to environment
// variables. The variable name is derived by converting the field's name to
// uppercase SNAKE_CASE (e.g., a field named APIKey maps to API_KEY).
// This behavior can be customized or disabled on a per-field basis using
// struct tags.
//
// # Usage
//
// Define a struct to hold your configuration. Only exported fields will be
// considered. The code snippet below showcases various field types and
// struct tag options:
//
// Example:
//
// type Config struct {
// Host string `env:",required"`
// Port int `env:",default:8080"`
// Timeout time.Duration `env:",unit:s"`
// Debug bool
// Proxy ProxyConfig `env:",prefix:'HTTP_PROXY_'"`
// Roles []string `env:",split:';'"`
// Internal int `env:"-"`
// internal int
// }
//
// var cfg Config
// if err := env.Unmarshal(&cfg); err != nil {
// log.Fatalf("failed to unmarshal config: %v", err)
// }
// // Use the configuration to bootstrap your application...
//
// # Options
//
// The behavior of the unmarshaler is controlled by the env struct field tag.
// The tag is a comma-separated string of options.
//
// The first value is the name of the environment variable. If it is omitted,
// the field's name is used as the base for the variable name.
//
// DatabaseURL string `env:"MY_DATABASE_URL"`
//
// The subsequent parts of the tag are options, which can be in a key:value
// format or be boolean flags.
//
// Option "default": Sets a default value to be used if the environment
// variable is not set.
//
// Port int `env:",default:8080"`
//
// Option "required": Marks the variable as required. [Unmarshal] will return
// an error if the variable is not set and no default is provided.
//
// APIKey string `env:",required"`
//
// Option "prefix": For nested struct fields, this overrides the default
// prefix. By default, the prefix is the field's name in SNAKE_CASE followed by
// an underscore. It can be set to an empty string to omit the prefix entirely.
//
// DBConfig `env:",prefix:DB_"`
//
// Option "inline": When applied to an anonymous struct field, it flattens the
// struct, effectively treating its fields as if they belonged to the parent
// struct.
//
// Nested `env:",inline"`
//
// Option "split": For slice types, this specifies the delimiter to split the
// environment variable string. The default separator is a comma.
//
// Hosts []string `env:",split:';'"`
//
// Option "format": Provides a format specifier for special types. For
// [time.Time] it can be a Go-compliant layout string (e.g., "2006-01-02") or
// one of the predefined constants "unix", "dateTime", "date", and "time".
// Defaults to the RFC 3339 format. For []byte, it can be "hex", "base32", or
// "base64" to alter the encoding format.
//
// StartDate time.Time `env:",format:date"`
//
// Option "unit": Specifies the unit for [time.Time] or [time.Duration] when
// parsing from an integer. For [time.Duration]: "ns", "us" (or "μs"), "ms",
// "s", "m", "h". For [time.Time] (with format:unix): "s", "ms", "us" (or "μs").
//
// CacheTTL time.Duration `env:",unit:m,default:5"`
package env
import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
"github.com/deep-rent/nexus/internal/pointer"
"github.com/deep-rent/nexus/internal/snake"
"github.com/deep-rent/nexus/internal/tag"
)
// Lookup is a function that retrieves the value of an environment variable.
// It follows the signature of [os.LookupEnv], returning the value and a boolean
// indicating whether the variable was present. This type allows for custom
// lookup mechanisms, such as reading from sources other than the actual
// environment, which is especially useful for testing.
type Lookup func(key string) (string, bool)
// Unmarshaler is an interface that can be implemented by types to provide their
// own custom logic for parsing an environment variable string.
type Unmarshaler interface {
// UnmarshalEnv unmarshals the string value of an environment variable.
// The value is the raw string from the environment or a default value.
// It returns an error if the value cannot be parsed.
UnmarshalEnv(value string) error
}
// Option is a function that configures the behavior of the [Unmarshal] and
// [Expand] functions. It follows the functional options pattern.
type Option func(*config)
// WithPrefix returns an [Option] that adds a common prefix to all environment
// variable keys looked up during unmarshaling. For example, WithPrefix("APP_")
// would cause a field with the env tag "PORT" to look for the "APP_PORT"
// variable.
func WithPrefix(prefix string) Option {
return func(o *config) {
o.Prefix = prefix
}
}
// WithLookup returns an [Option] that sets a custom [Lookup] function for
// retrieving environment variable values. If not customized, [os.LookupEnv]
// will be used by default. This is useful for testing or if you need to load
// environment variables from alternative sources.
func WithLookup(lookup Lookup) Option {
return func(o *config) {
if lookup != nil {
o.Lookup = lookup
}
}
}
// Unmarshal populates the fields of a struct with values from environment
// variables. The given value v must be a non-nil pointer to a struct.
//
// By default, [Unmarshal] processes all exported fields. A field's environment
// variable name is derived from its name, converted to uppercase SNAKE_CASE.
// To ignore a field, tag it with `env:"-"`. Unexported fields are always
// excluded. If a variable is not set, the field remains unchanged unless a
// default value is specified in the struct tag, or it is marked as required.
func Unmarshal(v any, opts ...Option) error {
if err := unmarshal(v, opts...); err != nil {
return fmt.Errorf("env: %w", err)
}
return nil
}
// Expand substitutes environment variables in a string.
//
// It replaces references to environment variables in the formats ${KEY} or $KEY
// with their corresponding values. A literal dollar sign can be escaped with $$
// (double dollar sign). If a referenced variable is not found in the
// environment, the function returns an error. Its behavior can be adjusted
// through functional options.
func Expand(s string, opts ...Option) (string, error) {
cfg := config{
Lookup: os.LookupEnv,
}
for _, opt := range opts {
opt(&cfg)
}
var b strings.Builder
b.Grow(len(s) * 2)
i := 0
for i < len(s) {
// Find the next dollar sign.
start := strings.IndexByte(s[i:], '$')
if start == -1 {
b.WriteString(s[i:])
break
}
// Append the text before the dollar sign.
b.WriteString(s[i : i+start])
// Move our main index to the location of the dollar sign.
i += start
// Check what follows the dollar sign.
switch {
case i+1 < len(s) && s[i+1] == '$':
// Case 1: Escaped dollar sign ($$).
b.WriteByte('$')
i += 2 // Skip both signs.
case i+1 < len(s) && s[i+1] == '{':
// Case 2: Bracketed variable expansion (${KEY}).
end := strings.IndexByte(s[i+2:], '}')
if end == -1 {
return "", errors.New("env: variable bracket not closed")
}
// Extract the variable name.
key := cfg.Prefix + s[i+2:i+2+end]
val, ok := cfg.Lookup(key)
if !ok {
return "", fmt.Errorf("env: variable %q is not set", key)
}
b.WriteString(val)
// Move the index past the processed variable `${KEY}`.
i += 2 + end + 1
default:
// Case 3: Standard variable expansion ($KEY) or lone dollar sign.
// Scan ahead for valid identifier characters. The first character
// must be a letter or underscore; subsequent characters can include
// digits.
n := 0
for j := i + 1; j < len(s); j++ {
c := s[j]
// Allow if it's a letter/underscore, OR if it's a digit but NOT
// the first character.
if (c == '_' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z')) ||
(n > 0 && ('0' <= c && c <= '9')) {
n++
} else {
break
}
}
if n == 0 {
// No valid identifier characters found (e.g., "$5", "$!").
// Treat as a literal dollar sign.
b.WriteByte('$')
i++
} else {
// Extract the unbracketed variable name.
key := cfg.Prefix + s[i+1:i+1+n]
val, ok := cfg.Lookup(key)
if !ok {
return "", fmt.Errorf("env: variable %q is not set", key)
}
b.WriteString(val)
// Move the index past the processed variable `$KEY`.
i += 1 + n
}
}
}
return b.String(), nil
}
// flags encapsulates the options parsed from an `env` struct tag.
type flags struct {
// Name is the name of the environment variable.
Name string
// Prefix is an optional prefix for nested structs.
Prefix *string
// Split is the delimiter for slice types.
Split string
// Unit is the unit for time.Time or time.Duration.
Unit string
// Format is the format specifier for special types.
Format string
// Default is the fallback value if the variable is not found.
Default string
// Inline indicates whether to inline an anonymous struct field.
Inline bool
// Required indicates whether the variable is required.
Required bool
}
// config holds configuration options for environment variable processing.
type config struct {
// Prefix is a common prefix for all environment variable keys.
Prefix string
// Lookup is the injectable callback for variable lookup.
Lookup Lookup
}
// Cache types with special unmarshaling logic.
var (
typeTime = reflect.TypeFor[time.Time]()
typeDuration = reflect.TypeFor[time.Duration]()
typeLocation = reflect.TypeFor[time.Location]()
typeURL = reflect.TypeFor[url.URL]()
typeUnmarshaler = reflect.TypeFor[Unmarshaler]()
)
// unmarshal is the internal implementation that orchestrates the unmarshaling.
func unmarshal(v any, opts ...Option) error {
ptr := reflect.ValueOf(v)
if ptr.Kind() != reflect.Pointer || ptr.IsNil() {
return errors.New(
"expected a non-nil pointer to a struct",
)
}
val := ptr.Elem()
if kind := val.Kind(); kind != reflect.Struct {
return fmt.Errorf(
"expected a pointer to a struct, but got pointer to %v", kind,
)
}
cfg := config{
Lookup: os.LookupEnv,
}
for _, opt := range opts {
opt(&cfg)
}
return process(val, cfg.Prefix, cfg.Lookup)
}
// process recursively walks through the struct fields.
func process(rv reflect.Value, prefix string, lookup Lookup) error {
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
ft := rt.Field(i)
fv := rv.Field(i)
if !ft.IsExported() || !fv.CanSet() {
continue
}
tagValue := ft.Tag.Get("env")
if tagValue == "-" {
continue
}
opts, err := parse(tagValue)
if err != nil {
return fmt.Errorf("failed to parse tag for field %q: %w", ft.Name, err)
}
if ft.Anonymous && opts.Inline {
// Dereference and allocate in case the inline field is a pointer.
embedded := pointer.Deref(fv)
if err := process(embedded, prefix, lookup); err != nil {
return err
}
continue
}
key := opts.Name
if key == "" {
key = snake.ToUpper(ft.Name)
}
// Check for true embedded structs.
if isEmbedded(ft, fv) {
nested := prefix
if opts.Prefix != nil {
nested += *opts.Prefix
} else {
nested += key + "_"
}
// Dereference and allocate in case the field is a pointer.
embedded := pointer.Deref(fv)
if err := process(embedded, nested, lookup); err != nil {
return err
}
continue
}
key = prefix + key
val, ok := lookup(key)
// A variable is "set" even if it is empty. We only trigger the strictly
// missing variable logic if 'ok' is false.
if !ok {
switch {
case opts.Default != "":
val = opts.Default
case opts.Required:
return fmt.Errorf("required variable %q is not set", key)
default:
continue
}
} else if val == "" && opts.Default != "" {
// If the variable is explicitly set to empty (""), but a default
// exists in the tags, fall back to the default value.
val = opts.Default
}
// If a field is required and set to "", it bypasses the errors above.
// For strings, setValue will correctly assign the empty string. For
// types like int or bool, setValue will return a natural parsing error.
if err := setValue(fv, val, opts); err != nil {
return fmt.Errorf(
"error setting field %q from variable %q: %w",
ft.Name, key, err,
)
}
}
return nil
}
// setValue assigns a string value to a reflect.Value based on its type.
func setValue(rv reflect.Value, v string, f *flags) error {
if u, ok := asUnmarshaler(rv); ok {
// Use the custom unmarshaler if available.
return u.UnmarshalEnv(v)
}
rv = pointer.Deref(rv)
switch rv.Type() {
case typeTime:
return setTime(rv, v, f)
case typeDuration:
return setDuration(rv, v, f)
case typeLocation:
return setLocation(rv, v)
case typeURL:
return setURL(rv, v)
default:
return setOther(rv, v, f)
}
}
// setOther handles all "regular" (primitive and slice) types by delegating to
// the appropriate parsing logic based on the reflective kind. If rv is a slice,
// it calls setSlice, otherwise it attempts to convert v into the type expected
// by rv and sets it.
func setOther(rv reflect.Value, v string, f *flags) error {
switch kind := rv.Kind(); kind {
case reflect.Slice:
return setSlice(rv, v, f)
case reflect.Bool:
b, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("%q is not a bool", v)
}
rv.SetBool(b)
case reflect.String:
rv.SetString(v)
case
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64:
b := rv.Type().Bits()
i, err := strconv.ParseInt(v, 10, b)
if err != nil {
return fmt.Errorf("%q is not an int%d", v, b)
}
rv.SetInt(i)
case
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64:
b := rv.Type().Bits()
u, err := strconv.ParseUint(v, 10, b)
if err != nil {
return fmt.Errorf("%q is not a uint%d", v, b)
}
rv.SetUint(u)
case reflect.Float32, reflect.Float64:
b := rv.Type().Bits()
fval, err := strconv.ParseFloat(v, b)
if err != nil {
return fmt.Errorf("%q is not a float%d", v, b)
}
rv.SetFloat(fval)
case reflect.Complex64, reflect.Complex128:
b := rv.Type().Bits()
c, err := strconv.ParseComplex(v, b)
if err != nil {
return fmt.Errorf("%q is not a complex%d", v, b)
}
rv.SetComplex(c)
default:
return fmt.Errorf("unsupported type: %s", kind)
}
return nil
}
// setTime parses and sets a [time.Time] value based on the provided format and
// unit options.
func setTime(rv reflect.Value, v string, f *flags) error {
var t time.Time
var err error
switch format := f.Format; format {
case "unix":
var i int64
i, err = strconv.ParseInt(v, 10, 64)
if err == nil {
switch unit := f.Unit; unit {
case "s", "":
t = time.Unix(i, 0)
case "ms":
t = time.UnixMilli(i)
case "us", "μs":
t = time.UnixMicro(i)
default:
err = fmt.Errorf("invalid time unit: %q", unit)
}
}
case "dateTime":
t, err = time.Parse(time.DateTime, v)
case "date":
t, err = time.Parse(time.DateOnly, v)
case "time":
t, err = time.Parse(time.TimeOnly, v)
case "":
format = time.RFC3339
fallthrough
default:
t, err = time.Parse(format, v)
}
if err != nil {
return err
}
rv.Set(reflect.ValueOf(t))
return nil
}
// setDuration parses and sets a [time.Duration] value based on the provided
// unit option.
func setDuration(rv reflect.Value, v string, f *flags) error {
var d time.Duration
var err error
if unit := f.Unit; unit == "" {
d, err = time.ParseDuration(v)
} else {
var i int64
i, err = strconv.ParseInt(v, 10, 64)
if err == nil {
switch unit {
case "ns":
d = time.Duration(i)
case "us", "μs":
d = time.Duration(i) * time.Microsecond
case "ms":
d = time.Duration(i) * time.Millisecond
case "s":
d = time.Duration(i) * time.Second
case "m":
d = time.Duration(i) * time.Minute
case "h":
d = time.Duration(i) * time.Hour
default:
err = fmt.Errorf("invalid duration unit: %q", unit)
}
}
}
if err != nil {
return err
}
rv.SetInt(int64(d))
return nil
}
// setLocation parses and sets a [time.Location] value.
func setLocation(rv reflect.Value, v string) error {
loc, err := time.LoadLocation(v)
if err != nil {
return err
}
rv.Set(reflect.ValueOf(*loc))
return nil
}
// setURL parses and sets a [url.URL] value.
func setURL(rv reflect.Value, v string) error {
u, err := url.Parse(v)
if err != nil {
return err
}
rv.Set(reflect.ValueOf(*u))
return nil
}
// setSlice parses and sets a slice value. It supports []byte with special
// encoding formats, as well as other slice types by splitting the input string.
func setSlice(rv reflect.Value, v string, f *flags) error {
if rv.Type().Elem().Kind() == reflect.Uint8 {
var b []byte
var err error
switch f.Format {
case "":
b = []byte(v)
case "hex":
b, err = hex.DecodeString(v)
case "base32":
b, err = base32.StdEncoding.DecodeString(v)
case "base64":
b, err = base64.StdEncoding.DecodeString(v)
default:
return fmt.Errorf("unsupported format for []byte: %q", f.Format)
}
if err != nil {
return err
}
rv.SetBytes(b)
return nil
}
parts := strings.Split(v, f.Split)
if len(parts) == 1 && parts[0] == "" {
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
return nil
}
slice := reflect.MakeSlice(rv.Type(), len(parts), len(parts))
for i, part := range parts {
if err := setValue(slice.Index(i), part, f); err != nil {
return fmt.Errorf("failed to parse slice element at index %d: %w", i, err)
}
}
rv.Set(slice)
return nil
}
// parse parses the `env` tag string. It supports quoted values for options
// to allow commas within them, e.g., `default:'a,b,c'`.
func parse(s string) (*flags, error) {
t := tag.Parse(s)
f := &flags{Name: t.Name, Split: ","}
seen := make(map[string]bool)
for k, v := range t.Opts() {
if seen[k] {
return nil, fmt.Errorf("duplicate option: %q", k)
}
switch k {
case "format":
f.Format = v
case "prefix":
f.Prefix = &v
case "split":
f.Split = v
case "unit":
f.Unit = v
case "default":
f.Default = v
case "inline":
f.Inline = true
case "required":
f.Required = true
default:
return nil, fmt.Errorf("unknown option: %q", k)
}
seen[k] = true
}
return f, nil
}
// isEmbedded checks if a struct field is a true embedded struct that should
// be processed recursively.
func isEmbedded(f reflect.StructField, rv reflect.Value) bool {
t := f.Type
// Unwrap pointer(s) to check the underlying type.
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
// 1. It must resolve to a struct.
if t.Kind() != reflect.Struct {
return false
}
// 2. It is not one of the special struct types we handle directly.
if t == typeTime || t == typeURL || t == typeLocation {
return false
}
// 3. It does NOT implement the Unmarshaler interface.
if _, ok := asUnmarshaler(rv); ok {
return false
}
// If all checks pass, it's a struct we should recurse into.
return true
}
// asUnmarshaler checks if the given [reflect.Value] implements the
// [Unmarshaler] interface, either directly or via a pointer receiver. If it
// does, the function returns the type-casted [Unmarshaler] and true.
// Otherwise, it returns nil and false.
func asUnmarshaler(rv reflect.Value) (Unmarshaler, bool) {
// Case 1: The field's type directly implements Unmarshaler.
// This works for pointer types (e.g., *reverse) or value types with
// value receivers.
if rv.Type().Implements(typeUnmarshaler) {
if rv.Kind() == reflect.Pointer && rv.IsNil() {
// If it's a nil pointer, we must allocate it to prevent a panic
// when calling the interface method on the nil receiver.
pointer.Alloc(rv)
}
return rv.Interface().(Unmarshaler), true
}
// Case 2: A pointer to the field's type implements Unmarshaler.
// This works for value types with pointer receivers (e.g., reverse).
if rv.CanAddr() && rv.Addr().Type().Implements(typeUnmarshaler) {
return rv.Addr().Interface().(Unmarshaler), true
}
return nil, false
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package event provides a high-performance, in-memory event bus system.
//
// It relies on a lock-free ring buffer for low-latency event publishing and an
// atomic copy-on-write mechanism for thread-safe subscriber management. The
// package offers both standalone event streams ([Bus]) and a centralized topic
// manager ([Broker]) for safely routing different event types across an
// application.
//
// # Usage
//
// A typical setup involves initializing a [Broker], retrieving a typed [Bus]
// for a topic, and subscribing to or publishing events.
//
// Example:
//
// type UserCreated struct {
// Email string
// }
//
// // 1. Initialize the central broker with options.
// broker := event.NewBroker(event.WithSyncDispatch())
// defer broker.Close()
//
// // 2. Retrieve a typed bus for a specific topic.
// bus := event.Topic[UserCreated](broker, "users.created")
//
// // 3. Subscribe to the event stream.
// unsub := bus.Subscribe(func(e UserCreated) {
// fmt.Println("New user registered:", e.Email)
// })
// defer unsub()
//
// // 4. Publish an event.
// bus.Publish(UserCreated{Email: "alice@example.com"})
package event
import (
"fmt"
"log/slog"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/deep-rent/nexus/internal/ring"
)
// OverflowMode determines how the bus behaves when the internal buffer is full.
type OverflowMode = ring.Policy
const (
// Block waits until space is available in the buffer.
Block = ring.Block
// DropOldest removes the oldest unread event to make room for the new one.
DropOldest = ring.DropOldest
// DropNewest discards the incoming event if the buffer is full.
DropNewest = ring.DropNewest
)
const (
// DefaultSize is the default capacity of the internal ring buffer.
// It is automatically rounded up to the nearest power of 2.
DefaultSize = 1024
// DefaultOverflowMode is the default overflow mode ([Block]).
DefaultOverflowMode = Block
)
// Subscriber is a callback function that handles events of type T.
type Subscriber[T any] func(T)
// WaitStrategy defines the idling behavior of the background processor when the
// ring buffer is empty.
type WaitStrategy interface {
// Snooze is called when the buffer is empty. The idle parameter represents
// the number of consecutive empty polls.
Snooze(idle int)
// Signal awakens the processor from a Snooze when a new event arrives.
Signal()
}
// adaptiveWait employs a spin-yield-sleep sequence to minimize latency while
// preventing constant CPU burn during idle periods.
type adaptiveWait struct{}
// Snooze scales the waiting mechanism based on how long the bus has been idle.
func (adaptiveWait) Snooze(idle int) {
const (
phase1 = 1000 // Spin-yield limit
phase2 = 5000 // Sleep limit
)
switch {
case idle < phase1:
// Low latency mode: Yield the processor but stay actively scheduled.
runtime.Gosched()
case idle < phase2:
// Cooldown mode: Drop CPU usage significantly while maintaining fast
// response.
time.Sleep(time.Microsecond)
default:
// Deep idle mode: Near 0% CPU consumption.
time.Sleep(time.Millisecond)
}
}
// Signal is a no-op because the loop actively wakes itself up.
func (adaptiveWait) Signal() {}
// blockingWait uses a semaphore channel to park the goroutine entirely when
// idle, saving CPU cycles at the cost of a slight wakeup latency.
type blockingWait struct {
// sem is a buffered channel acting as a non-blocking signaling mechanism.
sem chan struct{}
}
// Snooze parks the goroutine until a value is received on the semaphore channel.
func (w *blockingWait) Snooze(_ int) { <-w.sem }
// Signal attempts to send a wakeup token. If the channel already has a token,
// it drops the send to avoid blocking the publisher.
func (w *blockingWait) Signal() {
select {
case w.sem <- struct{}{}:
default:
}
}
// handler pairs a unique identifier with a subscriber function for internal
// dispatching. The identifier allows for constant-time unsubscription without
// relying on function pointers.
type handler[T any] struct {
// id is a unique identifier for the subscriber.
id uint64
// fn is the callback function to be executed.
fn Subscriber[T]
}
// dispatcher defines the internal strategy for delivering events to
// subscribers.
type dispatcher[T any] interface {
// dispatch delivers the event to the provided list of handlers.
dispatch(event T, handlers []handler[T])
}
// basicDispatcher delivers events sequentially on the background worker's
// goroutine.
type basicDispatcher[T any] struct {
// logger records any panics triggered by a subscriber function.
logger *slog.Logger
}
// dispatch iterates through all handlers and executes them sequentially.
func (d basicDispatcher[T]) dispatch(event T, handlers []handler[T]) {
for _, h := range handlers {
// Isolate each handler call to prevent a panic in one subscriber from
// crashing the entire background processor.
func() {
defer func() {
if r := recover(); r != nil {
d.logger.Error(
"Subscriber panicked",
slog.Any("panic", r),
slog.String("stack", string(debug.Stack())),
)
}
}()
h.fn(event)
}()
}
}
// asyncDispatcher delivers events concurrently by spawning a goroutine per
// subscriber.
type asyncDispatcher[T any] struct {
// logger records any panics triggered by a subscriber function.
logger *slog.Logger
}
// dispatch executes all handlers in parallel.
func (d asyncDispatcher[T]) dispatch(event T, handlers []handler[T]) {
for _, h := range handlers {
go func(f Subscriber[T]) {
defer func() {
if r := recover(); r != nil {
d.logger.Error(
"Subscriber panicked",
slog.Any("panic", r),
slog.String("stack", string(debug.Stack())),
)
}
}()
f(event)
}(h.fn)
}
}
// Option configures the [Bus] during initialization.
type Option func(*config)
// config aggregates all user-defined settings for the [Bus].
type config struct {
// size is the internal buffer capacity.
size int
// mode is the behavior on buffer overflow.
mode OverflowMode
// sync determines if dispatching is sequential.
sync bool
// wait is the idling strategy for the background worker.
wait WaitStrategy
// logger is used for reporting errors and panics.
logger *slog.Logger
}
// WithSize sets the buffer capacity (rounded up to the nearest power of 2).
// Defaults to [DefaultSize]. Non-positive values will be ignored.
func WithSize(size int) Option {
return func(o *config) {
if size > 0 {
o.size = size
}
}
}
// WithOverflowMode defines how the bus deals with backpressure on buffer
// exhaustion. Defaults to [DefaultOverflowMode].
func WithOverflowMode(mode OverflowMode) Option {
return func(o *config) {
o.mode = mode
}
}
// WithSyncDispatch forces sequential event delivery. If omitted, the bus
// defaults to asynchronous parallel delivery.
func WithSyncDispatch() Option {
return func(o *config) {
o.sync = true
}
}
// WithAdaptiveWait uses a low-latency spin-yield-sleep strategy. This is the
// default.
func WithAdaptiveWait() Option {
return func(o *config) {
o.wait = adaptiveWait{}
}
}
// WithBlockingWait uses a semaphore to park the CPU when idle. Ideal for
// multi-tenant setups.
func WithBlockingWait() Option {
return func(o *config) {
o.wait = &blockingWait{sem: make(chan struct{}, 1)}
}
}
// WithCustomWaitStrategy injects a user-defined idling strategy. Nil values are
// ignored. If passed to a [Broker], the instance is shared across all buses.
func WithCustomWaitStrategy(strategy WaitStrategy) Option {
return func(o *config) {
if strategy != nil {
o.wait = strategy
}
}
}
// WithLogger sets the structured logger for recording subscriber panics. If not
// provided, it defaults to [slog.Default].
func WithLogger(logger *slog.Logger) Option {
return func(o *config) {
if logger != nil {
o.logger = logger
}
}
}
// Bus is a high-performance, strictly-typed event stream.
type Bus[T any] struct {
// evts is the underlying lock-free ring buffer.
evts *ring.Buffer[T]
// disp is the configured strategy for calling subscriber functions.
disp dispatcher[T]
// wait dictates how the processor idles when the buffer is empty.
wait WaitStrategy
// subs is a copy-on-write pointer holding the active list of subscribers.
subs atomic.Pointer[[]handler[T]]
// closed indicates whether the bus has been shut down.
closed atomic.Bool
// mu protects write operations to the active subscriber list.
mu sync.Mutex
// id is an incrementing counter providing unique keys for new subscribers.
id uint64
// wg tracks the lifecycle of the background processor goroutine.
wg sync.WaitGroup
}
// NewBus initializes a [Bus] with the provided options.
func NewBus[T any](opts ...Option) *Bus[T] {
cfg := config{
size: DefaultSize,
mode: DefaultOverflowMode,
sync: false,
wait: adaptiveWait{},
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.logger == nil {
cfg.logger = slog.Default()
}
var disp dispatcher[T]
if cfg.sync {
disp = basicDispatcher[T]{
logger: cfg.logger,
}
} else {
disp = asyncDispatcher[T]{
logger: cfg.logger,
}
}
bus := &Bus[T]{
evts: ring.New[T](cfg.size, cfg.mode),
disp: disp,
wait: cfg.wait,
}
// Seed the atomic pointer with an empty slice to avoid nil pointer panics on
// first load.
empty := make([]handler[T], 0)
bus.subs.Store(&empty)
// Spin up the background processor.
bus.wg.Add(1)
go bus.process()
return bus
}
// Subscribe adds a callback to the bus. It returns an unsubscribe function that
// removes the callback when invoked.
func (b *Bus[T]) Subscribe(fn Subscriber[T]) (unsubscribe func()) {
b.mu.Lock()
defer b.mu.Unlock()
b.id++
id := b.id
// Copy-on-write: Load current state, clone into a larger slice, and append.
curr := *b.subs.Load()
next := make([]handler[T], len(curr), len(curr)+1)
copy(next, curr)
next = append(next, handler[T]{id: id, fn: fn})
// Atomically swap the new slice into place for the background processor to
// read lock-free.
b.subs.Store(&next)
// Guarantee the teardown logic only runs exactly once.
var once sync.Once
return func() {
once.Do(func() {
b.detach(id)
})
}
}
// detach filters out the subscriber matching the given ID.
func (b *Bus[T]) detach(id uint64) {
b.mu.Lock()
defer b.mu.Unlock()
curr := *b.subs.Load()
// Pre-allocate the new slice. By creating a new backing array, we ensure
// the old array (and its function pointers) can be garbage collected.
next := make([]handler[T], 0, len(curr))
for _, h := range curr {
if h.id != id {
next = append(next, h)
}
}
b.subs.Store(&next)
}
// Publish pushes an event to the bus. It returns false if the buffer is full
// (and [DropNewest] policy is active) or if the bus is already closed.
func (b *Bus[T]) Publish(event T) bool {
// Guard against publishing to a stopped bus.
if b.closed.Load() {
return false
}
// Attempt to push to the lock-free ring buffer.
if b.evts.Push(event) {
// Awaken the processor if it happens to be snoozing.
b.wait.Signal()
return true
}
return false
}
// Close signals the background processor to drain remaining events and stop.
// Further calls to [Bus.Publish] will immediately return false.
//
// Note: Producers must be externally synchronized to stop calling Publish
// before Close is invoked to prevent stranded events.
func (b *Bus[T]) Close() {
// Give straggling producers a few microseconds to finish their push before
// we officially close the gates.
time.Sleep(time.Microsecond * 50)
// Atomically swap to closed. If it was already closed, do nothing.
if !b.closed.Swap(true) {
// Wake up the processor if it is blocking on a semaphore so it can
// perform its final drain and exit.
b.wait.Signal()
// Wait for the processor goroutine to finish.
b.wg.Wait()
}
}
// process continuously polls the ring buffer for new events.
func (b *Bus[T]) process() {
defer b.wg.Done()
idle := 0
for {
// Fast path: attempt to pop an event off the lock-free queue.
if evt, ok := b.evts.Pop(); ok {
idle = 0 // Reset the backoff counter on success
// Load a read-only snapshot of the subscribers.
if handlers := *b.subs.Load(); len(handlers) > 0 {
b.disp.dispatch(evt, handlers)
}
} else {
// Slow path: queue is empty.
if b.closed.Load() {
// The bus was closed. Perform one final exhaustive drain check
// in case events were published just before the close signal.
for {
if final, ok := b.evts.Pop(); ok {
if handlers := *b.subs.Load(); len(handlers) > 0 {
b.disp.dispatch(final, handlers)
}
} else {
// Queue is truly empty and bus is closed; exit.
return
}
}
}
// Backoff and yield to prevent spinning the CPU at 100% capacity.
b.wait.Snooze(idle)
idle++
}
}
}
// closer is an internal interface that allows the [Broker] to shut down buses
// without knowing their generic type payloads.
type closer interface {
// Close signals the resource to shut down.
Close()
}
// Broker manages a collection of typed event buses segregated by topic strings.
type Broker struct {
// mu protects the buses map.
mu sync.RWMutex
// buses maps topic names to their underlying typed Bus instances.
buses map[string]closer
// opts are the default options applied to all buses created by this broker.
opts []Option
}
// NewBroker initializes an empty event [Broker] with options applied to all
// subsequently created buses.
func NewBroker(opts ...Option) *Broker {
return &Broker{
buses: make(map[string]closer),
opts: opts,
}
}
// Topic retrieves an existing [Bus] for the given topic or creates a new one
// using the broker's configured options. It panics if the topic already exists
// but is registered to a different event type.
func Topic[T any](b *Broker, name string) *Bus[T] {
// Fast path: Invoke the read-only lock.
b.mu.RLock()
existing, exists := b.buses[name]
b.mu.RUnlock()
if exists {
// Type assert back to the requested generic type.
bus, ok := existing.(*Bus[T])
if !ok {
panic(fmt.Sprintf(
"event: topic %q exists but expects a different event type",
name,
))
}
return bus
}
// Slow path: Invoke the write lock to initialize.
b.mu.Lock()
defer b.mu.Unlock()
// Double-check locking in case another goroutine initialized it while we
// were waiting to acquire the write lock.
if existing, exists = b.buses[name]; exists {
bus, ok := existing.(*Bus[T])
if !ok {
panic(fmt.Sprintf(
"event: topic %q exists but expects a different event type",
name,
))
}
return bus
}
// Create and store the new typed bus.
bus := NewBus[T](b.opts...)
b.buses[name] = bus
return bus
}
// Close gracefully shuts down all buses managed by the broker.
func (b *Broker) Close() {
b.mu.Lock()
// 1. Capture the existing buses.
buses := b.buses
// 2. Clear the map to release references and block new retrievals.
b.buses = make(map[string]closer)
b.mu.Unlock() // Release the lock before calling Close on all the buses
// 3. Close all buses concurrently so the 50µs grace periods overlap.
var wg sync.WaitGroup
for _, bus := range buses {
wg.Add(1)
go func(c closer) {
defer wg.Done()
c.Close()
}(bus)
}
wg.Wait()
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package header provides a collection of utility functions for parsing,
// interpreting, and manipulating HTTP headers.
//
// The package includes helpers for common header-related tasks, such as:
// - Parsing comma-separated directives (e.g., "max-age=3600").
// - Parsing wildcard-aware content negotiation headers with q-factors.
// - Parsing RFC 5988 Link headers to extract relations for API pagination.
// - Extracting credentials from an Authorization header.
// - Calculating cache lifetime from Cache-Control and Expires headers.
// - Determining throttle delays from Retry-After and X-Ratelimit-* headers.
//
// It also provides a convenient [http.RoundTripper] implementation for
// automatically attaching a static set of headers to all outgoing requests.
package header
import (
"iter"
"mime"
"net/http"
"strconv"
"strings"
"time"
)
// Directives parses a comma-separated header value into an iterator of
// key-value pairs.
//
// For example, parsing "no-cache, max-age=3600" would yield twice: first
// "no-cache", "" and then "max-age", "3600".
func Directives(s string) iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
for kv := range strings.SplitSeq(s, ",") {
k, v, ok := strings.Cut(strings.TrimSpace(kv), "=")
k = strings.ToLower(strings.TrimSpace(k))
if ok {
v = strings.TrimSpace(v)
}
if !yield(k, v) {
return
}
}
}
}
// Throttle determines the required delay before the next request based on
// rate-limiting headers in the response. It accepts a clock function to
// calculate relative times. If no throttling is indicated, it returns a
// duration of 0.
func Throttle(h http.Header, now func() time.Time) time.Duration {
if v := h.Get("Retry-After"); v != "" {
if d, err := strconv.ParseInt(v, 10, 64); err == nil && d > 0 {
return time.Duration(d) * time.Second
}
if t, err := http.ParseTime(v); err == nil {
if d := t.Sub(now()); d > 0 {
return d
}
}
}
if h.Get("X-Ratelimit-Remaining") == "0" {
if v := h.Get("X-Ratelimit-Reset"); v != "" {
if t, err := strconv.ParseInt(v, 10, 64); err == nil && t > 0 {
if d := time.Unix(t, 0).Sub(now()); d > 0 {
return d
}
}
}
}
return 0
}
// Lifetime determines the cache lifetime of a response based on caching
// headers. It accepts a clock function to calculate relative times. It returns
// a duration of 0 if the response is not cacheable or does not carry any
// caching information.
func Lifetime(h http.Header, now func() time.Time) time.Duration {
// Cache-Control takes precedence over Expires
if v := h.Get("Cache-Control"); v != "" {
for k, v := range Directives(v) {
switch k {
case "no-cache", "no-store":
return 0
case "max-age":
if d, err := strconv.ParseInt(v, 10, 64); err == nil {
return time.Duration(d) * time.Second
}
}
}
}
if v := h.Get("Expires"); v != "" {
if t, err := http.ParseTime(v); err == nil {
if d := t.Sub(now()); d > 0 {
return d
}
}
}
return 0
}
// Credentials extracts the credentials from the Authorization header of an
// HTTP request for a specific authentication scheme (e.g., "Basic", "Bearer").
//
// It returns the raw credentials as-is, or an empty string if the header is
// not present, not well-formed, or does not match the specified scheme. The
// scheme comparison is case-insensitive.
func Credentials(h http.Header, scheme string) string {
auth := h.Get("Authorization")
if auth == "" {
return ""
}
prefix, credentials, ok := strings.Cut(auth, " ")
if !ok || !strings.EqualFold(prefix, scheme) {
return ""
}
return credentials
}
// Preferences parses a header value with quality factors (e.g., Accept,
// Accept-Encoding, Accept-Language) into an iterator quality factors (q-value)
// by name (media range). The values are yielded in the order they appear in the
// header, not sorted by quality. Values without an explicit q-factor are
// assigned a default quality of 1.0. Malformed q-factors are also treated as
// 1.0, while out-of-range values are clamped into the [0.0, 1.0] interval.
func Preferences(s string) iter.Seq2[string, float64] {
return func(yield func(string, float64) bool) {
for part := range strings.SplitSeq(s, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
params := strings.Split(part, ";")
q := 1.0
for i := 1; i < len(params); i++ {
p := strings.TrimSpace(params[i])
k, v, found := strings.Cut(p, "=")
if found && strings.TrimSpace(k) == "q" {
v = strings.TrimSpace(v)
if f, err := strconv.ParseFloat(v, 64); err == nil {
q = min(1.0, max(0.0, f))
}
break
}
}
if !yield(strings.TrimSpace(params[0]), q) {
return
}
}
}
}
// Accepts checks if the given key is accepted based on a header value with
// quality factors (e.g., Accept, Accept-Encoding, or Accept-Language).
// It properly weights exact matches over partial wildcards (e.g., "text/*")
// and global wildcards ("*/*" or "*"), returning true if the best match has
// a q-value greater than zero.
func Accepts(s, key string) bool {
var (
maxQ float64
maxP int
)
// Extract the major type (e.g., "text" from "text/html") for partial
// wildcards.
major, _, has := strings.Cut(key, "/")
for k, q := range Preferences(s) {
var p int
switch {
case k == key:
p = 3 // Exact match (highest precedence)
case has && k == major+"/*":
p = 2 // Partial wildcard match (e.g., "text/*")
case k == "*/*" || k == "*":
p = 1 // Global wildcard match
}
// Update if we found a more specific match than our current best.
if p > maxP {
maxP = p
maxQ = q
}
}
// It is accepted if we found a valid match and its q-value is greater than 0.
return maxP > 0 && maxQ > 0
}
// MediaType extracts and returns the media type from a Content-Type header.
// It returns the media type in lowercase, trimmed of whitespace. If the header
// is empty or malformed, it returns an empty string.
//
// This function is similar to [mime.ParseMediaType] but does not return any
// parameters and ignores parsing errors.
func MediaType(h http.Header) string {
v := h.Get("Content-Type")
if v == "" {
return ""
}
i := strings.IndexByte(v, ';')
if i != -1 {
v = v[:i]
}
return strings.ToLower(strings.TrimSpace(v))
}
// Links parses an RFC 5988 Link header into an iterator of relation types (rel)
// and their corresponding URLs.
//
// If a link has multiple space-separated relations (e.g., rel="next archive"),
// it yields the URL for each relation separately.
func Links(s string) iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
for part := range strings.SplitSeq(s, ",") {
sidx := strings.IndexByte(part, '<')
eidx := strings.IndexByte(part, '>')
// Ensure the URL brackets are present and valid.
if sidx == -1 || eidx == -1 || sidx >= eidx {
continue
}
url := part[sidx+1 : eidx]
// Parse the parameters following the URL.
params := strings.SplitSeq(part[eidx+1:], ";")
for p := range params {
p = strings.TrimSpace(p)
k, v, found := strings.Cut(p, "=")
if found && strings.ToLower(strings.TrimSpace(k)) == "rel" {
// Remove optional quotes around the relation value.
v = strings.Trim(strings.TrimSpace(v), `"`)
// A single link can have multiple relation types.
for rel := range strings.FieldsSeq(v) {
if !yield(strings.ToLower(rel), url) {
return
}
}
}
}
}
}
}
// Link extracts the URL for a specific relation (e.g., "next" or "last") from
// a Link header. It returns an empty string if the relation is not found.
func Link(s, rel string) string {
rel = strings.ToLower(rel)
for k, v := range Links(s) {
if k == rel {
return v
}
}
return ""
}
// Filename extracts the intended filename from a Content-Disposition header.
//
// It automatically handles both the standard "filename" parameter and the
// RFC 6266 "filename*" parameter, which is used for non-ASCII (UTF-8) names.
// It returns an empty string if the header is missing, malformed, or does
// not contain a filename.
func Filename(h http.Header) string {
v := h.Get("Content-Disposition")
if v == "" {
return ""
}
_, params, err := mime.ParseMediaType(v)
if err != nil {
return ""
}
// The filename* parameter is decoded automatically.
return params["filename"]
}
// Header represents a single HTTP header key-value pair.
type Header struct {
// Key is the canonicalized header name.
Key string
// Value is the raw value of the header.
Value string
}
// String formats the header as "Key: Value".
func (h Header) String() string {
return h.Key + ": " + h.Value
}
// New creates a new [Header] with the given key and value. The key is
// automatically canonicalized to the standard HTTP header format.
func New(key, value string) Header {
return Header{
Key: http.CanonicalHeaderKey(key),
Value: value,
}
}
// UserAgent constructs a User-Agent header with the specified name, version,
// and an optional comment. The resulting value follows the format "name/version
// (comment)". The first part is the product token, while the parenthesized
// section provides supplementary information about the client. For external
// calls, it is best practice to include maintainer contact details in the
// comment (such as an URL or email address).
func UserAgent(name, version, comment string) Header {
value := name + "/" + version
if comment != "" {
value += " (" + comment + ")"
}
return Header{
Key: "User-Agent",
Value: value,
}
}
// transport is an internal [http.RoundTripper] that injects static headers.
type transport struct {
// wrapped is the underlying RoundTripper.
wrapped http.RoundTripper
// headers are the static headers to be injected into each request.
headers []Header
}
// RoundTrip clones the request and adds static headers before delegating.
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
for _, h := range t.headers {
clone.Header.Set(h.Key, h.Value)
}
return t.wrapped.RoundTrip(clone)
}
var _ http.RoundTripper = (*transport)(nil)
// NewTransport wraps a base transport and sets a static set of headers on
// each outgoing request. If the provided headers map is empty, the base
// transport is returned unmodified. The function creates a defensive copy of
// the provided map. The resulting transport clones the request before
// delegating to the base transport, so the original request is not changed.
func NewTransport(
t http.RoundTripper,
headers ...Header,
) http.RoundTripper {
if len(headers) == 0 {
return t
}
return &transport{
wrapped: t,
headers: headers,
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package check provides a collection of standard health check constructors for
// common infrastructure dependencies.
//
// It includes implementations for TCP connectivity, HTTP responsiveness, DNS
// resolution, and database pings. These functions return a [health.CheckFunc]
// that can be registered with a [health.Monitor] to automate dependency
// monitoring.
//
// # Usage
//
// The constructors in this package are designed to be passed directly into the
// Attach method of a [health.Monitor].
//
// Example:
//
// monitor := health.NewMonitor()
//
// // Check a Redis instance via TCP
// monitor.Attach(
// "redis",
// 2*time.Second,
// check.TCP("localhost:6379", 1*time.Second),
// )
//
// // Check an external API with a custom HTTP client
// monitor.Attach(
// "stripe",
// 10*time.Second,
// check.HTTP(client, "https://api.stripe.com/health"),
// )
package check
import (
"context"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/deep-rent/nexus/health"
)
// TCP returns a health check that attempts to establish a TCP connection
// to the specified address.
//
// It returns [health.StatusSick] if the connection cannot be established within
// the provided timeout.
func TCP(addr string, timeout time.Duration) health.CheckFunc {
return func(ctx context.Context) (health.Status, error) {
d := net.Dialer{Timeout: timeout}
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return health.StatusSick, fmt.Errorf("tcp dial %s: %w", addr, err)
}
_ = conn.Close()
return health.StatusHealthy, nil
}
}
// HTTP returns a health check that performs a GET request to the specified URL.
//
// If client is nil, [http.DefaultClient] is employed. The check logic includes:
//
// 1. Fallback Timeout: If neither the client nor the request context has a
// deadline, a 10-second timeout is applied.
// 2. Connection Hygiene: The response body is fully drained and closed
// to ensure the underlying TCP connection can be reused.
// 3. Status Codes: Any status code in the 2xx or 3xx range is considered
// healthy.
func HTTP(client *http.Client, url string) health.CheckFunc {
const defaultTimeout = 10 * time.Second
if client == nil {
client = http.DefaultClient
}
return func(ctx context.Context) (health.Status, error) {
child := ctx
// If the client has no timeout set, we enforce a fallback timeout
// specifically for this check execution using the context.
// We only do this if the incoming context doesn't already have a deadline.
if _, deadline := ctx.Deadline(); !deadline && client.Timeout == 0 {
var cancel context.CancelFunc
child, cancel = context.WithTimeout(ctx, defaultTimeout)
defer cancel()
}
req, err := http.NewRequestWithContext(child, http.MethodGet, url, nil)
if err != nil {
return health.StatusSick, fmt.Errorf("http request %s: %w", url, err)
}
res, err := client.Do(req)
if err != nil {
return health.StatusSick, fmt.Errorf("http get %s: %w", url, err)
}
// Ensure the body is drained so the connection can be reused.
defer func() {
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()
}()
code := res.StatusCode
if code >= http.StatusOK && code < http.StatusBadRequest {
return health.StatusHealthy, nil
}
return health.StatusSick, fmt.Errorf(
"http get %s: unexpected status code: %d",
url, code,
)
}
}
// Pinger is an interface for types that support context-aware connectivity
// checks.
//
// This is most commonly satisfied by [*database/sql.DB] from the standard
// library.
type Pinger interface {
// PingContext verifies a connection to the target system is still alive.
PingContext(ctx context.Context) error
}
// Ping returns a health check that calls PingContext on the provided [Pinger].
//
// It is ideal for monitoring the health of SQL database connections.
func Ping(p Pinger) health.CheckFunc {
return func(ctx context.Context) (health.Status, error) {
if err := p.PingContext(ctx); err != nil {
return health.StatusSick, err
}
return health.StatusHealthy, nil
}
}
// DNS returns a health check that verifies the provided host resolves
// to at least one IP address using the default system resolver.
func DNS(host string) health.CheckFunc {
return func(ctx context.Context) (health.Status, error) {
_, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return health.StatusSick, fmt.Errorf("dns lookup %s: %w", host, err)
}
return health.StatusHealthy, nil
}
}
// Wrap converts a simple function that returns an error into a health check
// callback.
//
// The resulting check is not context-aware and will ignore the context passed
// during execution.
func Wrap(fn func() error) health.CheckFunc {
return func(ctx context.Context) (health.Status, error) {
if err := fn(); err != nil {
return health.StatusSick, err
}
return health.StatusHealthy, nil
}
}
// WrapContext converts a context-aware function into a health check callback.
//
// This is used for custom checks that need to respect timeouts or cancellation
// signals provided by the [health.Monitor].
func WrapContext(fn func(context.Context) error) health.CheckFunc {
return func(ctx context.Context) (health.Status, error) {
if err := fn(ctx); err != nil {
return health.StatusSick, err
}
return health.StatusHealthy, nil
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package health provides a registry and HTTP handlers for application health
// monitoring.
//
// It allows for the registration of pluggable health checks with built-in
// TTL-based caching to prevent overloading downstream dependencies. The package
// handles the orchestration of these checks, providing thread-safe execution
// and aggregation of results into standardized [Report] formats suitable for
// automated monitoring systems and human inspection.
//
// # Usage
//
// To use the health monitor, create a new instance, attach your dependency
// checks, and mount the handlers to your router.
//
// Example:
//
// monitor := health.NewMonitor()
//
// // Register a check with a 5-second minimum delay between invocations.
// monitor.Attach("database", 5*time.Second, check.Ping(db))
//
// // Mount the standard endpoints (/health, /health/live, /health/ready)
// // to a router.Router instance.
// monitor.Mount(r)
package health
import (
"context"
"encoding/json/v2"
"fmt"
"net/http"
"sync"
"time"
"github.com/deep-rent/nexus/router"
)
// Status enumerates the operational states of a dependency.
//
// Note that statuses are ranked by severity: StatusHealthy > StatusDegraded >
// StatusSick. This allows for direct comparison using standard operators.
type Status int
const (
// StatusSick indicates the dependency is non-functional.
StatusSick Status = iota
// StatusDegraded indicates the dependency is functioning but with
// issues (e.g., high latency).
StatusDegraded
// StatusHealthy indicates the dependency is functioning normally.
StatusHealthy
)
// String returns the human-readable representation of the [Status].
func (s Status) String() string {
switch s {
case StatusHealthy:
return "healthy"
case StatusDegraded:
return "degraded"
case StatusSick:
return "sick"
default:
return "unknown"
}
}
// MarshalJSON implements the [json.Marshaler] interface, ensuring that the
// status is represented by its string name in JSON output rather than its
// underlying integer value.
func (s Status) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String())
}
// UnmarshalJSON implements the [json.Unmarshaler] interface. It converts a
// JSON string back into the corresponding [Status] integer constant. It returns
// an error if the string is not a recognized status.
func (s *Status) UnmarshalJSON(data []byte) error {
var v string
if err := json.Unmarshal(data, &v); err != nil {
return err
}
switch v {
case "healthy":
*s = StatusHealthy
case "degraded":
*s = StatusDegraded
case "sick":
*s = StatusSick
default:
return fmt.Errorf("invalid status: %s", v)
}
return nil
}
// Result holds the outcome of a health check execution.
type Result struct {
// Status is the state of the check.
Status Status `json:"status"`
// Error contains a descriptive error message if the check failed.
Error string `json:"error,omitempty"`
// Timestamp records when this check was actually executed.
Timestamp time.Time `json:"timestamp,format:unix"`
}
// Report represents the aggregated outcome of all registered health checks.
type Report struct {
// Status is the overall health state of the application.
Status Status `json:"status"`
// Checks maps the name of each registered check to its specific [Result].
Checks map[string]Result `json:"checks"`
}
// CheckFunc defines the signature for a pluggable health check. It receives
// the request context to allow for cancellation and should return the
// perceived [Status] and an error if applicable.
type CheckFunc func(ctx context.Context) (Status, error)
// check wraps a registered check with its caching state and mutex.
type check struct {
// name is the identifier for the health check.
name string
// fn is the [CheckFunc] to execute.
fn CheckFunc
// ttl is the duration for which the [Result] is considered fresh.
ttl time.Duration
// mu protects access to the cached last result.
mu sync.RWMutex
// last is the most recently recorded [Result].
last Result
}
// run executes the check or returns the cached result if the TTL hasn't
// expired. It protects against panics in the callback.
func (c *check) run(ctx context.Context) (res Result) {
c.mu.RLock()
if time.Since(c.last.Timestamp) < c.ttl {
res := c.last
c.mu.RUnlock()
return res
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
// Double-check the cache after acquiring the write lock.
if time.Since(c.last.Timestamp) < c.ttl {
return c.last
}
defer func() {
if r := recover(); r != nil {
c.last = Result{
Status: StatusSick,
Error: fmt.Sprintf("health check panicked: %v", r),
Timestamp: time.Now(),
}
res = c.last
}
}()
status, err := c.fn(ctx)
msg := ""
if err != nil {
msg = err.Error()
// Default to sick if an error occurs but the status wasn't explicitly set
// to degraded.
if status != StatusDegraded {
status = StatusSick
}
}
c.last = Result{
Status: status,
Error: msg,
Timestamp: time.Now(),
}
return c.last
}
// Monitor manages the registry of health checks and provides the
// [router]-compatible handlers. It is safe for concurrent use.
type Monitor struct {
// mu protects access to the internal map of checks.
mu sync.RWMutex
// checks stores registered health checks indexed by name.
checks map[string]*check
}
// NewMonitor creates a fresh [Monitor] instance.
func NewMonitor() *Monitor {
return &Monitor{
checks: make(map[string]*check),
}
}
// Attach registers a new health check under the given name. If a check with
// the same name already exists, it is replaced.
//
// The name should be formatted in snake_case (e.g., "redis_primary").
// The TTL (Time-To-Live) parameter defines the minimum duration between
// consecutive executions of the [CheckFunc]; subsequent calls within this
// window return the cached [Result] to prevent overloading the dependency.
func (m *Monitor) Attach(name string, ttl time.Duration, fn CheckFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.checks[name] = &check{
name: name,
fn: fn,
ttl: ttl,
}
}
// Detach unregisters a health check by name. If the check does not exist, this
// is a no-op.
func (m *Monitor) Detach(name string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.checks, name)
}
// run runs all registered checks concurrently and compiles the overall [Status]
// from the gathered results.
func (m *Monitor) run(ctx context.Context) (Status, map[string]Result) {
m.mu.RLock()
checks := make([]*check, 0, len(m.checks))
for _, c := range m.checks {
checks = append(checks, c)
}
m.mu.RUnlock()
results := make(map[string]Result, len(checks))
overall := StatusHealthy
var wg sync.WaitGroup
var mu sync.Mutex
for _, c := range checks {
wg.Add(1)
go func(current *check) {
defer wg.Done()
res := current.run(ctx)
mu.Lock()
results[current.name] = res
if res.Status < overall {
overall = res.Status
}
mu.Unlock()
}(c)
}
wg.Wait()
return overall, results
}
// Live returns a handler that indicates if the application process is alive.
// It always returns [StatusHealthy] and HTTP 200 without checking dependencies,
// serving as a basic liveness probe to detect process hangs.
func (m *Monitor) Live() router.HandlerFunc {
return func(e *router.Exchange) error {
return e.JSON(http.StatusOK, Report{
Status: StatusHealthy,
})
}
}
// Ready returns a handler that evaluates all registered checks.
// It returns HTTP 503 (Service Unavailable) if any check results in
// [StatusSick]. Otherwise, it returns HTTP 200.
func (m *Monitor) Ready() router.HandlerFunc {
return func(e *router.Exchange) error {
overall, results := m.run(e.Context())
code := http.StatusOK
if overall == StatusSick {
code = http.StatusServiceUnavailable
}
return e.JSON(code, Report{
Status: overall,
Checks: results,
})
}
}
// Handler is an alias for [Monitor.Ready]. It provides a detailed JSON
// breakdown of all checks, suitable for monitoring scrapers and dashboards.
func (m *Monitor) Handler() router.HandlerFunc {
return m.Ready()
}
// Mount registers the standard health check routes on the provided
// [router.Router].
//
// It exposes:
// - GET /health: Detailed summary of all checks.
// - GET /health/live: Shallow liveness probe.
// - GET /health/ready: Deep readiness probe.
func (m *Monitor) Mount(r *router.Router) {
r.HandleFunc("GET /health", m.Handler())
r.HandleFunc("GET /health/live", m.Live())
r.HandleFunc("GET /health/ready", m.Ready())
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ascii provides fast, rune-based classification and conversion
// functions specifically for ASCII characters.
//
// It is designed as a lightweight alternative to the standard [unicode] package
// for cases where only basic ASCII support is required. By focusing strictly on
// the ASCII range, it avoids the overhead of large Unicode lookup tables,
// making it suitable for high-performance parsing and validation tasks.
//
// # Usage
//
// You can use the classification functions to validate runes or conversion
// functions to shift casing.
//
// Example:
//
// r := 'A'
// if ascii.IsUpper(r) {
// lower := ascii.ToLower(r) // 'a'
// }
package ascii
// IsUpper reports whether the rune is an uppercase ASCII letter
// ('A' through 'Z').
func IsUpper(c rune) bool { return c >= 'A' && c <= 'Z' }
// IsLower reports whether the rune is a lowercase ASCII letter
// ('a' through 'z').
func IsLower(c rune) bool { return c >= 'a' && c <= 'z' }
// IsDigit reports whether the rune is an ASCII decimal digit
// ('0' through '9').
func IsDigit(c rune) bool { return c >= '0' && c <= '9' }
// IsAlpha reports whether the rune is an ASCII letter (uppercase or lowercase).
func IsAlpha(c rune) bool { return IsUpper(c) || IsLower(c) }
// IsAlphaNum reports whether the rune is an ASCII letter or decimal digit.
func IsAlphaNum(c rune) bool { return IsAlpha(c) || IsDigit(c) }
// IsHex reports whether the given rune is a hexadecimal character.
func IsHex(c rune) bool {
return IsDigit(c) || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
}
// IsWord reports whether the rune is an ASCII letter, digit, or underscore
// ('_').
//
// This is commonly used for validating variable names or identifiers.
func IsWord(c rune) bool { return IsAlphaNum(c) || c == '_' }
// IsSlug reports whether the rune is an ASCII letter, digit, or hyphen ('-').
//
// This is commonly used for validating URL path components.
func IsSlug(c rune) bool { return IsAlphaNum(c) || c == '-' }
// ToLower converts an uppercase ASCII rune to lowercase.
//
// If the rune is not an uppercase letter, it is returned unchanged.
func ToLower(c rune) rune {
if IsUpper(c) {
return c + ('a' - 'A')
}
return c
}
// ToUpper converts a lowercase ASCII rune to uppercase.
//
// If the rune is not a lowercase letter, it is returned unchanged.
func ToUpper(c rune) rune {
if IsLower(c) {
return c - ('a' - 'A')
}
return c
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package buffer provides a sync.Pool-backed implementation of
// [httputil.BufferPool] for reusing byte slices.
//
// It is designed to save memory and reduce GC pressure when dealing with large
// response bodies by recycling memory buffers. This implementation helps
// stabilize heap usage in high-throughput proxies or servers that frequently
// allocate temporary buffers for I/O operations.
//
// # Usage
//
// Create a new [Pool] and pass it to a reverse proxy or any component requiring
// a [httputil.BufferPool].
//
// Example:
//
// // Create a pool with 32KB initial buffers, capped at 1MB for reuse.
// pool := buffer.NewPool(32*1024, 1024*1024)
//
// proxy := &httputil.ReverseProxy{
// Director: director,
// BufferPool: pool,
// }
package buffer
import (
"net/http/httputil"
"sync"
)
// Pool implements [httputil.BufferPool] backed by a [sync.Pool] internally.
//
// It reduces allocations for large response bodies by reusing byte slices,
// thus lowering GC pressure.
type Pool struct {
// pool is the underlying [sync.Pool] storing buffer pointers.
pool sync.Pool
// size is the maximum capacity of a buffer allowed back into the pool.
size int
}
// NewPool creates a new [Pool] that returns buffers of at least minSize bytes.
//
// Buffers that grow beyond maxSize will be discarded during [Pool.Put]. Both
// numbers must be positive, or else the function panics; minSize will be
// clamped by maxSize.
func NewPool(minSize, maxSize int) *Pool {
if minSize <= 0 {
panic("buffer: minSize must be positive")
}
if maxSize <= 0 {
panic("buffer: maxSize must be positive")
}
minSize = min(minSize, maxSize)
// Store a pointer to a slice to avoid allocations when storing in the
// interface-typed pool.
alloc := func() any {
buf := make([]byte, minSize)
return &buf
}
return &Pool{
pool: sync.Pool{New: alloc},
size: maxSize,
}
}
// Get returns a reusable byte slice from the [Pool].
func (b *Pool) Get() []byte {
return *b.pool.Get().(*[]byte)
}
// Put returns the buffer to the [Pool] unless it grew beyond the size limit.
//
// If the capacity of the provided slice exceeds the maxSize defined during
// initialization, the buffer is dropped to allow the GC to reclaim memory and
// prevent the pool from holding onto excessively large slices.
func (b *Pool) Put(buf []byte) {
// Avoid holding on to overly large buffers.
if cap(buf) <= b.size {
b.pool.Put(&buf)
}
}
// Ensure Pool satisfies the [httputil.BufferPool] interface.
var _ httputil.BufferPool = (*Pool)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jitter provides functionality for adding random variation (jitter) to
// time durations.
//
// This package is designed to help distributed systems avoid "thundering herd"
// problems by desynchronizing retry attempts or periodic jobs. The jitter
// implementation is "subtractive". It calculates a duration randomly chosen
// between [d * (1 - p), d], where p is the jitter percentage. This ensures that
// the returned duration never exceeds the input duration, allowing strict
// adherence to maximum delay limits (e.g., in backoff strategies).
//
// # Usage
//
// Create a [Jitter] instance with a specific percentage and apply it to your
// base durations.
//
// Example:
//
// // Create a jitterer with 20% randomness.
// j := jitter.New(0.2, nil)
//
// // A 10s duration will result in a random value between 8s and 10s.
// d := j.Apply(10 * time.Second)
package jitter
import (
"math/rand/v2"
"time"
)
// Rand serves as a minimal facade over [rand.Rand] to ease mocking.
type Rand interface {
// Float64 generates a pseudo-random number in [0.0, 1.0).
Float64() float64
}
// Ensure compliance with parent interface.
var _ Rand = (*rand.Rand)(nil)
// seeded is a pre-seeded [Rand] instance for default use.
//
// Note: Go 1.20+ auto-seeds the global RNG, which spares us from time-based
// seeding.
var seeded Rand = rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64())) //nolint:gosec
// Jitter applies subtractive random jitter to a duration.
type Jitter struct {
// p is the jitter percentage between 0.0 and 1.0.
p float64
// r is the random number generator source.
r Rand
}
// New creates a new [Jitter] instance with the given percentage p (0.0 to 1.0)
// and source of randomness r.
//
// If r is nil, a default thread-safe, seeded generator is used.
func New(p float64, r Rand) *Jitter {
if r == nil {
r = seeded // Fallback
}
return &Jitter{
r: r,
p: p,
}
}
// Apply returns the duration d damped by a random amount based on the jitter
// percentage.
//
// The result is guaranteed to be in the range [[Jitter.Floor](d, 1.0), d].
func (j *Jitter) Apply(d time.Duration) time.Duration {
return j.Floor(d, j.r.Float64())
}
// Floor returns the minimum possible duration that [Jitter.Apply] could return
// for the given input d when provided a random factor f.
//
// While typically used internally with f as a random float, passing f = 1.0
// provides the absolute lower bound for the jittered duration.
func (j *Jitter) Floor(d time.Duration, f float64) time.Duration {
return time.Duration(float64(d) * (1 - f*j.p))
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package pointer provides reflection-based helpers for working with pointers.
//
// These functions are useful for dynamically allocating and dereferencing
// variables in contexts like configuration loading or data mapping. By using
// the [reflect] package, this package allows for the manipulation of types
// where the concrete structure is not known at compile time.
//
// # Usage
//
// The primary helpers allow for safe allocation and deep dereferencing of
// [reflect.Value] types.
//
// Example:
//
// var str *string
// rv := reflect.ValueOf(&str).Elem()
//
// // Allocates a new string and sets the pointer
// pointer.Alloc(rv)
//
// // Deeply dereferences even nested pointers like **int
// final := pointer.Deref(rv)
package pointer
import "reflect"
// Alloc allocates a new value for a nil pointer and sets the pointer to it.
//
// This function causes a panic if rv is not a settable pointer. It uses
// [reflect.New] to create a zero value of the pointer's element type and
// applies it to the provided [reflect.Value].
func Alloc(rv reflect.Value) {
rv.Set(reflect.New(rv.Type().Elem()))
}
// Deref follows pointers until it reaches a non-pointer, allocating if nil.
//
// If a nil pointer is encountered along the way, [Deref] will attempt to
// allocate a new value for it using [Alloc]. If it encounters an un-settable
// nil pointer (e.g., one within an unexported struct field), it stops and
// returns that pointer to prevent a panic. The final, non-pointer value is
// returned as a [reflect.Value].
func Deref(rv reflect.Value) reflect.Value {
// Loop through multi-level pointers to handle cases like **int.
for rv.Kind() == reflect.Pointer {
if rv.IsNil() {
// If the pointer is nil but cannot be set, we must stop
// here to avoid a panic.
if !rv.CanSet() {
break
}
Alloc(rv)
}
rv = rv.Elem()
}
return rv
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package quote provides utility functions for working with quoted strings.
//
// It offers a suite of tools for detecting, applying, and stripping single or
// double quotes from string data. These utilities are particularly helpful
// when parsing configuration files, processing CLI arguments, or normalizing
// user input where string literals may be wrapped in various quote styles.
//
// # Usage
//
// The package supports both single-layer operations and recursive unquoting.
//
// Example:
//
// // Remove a single layer
// s := quote.Remove(`"hello"`) // returns: hello
//
// // Remove nested layers
// s = quote.RemoveAll(`"'nested'"`) // returns: hello
//
// // Wrap content
// s = quote.Double("content") // returns: "content"
package quote
// Remove strips a single layer of surrounding single or double quotes from a
// string.
//
// If the string is not quoted or is too short to contain a matching pair, it is
// returned unchanged.
func Remove(s string) string {
if len(s) < 2 {
return s
}
// Check for a matching pair of quotes.
switch s[0] {
case '"':
if s[len(s)-1] == '"' {
return s[1 : len(s)-1]
}
case '\'':
if s[len(s)-1] == '\'' {
return s[1 : len(s)-1]
}
}
// Return the original string if no matching quotes are found.
return s
}
// RemoveAll strips all layers of surrounding quotes from a string, regardless
// of quote type mixing (e.g., "'hello'" becomes hello).
//
// It repeatedly applies [Remove] until no further changes are detected in the
// input string.
func RemoveAll(s string) string {
for {
unquoted := Remove(s)
if unquoted == s {
break
}
s = unquoted
}
return s
}
// Has returns true if the string is surrounded by a matching pair of single or
// double quotes.
func Has(s string) bool {
if len(s) < 2 {
return false
}
switch s[0] {
case '"', '\'':
return s[len(s)-1] == s[0]
}
return false
}
// Wrap surrounds the given string with the specified quote character.
//
// Note: It does not escape existing quotes inside the string. It essentially
// performs a simple concatenation of the quote rune and the content.
func Wrap(s string, q rune) string {
r := string(q)
return r + s + r
}
// Double wraps a string in double quotes using [Wrap].
func Double(s string) string { return Wrap(s, '"') }
// Single wraps a string in single quotes using [Wrap].
func Single(s string) string { return Wrap(s, '\'') }
// Is checks if the given rune is a single or double quote character.
func Is(c rune) bool { return c == '"' || c == '\'' }
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ring provides a generic, lock-free ring buffer designed for
// high-throughput concurrent queues.
//
// It relies on atomic compare-and-swap operations to manage read and write
// positions, completely avoiding mutex bottlenecks during high-load scenarios.
// The buffer's capacity is strictly enforced as a power of two, allowing for
// highly efficient bitwise operations when calculating array indices.
//
// # Usage
//
// To use the ring buffer, initialize it with a size and a overflow [Policy],
// then use Push and Pop for concurrent data exchange.
//
// Example:
//
// rb := ring.New[int](64, ring.DropOldest)
//
// // Add an item to the queue
// rb.Push(42)
//
// // Retrieve the item
// if val, ok := rb.Pop(); ok {
// fmt.Println(val) // Output: 42
// }
package ring
import (
"math/bits"
"runtime"
"sync/atomic"
)
// Policy dictates how the buffer behaves when a producer attempts to push into
// a queue that has reached its maximum capacity (overflow).
type Policy int
const (
// Block causes the producer to yield the processor to other goroutines (via
// [runtime.Gosched]) until space becomes available.
Block Policy = iota
// DropOldest forcefully advances the read pointer, discarding the oldest
// unread item in the buffer to make room for the newly pushed item.
DropOldest
// DropNewest immediately discards the incoming item being pushed, returning
// false and leaving the existing buffer contents unchanged.
DropNewest
)
// Buffer represents a bounded, lock-free, strongly-typed concurrent queue.
type Buffer[T any] struct {
// data holds the underlying circular storage for the buffer items.
// Its length is always a power of two.
data []T
// seq holds the sequence numbers for each slot to prevent read-before-write
// race conditions in concurrent MPMC scenarios.
seq []uint64
// head is a monotonically increasing counter representing the read index.
// The actual array index is calculated as (head & mask).
head uint64
// tail is a monotonically increasing counter representing the write index.
// The actual array index is calculated as (tail & mask).
tail uint64
// mask is used to perform a bitwise AND operation (tail & mask) to
// wrap the counters around the buffer size efficiently.
// It is equal to (capacity - 1).
mask uint64
// policy defines the behavior of the [Buffer.Push] operation when the
// difference between tail and head reaches the buffer capacity.
policy Policy
}
// New creates a [Buffer] configured with the requested size and overflow
// [Policy].
//
// If the provided size is less than 2, it defaults to 2. The final capacity is
// always automatically rounded up to the nearest power of two to optimize
// internal index masking via the [Buffer.mask].
func New[T any](size int, policy Policy) *Buffer[T] {
if size < 2 {
size = 2
}
// Round up to the next power of two.
p := uint(1 << bits.Len(uint(size-1)))
return &Buffer[T]{
data: make([]T, p),
seq: make([]uint64, p),
mask: uint64(p - 1),
policy: policy,
}
}
// Push adds an item to the tail of the buffer using atomic operations.
//
// It returns true if the item was successfully written. If the buffer is full
// and configured with the [DropNewest] policy, it safely discards the item and
// returns false. For the [Block] policy, it will wait for space by calling
// [runtime.Gosched].
func (b *Buffer[T]) Push(item T) bool {
for {
head := atomic.LoadUint64(&b.head)
tail := atomic.LoadUint64(&b.tail)
capacity := b.mask + 1
// 1. Check if the buffer is full.
if tail-head >= capacity {
switch b.policy {
case DropNewest:
return false // Discard the incoming event
case DropOldest:
// Try to advance head to invalidate the oldest item.
// If CAS fails, another goroutine already changed head; loop and retry.
atomic.CompareAndSwapUint64(&b.head, head, head+1)
continue
case Block:
// Yield execution to the scheduler to allow consumers to catch up.
runtime.Gosched()
continue
}
}
// 2. Try to claim the tail slot.
if atomic.CompareAndSwapUint64(&b.tail, tail, tail+1) {
// 3. Write data to the claimed slot.
b.data[tail&b.mask] = item
// 4. Publish the write by updating the sequence number.
atomic.StoreUint64(&b.seq[tail&b.mask], tail+1)
return true
}
// CAS failed: another producer claimed the slot first; loop and retry.
}
}
// Pop retrieves and removes the oldest item from the head of the buffer.
//
// It returns the generic item and true on success. If the buffer is currently
// empty, it returns the zero-value of type T and false. This method is safe
// for concurrent use by multiple consumers.
func (b *Buffer[T]) Pop() (T, bool) {
var zero T // Used to return a zero-value on failure
for {
head := atomic.LoadUint64(&b.head)
tail := atomic.LoadUint64(&b.tail)
// 1. Check if the buffer is empty.
if head == tail {
return zero, false
}
// 2. Ensure the producer has finished writing to this slot.
// If the sequence doesn't match head+1, it means the producer
// claimed the tail but hasn't published the write yet, or we
// are reading a stale head.
if atomic.LoadUint64(&b.seq[head&b.mask]) != head+1 {
runtime.Gosched()
continue
}
// 3. Read the data BEFORE advancing the head pointer.
item := b.data[head&b.mask]
// 4. Try to commit the read.
if atomic.CompareAndSwapUint64(&b.head, head, head+1) {
return item, true
}
// CAS failed: another consumer popped the item first; loop and retry.
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package rotor provides a thread-safe, generic type for rotating through a
// slice of items in a round-robin fashion.
//
// This package is intended for load-balancing scenarios, such as selecting
// backends, rotating API keys, or distributing tasks across a pool of workers.
// The implementation uses atomic operations to ensure high performance under
// concurrent access without the need for heavy-weight mutexes.
//
// # Usage
//
// Initialize a rotor with a slice of items and call Next to retrieve the next
// element in the sequence.
//
// Example:
//
// backends := []string{"srv-1", "srv-2", "srv-3"}
// r := rotor.New(backends)
//
// // Each call returns the next item in the sequence, wrapping around
// // at the end.
// s1 := r.Next() // "srv-1"
// s2 := r.Next() // "srv-2"
package rotor
import "sync/atomic"
// Rotor provides thread-safe round-robin access to a slice of items.
//
// It must be initialized with the [New] function. The interface allows for
// optimized internal implementations depending on the number of items provided.
type Rotor[E any] interface {
// Next returns the next item in the rotation.
// This method is safe for concurrent use by multiple goroutines.
Next() E
}
// singleton is a [Rotor] that contains only a single item.
type singleton[E any] struct {
// item is the solitary element in this rotation.
item E
}
// Next implements the [Rotor] interface, always returning the same item.
func (s *singleton[E]) Next() E {
return s.item
}
// rotor is a generic implementation of the [Rotor] interface for multiple items.
type rotor[E any] struct {
// items is the immutable slice of elements to rotate through.
items []E
// index tracks the current position in the rotation using atomic operations.
index atomic.Uint64
}
// New creates a new [Rotor].
//
// It makes a defensive copy of the provided items slice to ensure immutability.
// This function panics if the items slice is empty. If the slice contains exactly
// one item, an optimized [Rotor] implementation is returned.
func New[E any](items []E) Rotor[E] {
if len(items) == 0 {
panic("rotor: items slice must not be empty")
}
if len(items) == 1 {
return &singleton[E]{item: items[0]}
}
c := make([]E, len(items))
copy(c, items)
return &rotor[E]{items: c}
}
// Next implements the [Rotor] interface.
//
// It uses an atomic compare-and-swap loop to increment the internal index and
// wrap it around the length of the items slice, ensuring that every caller
// receives a unique index in the sequence until the cycle repeats.
func (r *rotor[E]) Next() E {
n := uint64(len(r.items))
var idx uint64
for {
idx = r.index.Load()
if r.index.CompareAndSwap(idx, (idx+1)%n) {
break
}
}
return r.items[idx]
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package schema provides utilities for parsing and manipulating database
// schema definitions and migration scripts.
//
// Its primary responsibility is to safely split raw SQL scripts into individual
// statements so they can be executed sequentially by a database driver. The
// parsing logic is designed to be aware of database-specific syntax, such as
// string literals, comments, and dollar-quoted strings in PostgreSQL, to
// prevent false positives when splitting on statement terminators like
// semicolons.
//
// # Usage
//
// Use the provided [Parser] implementations to break down SQL migration files
// into executable chunks.
//
// Example:
//
// script := []byte(
// "CREATE TABLE users (id int); -- comment\nINSERT INTO users VALUES (1);",
// )
// statements := schema.Postgres(script)
// // returns ["CREATE TABLE users (id int)", "INSERT INTO users VALUES (1)"]
package schema
import (
"bytes"
"github.com/deep-rent/nexus/internal/ascii"
)
// Parser is a function that splits a raw SQL script into a slice of individual
// executable statements.
//
// Implementations should handle database-specific syntax rules to ensure
// statements are not prematurely split when terminators appear inside quotes or
// comments.
type Parser func(script []byte) []string
// Postgres is a [Parser] implementation tailored for PostgreSQL scripts.
//
// It safely splits the script by semicolons (';'), while strictly ignoring
// semicolons that appear within:
// - Single-line comments ("-- ...")
// - Multi-line block comments ("/* ... */"), supporting nested blocks.
// - Single-quoted string literals ('...')
// - Double-quoted identifiers ("...")
// - PostgreSQL dollar-quoted strings ($tag$...$tag$).
//
// To optimize performance, it uses [bytes] package operations to fast-forward
// through recognized blocks and pre-allocates the statement slice based on
// semicolon frequency.
func Postgres(script []byte) []string {
script = bytes.TrimSpace(script)
// Early return for empty or whitespace-only scripts
if len(script) == 0 {
return nil
}
// Pre-allocate slice based on semicolon count to minimize allocations
n := bytes.Count(script, []byte{';'})
p := &postgres{
script: script,
stmts: make([]string, 0, n+1),
}
return p.parse()
}
// postgres holds the internal state machine for parsing a PostgreSQL script.
type postgres struct {
// script is the raw SQL script being parsed.
script []byte
// i is the current cursor position (byte index) in the script.
i int
// start is the byte index where the current statement begins.
start int
// stmts is the collected slice of individual statements.
stmts []string
// inSingleQuotes is true if the cursor is within a single-quoted string.
inSingleQuotes bool
// inDoubleQuotes is true if the cursor is within a double-quoted string.
inDoubleQuotes bool
// inComment is true if the cursor is within a single-line comment.
inComment bool
// depth tracks the nesting depth of multi-line block comments.
depth int
// tag is the active dollar-quote tag (e.g., "$BODY$") when inside one.
tag []byte
}
// parse iterates through the script, updating the state machine
// and splitting statements when a valid, top-level semicolon is encountered.
func (p *postgres) parse() []string {
n := len(p.script)
for p.i < n {
// 1. Prioritize state checks and fast-forward using bytes operations
switch {
case p.inComment:
idx := bytes.IndexByte(p.script[p.i:], '\n')
if idx == -1 {
p.i = n // EOF
break
}
p.i += idx + 1
p.inComment = false
continue
case p.depth > 0:
// Fast-forward to next possible block comment boundary
idx := bytes.IndexAny(p.script[p.i:], "/*")
if idx == -1 {
p.i = n // EOF
break
}
p.i += idx
c := p.script[p.i]
if c == '/' && p.i+1 < n && p.script[p.i+1] == '*' {
p.depth++
p.i++
} else if c == '*' && p.i+1 < n && p.script[p.i+1] == '/' {
p.depth--
p.i++
}
p.i++
continue
case len(p.tag) != 0:
// Fast-forward to the exact matching dollar-tag
idx := bytes.Index(p.script[p.i:], p.tag)
if idx == -1 {
p.i = n // EOF
break
}
p.i += idx + len(p.tag)
p.tag = nil
continue
case p.inSingleQuotes:
idx := bytes.IndexByte(p.script[p.i:], '\'')
if idx == -1 {
p.i = n // EOF
break
}
p.i += idx
if p.i+1 < n && p.script[p.i+1] == '\'' {
p.i += 2 // Skip escaped quote
} else {
p.inSingleQuotes = false
p.i++
}
continue
case p.inDoubleQuotes:
idx := bytes.IndexByte(p.script[p.i:], '"')
if idx == -1 {
p.i = n // EOF
break
}
p.i += idx
if p.i+1 < n && p.script[p.i+1] == '"' {
p.i += 2 // Skip escaped quote
} else {
p.inDoubleQuotes = false
p.i++
}
continue
}
// 2. We are in normal SQL text. Fast-forward to the next relevant
// character.
idx := bytes.IndexAny(p.script[p.i:], "-/$'\";")
if idx == -1 {
p.i = n // No more special characters, jump to end
break
}
p.i += idx
c := p.script[p.i]
// 3. Isolated value-based switch for compiler optimization
switch c {
case '-':
if p.i+1 < n && p.script[p.i+1] == '-' {
p.inComment = true
p.i++
}
case '/':
if p.i+1 < n && p.script[p.i+1] == '*' {
p.depth++
p.i++
}
case '$':
p.dollar(n)
case '\'':
p.inSingleQuotes = true
case '"':
p.inDoubleQuotes = true
case ';':
p.flush()
}
p.i++
}
// Add the final statement if the script does not end with a semicolon.
p.flush()
return p.stmts
}
// dollar scans ahead to parse a PostgreSQL dollar-quote tag (e.g., "$tag$").
//
// It uses [ascii.IsWord] to validate the characters within the tag. If a valid
// tag is found, the state is updated to track this [postgres.tag] block.
func (p *postgres) dollar(n int) {
end := -1
for j := p.i + 1; j < n; j++ {
nc := p.script[j]
if nc == '$' {
end = j
break
}
if !ascii.IsWord(rune(nc)) {
break
}
}
if end != -1 {
p.tag = p.script[p.i : end+1]
p.i = end
}
}
// flush extracts the current statement from the script buffer.
//
// It trims surrounding whitespace using [bytes.TrimSpace]. If the extracted
// statement is not empty, it is appended to the results list. It then advances
// the [postgres.start] pointer.
func (p *postgres) flush() {
if p.start >= len(p.script) {
return
}
if stmt := bytes.TrimSpace(p.script[p.start:p.i]); len(stmt) != 0 {
p.stmts = append(p.stmts, string(stmt))
}
p.start = p.i + 1
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package snake provides functions for converting strings between camelCase and
// snake_case formats.
//
// It handles transitions between lowercase letters, uppercase letters, and
// digits to produce idiomatic snake_case or SCREAMING_SNAKE_CASE strings. The
// implementation is specifically tuned for ASCII character sets and manages
// acronyms by detecting transitions from sequences of uppercase letters to a
// new word.
//
// # Usage
//
// Use [ToLower] for standard snake_case and [ToUpper] for constant-style
// uppercase snake_case.
//
// Example:
//
// low := snake.ToLower("JSONData") // "json_data"
// up := snake.ToUpper("myVariable") // "MY_VARIABLE"
package snake
import (
"strings"
"github.com/deep-rent/nexus/internal/ascii"
)
// ToUpper converts a camelCase string to an uppercase SNAKE_CASE string.
//
// For example, "fooBar" is converted to "FOO_BAR", and so is "FOOBar". Note
// that digits do not induce transitions, so "foo1" becomes "FOO1". Only ASCII
// characters are supported. This function internally uses [transform] with
// [ascii.ToUpper].
func ToUpper(s string) string { return transform(s, ascii.ToUpper) }
// ToLower converts a camelCase string to a lowercase snake_case string.
//
// For example, "fooBar" is converted to "foo_bar", and so is "FOOBar". Note
// that digits do not induce transitions, so "foo1" becomes "foo1". Only ASCII
// characters are supported. This function internally uses [transform] with
// [ascii.ToLower].
func ToLower(s string) string { return transform(s, ascii.ToLower) }
// transform is a helper function that performs the actual text conversion.
//
// It iterates through the runes of the string s and applies the toCase
// function to each character, while injecting underscores at word boundaries
// detected by case transitions or acronym detection logic.
func transform(s string, toCase func(rune) rune) string {
var b strings.Builder
b.Grow(len(s) + 5)
for i, r := range s {
// Insert an underscore before a capital letter or digit.
if i != 0 {
q := rune(s[i-1])
if (ascii.IsLower(q) &&
// Case 1: Lowercase to uppercase/digit transition ("myVar", "myVar1").
(ascii.IsUpper(r) || ascii.IsDigit(r))) ||
(ascii.IsUpper(q) &&
// Case 2: Acronym to new word transition ("MYVar").
ascii.IsUpper(r) &&
i+1 < len(s) &&
ascii.IsLower(rune(s[i+1]))) {
b.WriteRune('_')
}
}
b.WriteRune(toCase(r))
}
return b.String()
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tag provides utility for parsing Go struct tags that follow a
// comma-separated key-value option format.
//
// This format is similar to the `json` tag used in the standard library. The
// package handles complex cases where options may contain nested values or
// quoted strings, ensuring that commas within quotes do not break the parsing
// logic.
//
// # Usage
//
// Use [Parse] to initialize a tag and the [Tag.Opts] iterator to process
// individual options.
//
// Example:
//
// const raw = "user_id,omitempty,default:'anonymous,guest'"
// t := tag.Parse(raw)
// // t.Name is "user_id"
//
// for k, v := range t.Opts() {
// // Yields:
// // "omitempty", ""
// // "default", "anonymous,guest"
// }
package tag
import (
"iter"
"strings"
"unicode"
"github.com/deep-rent/nexus/internal/quote"
)
// Tag represents a parsed struct tag, separating the primary name from the
// additional options.
type Tag struct {
// Name is the primary identifier of the tag (the part before the first
// comma).
Name string
// opts is the raw string containing the remaining comma-separated options.
opts string
}
// Opts returns an iterator sequence over the tag's options.
//
// Each element yielded is a key-value pair. If an option does not have an
// explicit value (e.g., "omitempty"), the value string will be empty. Keys and
// values are trimmed of surrounding whitespace. Values that were quoted in the
// source string (e.g., `key:"value"`) will have the quotes removed via
// [quote.Remove].
//
// Commas inside quoted values are preserved and not treated as option
// separators (e.g., `key:"val1,val2"` is treated as one option).
func (t *Tag) Opts() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
rest := t.opts
// Scan through the rest of the string until it's completely consumed.
for rest != "" {
// Trim leading space from the rest of the string.
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
if rest == "" {
break
}
// Find the end of the current option part by finding the next
// comma that is not inside quotes.
end := -1
inQuote := false
var q rune
scan:
for i, r := range rest {
switch {
case r == q:
inQuote = false
q = 0
case !inQuote && quote.Is(r):
inQuote = true
q = r
case !inQuote && r == ',':
end = i
break scan
}
}
var part string
if end == -1 {
// This is the last option part.
part = rest
rest = ""
} else {
part = rest[:end]
rest = rest[end+1:]
}
// Now, parse the individual part (e.g., "default:'foo,bar'").
k, v, found := strings.Cut(part, ":")
if found {
v = quote.Remove(v)
}
if !yield(strings.TrimRightFunc(k, unicode.IsSpace), v) {
return
}
}
}
}
// Parse takes a raw tag string and separates it into the primary name and the
// options string.
//
// For a string like `json:opt1,opt2:val`, it identifies the content before the
// first comma as the [Tag.Name].
func Parse(s string) *Tag {
name, opts, _ := strings.Cut(s, ",")
return &Tag{
Name: name,
opts: opts,
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwa provides implementations for asymmetric JSON Web Algorithms.
//
// Package jwa provides implementations for asymmetric JSON Web Algorithms (JWA)
// as defined in RFC 7518. It provides a unified interface for signature
// verification using public keys and signature creation using [crypto.Signer].
// This abstraction handles algorithm-specific complexities such as hash
// function selection, padding schemes (e.g., PSS vs PKCS1v15), and signature
// format transcoding (e.g., converting ECDSA ASN.1 DER to raw concatenation).
//
// Note: Symmetric algorithms (such as HMAC) are not supported.
//
// # Usage
//
// This package is typically used to verify or sign JWS payloads by selecting
// a specific [Algorithm] instance like [RS256] or [EdDSA].
//
// Example:
//
// // Verify a signature using RS256
// valid := jwa.RS256.Verify(publicKey, message, signature)
package jwa
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"encoding/asn1"
"errors"
"fmt"
"hash"
"math/big"
"sync"
"github.com/cloudflare/circl/sign/ed448"
)
// Algorithm represents an asymmetric JSON Web Algorithm (JWA) used for
// verifying and calculating signatures. The type parameter T specifies the type
// of public key that the algorithm works with.
type Algorithm[T crypto.PublicKey] interface {
// String provides the standard JWA name for the algorithm.
fmt.Stringer
// Verify checks a signature against a message using the provided public key.
// It returns true if the signature is valid, and false otherwise.
// None of the parameters may be nil.
Verify(key T, msg, sig []byte) bool
// Sign creates a signature for the message using the provided signer.
// The signer must be capable of using the algorithm's specific hash
// and padding scheme.
Sign(signer crypto.Signer, msg []byte) ([]byte, error)
}
// rs implements the RSASSA-PKCS1-v1_5 family of algorithms (RSxxx).
type rs struct {
// name is the JWA identifier.
name string
// pool is the internal hash pool for thread-safe operations.
pool *hashPool
}
// newRS creates a new [Algorithm] for RSASSA-PKCS1-v1_5 signatures
// with the given JWA name and hash function.
func newRS(name string, hash crypto.Hash) Algorithm[*rsa.PublicKey] {
return &rs{
name: name,
pool: newHashPool(hash),
}
}
// Verify checks an RSASSA-PKCS1-v1_5 signature.
func (a *rs) Verify(key *rsa.PublicKey, msg, sig []byte) bool {
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
return rsa.VerifyPKCS1v15(key, a.pool.Hash, digest, sig) == nil
}
// Sign creates an RSASSA-PKCS1-v1_5 signature using the provided [crypto.Signer].
func (a *rs) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
return signer.Sign(rand.Reader, digest, a.pool.Hash)
}
// String returns the JWA algorithm name.
func (a *rs) String() string {
return a.name
}
// RS256 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-256.
var RS256 = newRS("RS256", crypto.SHA256)
// RS384 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-384.
var RS384 = newRS("RS384", crypto.SHA384)
// RS512 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-512.
var RS512 = newRS("RS512", crypto.SHA512)
// ps implements the RSASSA-PSS family of algorithms (PSxxx).
type ps struct {
// name is the JWA identifier.
name string
// pool is the internal hash pool for thread-safe operations.
pool *hashPool
}
// newPS creates a new [Algorithm] for RSASSA-PSS signatures
// with the given JWA name and hash function.
func newPS(name string, hash crypto.Hash) Algorithm[*rsa.PublicKey] {
return &ps{
name: name,
pool: newHashPool(hash),
}
}
// Verify checks an RSASSA-PSS signature.
func (a *ps) Verify(key *rsa.PublicKey, msg, sig []byte) bool {
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
// The salt length is set to match the hash size.
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
return rsa.VerifyPSS(key, a.pool.Hash, digest, sig, opts) == nil
}
// Sign creates an RSASSA-PSS signature using the provided [crypto.Signer].
func (a *ps) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
opts := &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: a.pool.Hash,
}
return signer.Sign(rand.Reader, digest, opts)
}
// String returns the JWA algorithm name.
func (a *ps) String() string {
return a.name
}
// PS256 represents the RSASSA-PSS signature algorithm using SHA-256.
var PS256 = newPS("PS256", crypto.SHA256)
// PS384 represents the RSASSA-PSS signature algorithm using SHA-384.
var PS384 = newPS("PS384", crypto.SHA384)
// PS512 represents the RSASSA-PSS signature algorithm using SHA-512.
var PS512 = newPS("PS512", crypto.SHA512)
// es implements the ECDSA family of algorithms (ESxxx).
type es struct {
// name is the JWA identifier.
name string
// pool is the internal hash pool for thread-safe operations.
pool *hashPool
}
// newES creates a new [Algorithm] for ECDSA signatures
// with the given JWA name and hash function.
func newES(name string, hash crypto.Hash) Algorithm[*ecdsa.PublicKey] {
return &es{
name: name,
pool: newHashPool(hash),
}
}
// Verify checks an ECDSA signature.
func (a *es) Verify(key *ecdsa.PublicKey, msg, sig []byte) bool {
// The signature is the concatenation of two integers of the same size
// as the curve's order.
n := (key.Curve.Params().BitSize + 7) / 8
if len(sig) != 2*n {
return false
}
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
// Split the signature into R and S.
r := new(big.Int).SetBytes(sig[:n])
s := new(big.Int).SetBytes(sig[n:])
return ecdsa.Verify(key, digest, r, s)
}
// Sign creates an ECDSA signature and transcodes it from ASN.1 DER to raw format.
func (a *es) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
der, err := signer.Sign(rand.Reader, digest, nil)
if err != nil {
return nil, err
}
var concat struct{ R, S *big.Int }
if _, err := asn1.Unmarshal(der, &concat); err != nil {
return nil, fmt.Errorf("failed to parse ECDSA signature: %w", err)
}
pub, ok := signer.Public().(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("signer public key is not ECDSA")
}
n := (pub.Curve.Params().BitSize + 7) / 8
out := make([]byte, 2*n)
concat.R.FillBytes(out[:n])
concat.S.FillBytes(out[n:])
return out, nil
}
// String returns the JWA algorithm name.
func (a *es) String() string {
return a.name
}
// ES256 represents the ECDSA signature algorithm using P-256 and SHA-256.
var ES256 = newES("ES256", crypto.SHA256)
// ES384 represents the ECDSA signature algorithm using P-384 and SHA-384.
var ES384 = newES("ES384", crypto.SHA384)
// ES512 represents the ECDSA signature algorithm using P-521 and SHA-512.
var ES512 = newES("ES512", crypto.SHA512)
// ed implements the EdDSA family of algorithms.
type ed struct{}
// Verify checks an EdDSA signature, supporting both Ed25519 and Ed448.
func (a *ed) Verify(key, msg, sig []byte) bool {
switch len(key) {
case ed448.PublicKeySize:
// Per RFC 8037, the JWS "EdDSA" algorithm corresponds to the "pure" EdDSA
// variant, which uses an empty string for the context parameter.
pub := ed448.PublicKey(key)
return ed448.Verify(pub, msg, sig, "")
case ed25519.PublicKeySize:
pub := ed25519.PublicKey(key)
return ed25519.Verify(pub, msg, sig)
default:
return false
}
}
// Sign creates an EdDSA signature using the provided [crypto.Signer].
func (a *ed) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
return signer.Sign(rand.Reader, msg, crypto.Hash(0))
}
// String returns the JWA algorithm name.
func (a *ed) String() string {
return "EdDSA"
}
// EdDSA represents the EdDSA signature algorithm. It supports both Ed25519
// and Ed448 curves. The curve is determined by the size of the public key.
var EdDSA Algorithm[[]byte] = &ed{}
// hashPool manages a pool of [hash.Hash] objects to reduce allocations.
type hashPool struct {
// Hash is the underlying hash identifier.
Hash crypto.Hash
// pool is the [sync.Pool] containing initialized [hash.Hash] instances.
pool *sync.Pool
}
// newHashPool creates a new [hashPool] for the given hash function.
func newHashPool(hash crypto.Hash) *hashPool {
pool := &sync.Pool{
New: func() any {
return hash.New()
},
}
return &hashPool{
Hash: hash,
pool: pool,
}
}
// Get retrieves a [hash.Hash] from the pool.
func (p *hashPool) Get() hash.Hash {
h := p.pool.Get()
return h.(hash.Hash)
}
// Put returns a [hash.Hash] to the pool after resetting it.
func (p *hashPool) Put(h hash.Hash) {
h.Reset()
p.pool.Put(h)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwk provides functionality to parse, manage, and marshal JSON Web
// Keys (JWK) and JSON Web Key Sets (JWKS), as defined in RFC 7517.
//
// # Verification
//
// The package is primarily designed to consume public keys from a remote JWKS
// endpoint for the purpose of verifying JWT signatures.
//
// # Signing
//
// While JWKS parsing focuses on public keys, this package also supports the
// creation of signing keys via the [KeyBuilder]. These keys wrap a
// [crypto.Signer] (e.g., hardware modules, KMS, or standard library keys) to
// support token issuance operations.
//
// # Encoding
//
// The package supports serializing keys back to JSON. This is useful for
// services that need to expose their own public keys via a JWKS endpoint or
// for persisting key sets. The marshaling logic is strict: it only outputs
// public key material and adheres to RFC 7518 fixed-width requirements for
// elliptic curve coordinates.
//
// # Eligible Keys
//
// Keys that are not intended for signature verification are considered
// ineligible and will be skipped during parsing of a JWKS. A key is eligible
// if it meets at least one of the following criteria:
//
// - The "use" (Public Key Use) parameter is set to "sig".
// - The "key_ops" (Key Operations) parameter includes "verify".
//
// # Key Selection
//
// This implementation deliberately deviates from the RFC for robustness and
// simplicity:
//
// 1. The "alg" (Algorithm) parameter, optional in the standard, is treated as
// mandatory for all eligible keys. Enforcing this is a best practice that
// mitigates algorithm confusion attacks.
// 2. For key selection, either "kid" (Key ID) or "x5t#S256" (SHA-256
// Thumbprint) must be defined. The "x5t" (SHA-1 Thumbprint) parameter is
// explicitly ignored as it is considered outdated. No other lookup
// mechanism is supported.
//
// # Usage
//
// Parse a JWKS from a remote endpoint and look up a key for verification.
//
// Example:
//
// set, err := jwk.ParseSet(jsonData)
// if err != nil {
// log.Fatal(err)
// }
// key := set.Find(header)
package jwk
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json/jsontext"
"encoding/json/v2"
"errors"
"fmt"
"iter"
"log/slog"
"math/big"
"slices"
"time"
"github.com/cloudflare/circl/sign/ed448"
"github.com/deep-rent/nexus/cache"
"github.com/deep-rent/nexus/jose/jwa"
"github.com/deep-rent/nexus/scheduler"
)
// Hint represents a reference to a [Key], containing the minimum information
// needed to look one up in a [Set]. It effectively abstracts the JWS header
// fields used to select a key for signature verification.
type Hint interface {
// Algorithm returns the JWA algorithm name that the key is intended for.
// This must match the "alg" parameter in the JWS header.
Algorithm() string
// KeyID returns the unique identifier for the key, or an empty string if
// absent. This must match the "kid" parameter in the JWS header.
// One of "kid" or "x5t#S256" must be present. If both are present, "kid"
// takes precedence during lookups.
KeyID() string
// Thumbprint returns the base64url-encoded SHA-256 digest of the DER-encoded
// X.509 certificate associated with the key, or an empty string if absent.
// This must match the "x5t#S256" parameter in the JWS header. One of "kid" or
// "x5t#S256" must be present. If both are present, "kid" takes precedence
// during lookups.
Thumbprint() string
}
// Key represents a public JSON Web Key (JWK) used for signature verification.
type Key interface {
Hint
// Verify checks a signature against a message using the key's material
// and its associated algorithm. It returns true if the signature is valid.
// It returns false if the signature is invalid.
Verify(msg, sig []byte) bool
// Material returns the raw cryptographic public key for encoding purposes.
// The private key is never exposed.
Material() any
}
// newKey creates a new [Key] programmatically from its constituent parts. The
// type parameter T must match the public key type expected by the provided
// algorithm (e.g., [*rsa.PublicKey] for [jwa.RS256]).
func newKey[T crypto.PublicKey](
alg jwa.Algorithm[T],
kid string,
x5t string,
mat T,
) Key {
return &key[T]{alg: alg, kid: kid, x5t: x5t, mat: mat}
}
// key is a concrete implementation of the [Key] interface, generic over the
// public key type.
type key[T crypto.PublicKey] struct {
// alg is the JWA implementation for this key.
alg jwa.Algorithm[T]
// kid is the unique key identifier.
kid string
// x5t is the SHA-256 thumbprint of the certificate.
x5t string
// mat is the actual cryptographic public key material.
mat T
}
// Algorithm implements [Hint].
func (k *key[T]) Algorithm() string { return k.alg.String() }
// KeyID implements [Hint].
func (k *key[T]) KeyID() string { return k.kid }
// Thumbprint implements [Hint].
func (k *key[T]) Thumbprint() string { return k.x5t }
// Material implements [Key].
func (k *key[T]) Material() any { return k.mat }
// Verify implements [Key].
func (k *key[T]) Verify(msg, sig []byte) bool {
return k.alg.Verify(k.mat, msg, sig)
}
// KeyPair represents a JSON Web Key that is capable of both verification and
// signing. It embeds the public [Key] interface and wraps a [crypto.Signer] for
// the private key operations.
type KeyPair interface {
Key
// Sign generates a signature for the given message.
Sign(msg []byte) ([]byte, error)
}
// keyPair is the concrete implementation of [KeyPair].
type keyPair[T crypto.PublicKey] struct {
// key is the underlying public key.
key[T]
// signer is the private key handle.
signer crypto.Signer
}
// Sign implements [KeyPair].
func (s *keyPair[T]) Sign(msg []byte) ([]byte, error) {
return s.alg.Sign(s.signer, msg)
}
// KeyBuilder assists in the programmatic construction of [Key] and [KeyPair]
// instances. It ensures that the resulting keys possess the required metadata
// and that the cryptographic material matches the intended algorithm.
type KeyBuilder[T crypto.PublicKey] struct {
// alg is the algorithm for the resulting key.
alg jwa.Algorithm[T]
// kid is the key identifier to assign.
kid string
// x5t is the thumbprint to assign.
x5t string
}
// NewKeyBuilder starts the construction of a key for the specified algorithm.
// The generic type T determines the expected public key material (e.g.,
// [*rsa.PublicKey]).
func NewKeyBuilder[T crypto.PublicKey](alg jwa.Algorithm[T]) *KeyBuilder[T] {
return &KeyBuilder[T]{alg: alg}
}
// Algorithm returns the JWA associated with this builder.
func (b *KeyBuilder[T]) Algorithm() jwa.Algorithm[T] { return b.alg }
// KeyID returns the currently configured key identifier, or an empty string.
func (b *KeyBuilder[T]) KeyID() string { return b.kid }
// Thumbprint returns the currently configured certificate thumbprint, or an
// empty string.
func (b *KeyBuilder[T]) Thumbprint() string { return b.x5t }
// WithKeyID sets the "kid" (Key ID) parameter.
func (b *KeyBuilder[T]) WithKeyID(kid string) *KeyBuilder[T] {
b.kid = kid
return b
}
// WithThumbprint sets the "x5t#S256" (SHA-256 Certificate Thumbprint)
// parameter.
func (b *KeyBuilder[T]) WithThumbprint(x5t string) *KeyBuilder[T] {
b.x5t = x5t
return b
}
// Build creates a verification-only [Key] using the provided public key material.
// It panics if neither a Key ID nor a Thumbprint has been configured.
func (b *KeyBuilder[T]) Build(mat T) Key {
return b.build(mat)
}
// BuildPair creates a signing-capable [KeyPair] using the provided signer.
//
// It panics if:
// 1. The signer's public key cannot be cast to type T.
// 2. Neither a Key ID nor a Thumbprint has been configured.
func (b *KeyBuilder[T]) BuildPair(signer crypto.Signer) KeyPair {
mat, ok := signer.Public().(T)
if !ok {
panic("signer public key type does not match key builder type")
}
return &keyPair[T]{
key: *b.build(mat),
signer: signer,
}
}
// build is an internal helper to construct the public key part.
func (b *KeyBuilder[T]) build(mat T) *key[T] {
if b.kid == "" && b.x5t == "" {
panic("either key id or thumbprint must be set")
}
return &key[T]{
alg: b.alg,
kid: b.kid,
x5t: b.x5t,
mat: mat,
}
}
// ErrIneligibleKey indicates that a key may be syntactically valid but should
// not be used for signature verification according to its "use" or "key_ops"
// parameters.
var ErrIneligibleKey = errors.New("ineligible for signature verification")
// Parse parses a single [Key] from the provided JSON input.
//
// It first checks if the key is eligible for signature verification. If not,
// it returns [ErrIneligibleKey]. Otherwise, it proceeds to validate the
// presence of required parameters ("kty" and "alg"), whether the algorithm is
// supported, and the integrity of the key material itself.
func Parse(in []byte) (Key, error) {
var raw raw
if err := json.Unmarshal(in, &raw); err != nil {
return nil, fmt.Errorf("invalid json format: %w", err)
}
// Per RFC 7517, a key's purpose is determined by the union of "use" and
// "key_ops". We perform this check first for efficiency, as we only care
// about signature verification keys.
if raw.Use != "sig" && !slices.Contains(raw.Ops, "verify") {
return nil, ErrIneligibleKey
}
if raw.Kty == "" {
return nil, errors.New("undefined key type")
}
if raw.Alg == "" {
return nil, errors.New("algorithm not specified")
}
read := readers[raw.Alg]
if read == nil {
return nil, fmt.Errorf("unknown algorithm %q", raw.Alg)
}
key, err := read(&raw)
if err != nil {
return nil, fmt.Errorf("read %s key material: %w", raw.Kty, err)
}
return key, nil
}
// Set stores an immutable collection of [Key] instances, typically parsed from
// a JWKS. It provides efficient lookups of keys for signature verification.
type Set interface {
// Keys returns an iterator over all keys in this set.
Keys() iter.Seq[Key]
// Len returns the number of keys in this set.
Len() int
// Find looks up a key using the specified hint. A key is returned only
// if both its key id and algorithm match the hint exactly.
// Otherwise, it returns nil.
Find(hint Hint) Key
}
// newSet creates a new, empty [set] with the specified initial capacity.
func newSet(n int) *set {
return &set{
keys: make([]Key, 0, n),
kid: make(map[string]int, n),
x5t: make(map[string]int, n),
}
}
// set is the concrete implementation of the [Set] interface.
// It uses maps for efficient O(1) average time complexity lookups.
type set struct {
// keys is the slice of keys in the set.
keys []Key
// kid maps key id to index in keys array.
kid map[string]int
// x5t maps thumbprint to index in keys array.
x5t map[string]int
}
// Keys implements [Set].
func (s *set) Keys() iter.Seq[Key] { return slices.Values(s.keys) }
// Len implements [Set].
func (s *set) Len() int { return len(s.keys) }
// Find implements [Set].
func (s *set) Find(hint Hint) Key {
if hint == nil {
return nil
}
var k Key
if i, ok := s.kid[hint.KeyID()]; ok {
k = s.keys[i]
} else if i, ok := s.x5t[hint.Thumbprint()]; ok {
k = s.keys[i]
} else {
return nil
}
if k.Algorithm() != hint.Algorithm() {
return nil
}
return k
}
// emptySet represents a [Set] containing no keys.
type emptySet struct{}
// Keys implements [Set] for [emptySet].
func (e emptySet) Keys() iter.Seq[Key] { return func(func(Key) bool) {} }
// Len implements [Set] for [emptySet].
func (e emptySet) Len() int { return 0 }
// Find implements [Set] for [emptySet].
func (e emptySet) Find(Hint) Key { return nil }
// empty is a singleton instance of an empty [Set].
var empty Set = emptySet{}
// singletonSet is an adapter that wraps a single [Key] as a [Set].
type singletonSet struct {
// key is the single key in the set.
key Key
}
// Keys implements [Set] for [singletonSet].
func (s *singletonSet) Keys() iter.Seq[Key] {
return func(f func(Key) bool) { f(s.key) }
}
// Len implements [Set] for [singletonSet].
func (s *singletonSet) Len() int { return 1 }
// Find implements [Set] for [singletonSet].
func (s *singletonSet) Find(hint Hint) Key {
if s.key.Algorithm() != hint.Algorithm() {
return nil
}
kid := hint.KeyID()
if kid != "" && s.key.KeyID() == kid {
return s.key
}
x5t := hint.Thumbprint()
if x5t != "" && s.key.Thumbprint() == x5t {
return s.key
}
return nil
}
// ParseSet parses a [Set] from a JWKS JSON input.
//
// If the top-level JSON structure is malformed, it returns an empty set and
// a fatal error. Otherwise, it iterates through the "keys" array, parsing
// each key individually. Keys that are invalid, unsupported, or occur multiple
// times, result in non-fatal errors. Ineligible keys (e.g., those meant for
// encryption) are silently skipped. If any non-fatal errors occurred, a joined
// error is returned alongside the set of successfully parsed keys.
func ParseSet(in []byte) (Set, error) {
var raw struct {
// Defer unmarshaling of individual keys to safely skip ineligible ones.
Keys []jsontext.Value `json:"keys"`
}
if err := json.Unmarshal(in, &raw); err != nil {
return empty, fmt.Errorf("invalid format: %w", err)
}
n := len(raw.Keys)
if n == 0 {
return empty, nil
}
s := newSet(n)
var errs []error
for i, v := range raw.Keys {
k, err := Parse(v)
if err != nil {
if errors.Is(err, ErrIneligibleKey) {
continue
}
err = fmt.Errorf("key at index %d: %w", i, err)
errs = append(errs, err)
continue
}
kid := k.KeyID()
x5t := k.Thumbprint()
if kid == "" && x5t == "" {
errs = append(errs, fmt.Errorf(
"key at index %d: missing both key id and thumbprint", i,
))
continue
}
// Check for duplicates before mutating the set.
if kid != "" {
if _, ok := s.kid[kid]; ok {
errs = append(errs, fmt.Errorf(
"key at index %d: duplicate key id %q", i, kid,
))
continue
}
}
if x5t != "" {
if _, ok := s.x5t[x5t]; ok {
errs = append(errs, fmt.Errorf(
"key at index %d: duplicate thumbprint %q", i, x5t,
))
continue
}
}
// Determines the index in the keys'slice where this new key will be
// stored. This is safe because we are appending linearly.
idx := len(s.keys)
// Append the key exactly once.
s.keys = append(s.keys, k)
// Update the lookup maps.
if kid != "" {
s.kid[kid] = idx
}
if x5t != "" {
s.x5t[x5t] = idx
}
}
return s, errors.Join(errs...)
}
// Write marshals a single [Key] into its JSON Web Key representation.
//
// It populates the standard JWK fields ("kty", "alg", "use", "kid", "x5t#S256")
// and the algorithm-specific public key parameters (e.g., "n" and "e" for RSA).
// The output is strictly compliant with RFC 7517 and RFC 7518, ensuring that
// elliptic curve coordinates are padded to the correct fixed width.
func Write(k Key) ([]byte, error) {
r, err := toRaw(k)
if err != nil {
return nil, err
}
return json.Marshal(r)
}
// WriteSet marshals a [Set] into a JSON Web Key Set (JWKS) document.
//
// The resulting JSON corresponds to the standard JWKS structure:
//
// {
// "keys": [ ... ]
// }
//
// This function efficiently iterates over the keys in the set, converting them
// to their raw JSON representation before marshaling the entire collection.
func WriteSet(s Set) ([]byte, error) {
// We marshal into a slice of raw structs directly.
// This is more efficient than calling Write() loop, which would
// result in double-marshaling.
keys := make([]raw, 0, s.Len())
for k := range s.Keys() {
r, err := toRaw(k)
if err != nil {
return nil, fmt.Errorf("encode key %q: %w", k.KeyID(), err)
}
keys = append(keys, *r)
}
return json.Marshal(struct {
Keys []raw `json:"keys"`
}{
Keys: keys,
})
}
// toRaw converts a [Key] object into the [raw] DTO.
func toRaw(k Key) (*raw, error) {
write, ok := writers[k.Algorithm()]
if !ok {
return nil, fmt.Errorf("unsupported algorithm %q", k.Algorithm())
}
// Populate standard metadata.
r := &raw{
Alg: k.Algorithm(),
Kid: k.KeyID(),
X5t: k.Thumbprint(),
Use: "sig",
}
// Populate algorithm-specific fields.
if err := write(k.Material(), r); err != nil {
return nil, err
}
return r, nil
}
// Singleton creates a [Set] that contains only the provided [Key].
func Singleton(key Key) Set {
return &singletonSet{key: key}
}
// CacheSet extends the [Set] interface with [scheduler.Tick], creating a
// component that can be deployed to a scheduler for automatic refreshing of a
// remote JWKS view in the background. The default implementation is backed by
// a [cache.Controller].
type CacheSet interface {
Set
scheduler.Tick
}
// cacheSet is the concrete implementation of the [CacheSet] interface.
type cacheSet struct {
// ctrl manages the lifecycle and fetching of the remote JWKS.
ctrl cache.Controller[Set]
}
// get safely retrieves the current [Set] from the cache controller. If the
// cache has not been populated yet (e.g., due to an initial network failure),
// it returns a static [empty] set to ensure that delegated operations like Find
// do not panic. This makes the [Set] resilient to transient startup issues.
func (s *cacheSet) get() Set {
if set, ok := s.ctrl.Get(); ok {
return set
}
return empty
}
// Keys implements [Set].
func (s *cacheSet) Keys() iter.Seq[Key] { return s.get().Keys() }
// Len implements [Set].
func (s *cacheSet) Len() int { return s.get().Len() }
// Find implements [Set].
func (s *cacheSet) Find(hint Hint) Key { return s.get().Find(hint) }
// Run implements [scheduler.Tick].
func (s *cacheSet) Run(ctx context.Context) time.Duration {
return s.ctrl.Run(ctx)
}
// Ensure cacheSet implements CacheSet.
var _ CacheSet = (*cacheSet)(nil)
// mapper adapts the [ParseSet] function to the [cache.Mapper] interface.
var mapper cache.Mapper[Set] = func(r *cache.Response) (Set, error) {
set, err := ParseSet(r.Body)
if set.Len() == 0 {
return nil, errors.New("no valid keys found")
}
if err != nil && r.Logger.Enabled(r.Ctx, slog.LevelDebug) {
r.Logger.Debug("Some keys could not be parsed", slog.Any("error", err))
}
// Don't complain unless there are no keys available at all.
return set, nil
}
// NewCacheSet creates a new [CacheSet] that stays in sync with a remote JWKS
// endpoint. It must be deployed to a [scheduler.Scheduler] to begin the
// background fetching and refreshing process.
//
// The provided [cache.Option] can configure behaviors like refresh interval,
// request timeouts, and error handling. Parsing of retrieved key sets is
// extremely lenient: it will only fail if no valid keys are found at all.
func NewCacheSet(url string, opts ...cache.Option) CacheSet {
ctrl := cache.NewController(url, mapper, opts...)
return &cacheSet{ctrl}
}
// raw holds the JWK parameters including the key material.
type raw struct {
Kty string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use,omitempty"`
Ops []string `json:"key_ops,omitempty"`
Kid string `json:"kid,omitempty"`
X5t string `json:"x5t#S256,omitempty"`
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
}
// reader defines a function that decodes the key material from a [raw] JWK
// and constructs a concrete [Key].
type reader func(r *raw) (Key, error)
// readers maps a JWA algorithm name to the function responsible for parsing
// its key material.
var readers map[string]reader
// addReader helps populate the readers map in a type-safe manner.
func addReader[T crypto.PublicKey](alg jwa.Algorithm[T], dec decoder[T]) {
readers[alg.String()] = func(r *raw) (Key, error) {
mat, err := dec(r)
if err != nil {
return nil, err
}
return newKey(alg, r.Kid, r.X5t, mat), nil
}
}
// decoder decodes the key material for a specific key type T.
type decoder[T crypto.PublicKey] func(*raw) (T, error)
// decodeRSA parses the material for an RSA public key.
func decodeRSA(raw *raw) (*rsa.PublicKey, error) {
if raw.Kty != "RSA" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
if len(raw.N) == 0 {
return nil, errors.New("missing modulus")
}
if len(raw.E) == 0 {
return nil, errors.New("missing public exponent")
}
nBytes, err := base64.RawURLEncoding.DecodeString(raw.N)
if err != nil {
return nil, fmt.Errorf("decode modulus: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(raw.E)
if err != nil {
return nil, fmt.Errorf("decode public exponent: %w", err)
}
// Exponents > 2^31-1 are extremely rare and not recommended.
if len(eBytes) > 4 {
return nil, errors.New("public exponent exceeds 32 bits")
}
n := new(big.Int).SetBytes(nBytes)
e := 0
// The conversion to a big-endian unsigned integer is safe because of the
// length check above.
for _, b := range eBytes {
e = (e << 8) | int(b)
}
return &rsa.PublicKey{N: n, E: e}, nil
}
// decodeECDSA creates a [decoder] for the specified elliptic curve.
func decodeECDSA(crv elliptic.Curve) decoder[*ecdsa.PublicKey] {
return func(raw *raw) (*ecdsa.PublicKey, error) {
if raw.Kty != "EC" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
if raw.Crv != crv.Params().Name {
return nil, fmt.Errorf("incompatible curve %q", raw.Crv)
}
if len(raw.X) == 0 {
return nil, errors.New("missing x coordinate")
}
if len(raw.Y) == 0 {
return nil, errors.New("missing y coordinate")
}
xBytes, err := base64.RawURLEncoding.DecodeString(raw.X)
if err != nil {
return nil, fmt.Errorf("decode x coordinate: %w", err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(raw.Y)
if err != nil {
return nil, fmt.Errorf("decode y coordinate: %w", err)
}
// Calculate the required byte size for the curve coordinates.
size := (crv.Params().BitSize + 7) / 8
if len(xBytes) > size || len(yBytes) > size {
return nil, errors.New("coordinate length exceeds curve size")
}
// Construct the SEC 1 uncompressed point format: 0x04 || X || Y.
uncompressed := make([]byte, 1+(2*size))
uncompressed[0] = 4
copy(uncompressed[1+size-len(xBytes):1+size], xBytes)
copy(uncompressed[1+(2*size)-len(yBytes):], yBytes)
pub, err := ecdsa.ParseUncompressedPublicKey(crv, uncompressed)
if err != nil {
return nil, fmt.Errorf("parse public key: %w", err)
}
return pub, nil
}
}
// decodeEdDSA parses the material for an EdDSA public key.
func decodeEdDSA(raw *raw) ([]byte, error) {
if raw.Kty != "OKP" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
var n int
switch raw.Crv {
case "Ed448":
n = ed448.PublicKeySize
case "Ed25519":
n = ed25519.PublicKeySize
default:
return nil, fmt.Errorf("unsupported curve %q", raw.Crv)
}
x, err := base64.RawURLEncoding.DecodeString(raw.X)
if err != nil {
return nil, fmt.Errorf("decode x coordinate: %w", err)
}
if m := len(x); m != n {
return nil, fmt.Errorf(
"illegal key size for %s curve: got %d, want %d", raw.Crv, m, n,
)
}
return x, nil
}
// writers maps a JWA algorithm name to the function responsible for encoding
// its key material.
var writers map[string]writer
// writer defines a function that encodes the key material into a marshallable
// JWT struct.
type writer func(mat any, r *raw) error
// addWriter helps populate the writers map in a type-safe manner.
func addWriter[T crypto.PublicKey](alg jwa.Algorithm[T], enc encoder[T]) {
writers[alg.String()] = func(mat any, r *raw) error {
pub, ok := mat.(T)
if !ok {
return fmt.Errorf("invalid key for algorithm %q", alg.String())
}
return enc(pub, r)
}
}
// encoder defines a function that populates the [raw] JWK parameters from the
// algorithm-specific key material.
type encoder[T crypto.PublicKey] func(mat T, r *raw) error
// encodeRSA populates the RSA-specific fields ("n", "e") in the [raw] JWK.
func encodeRSA(key *rsa.PublicKey, r *raw) error {
r.Kty = "RSA"
r.N = base64.RawURLEncoding.EncodeToString(key.N.Bytes())
e := key.E
if e == 0 {
return errors.New("RSA public exponent is zero")
}
var eBytes []byte
if e < 0xFFFFFF {
eBytes = make([]byte, 0, 3)
} else {
eBytes = make([]byte, 0, 4)
}
for e > 0 {
eBytes = append([]byte{byte(e)}, eBytes...)
e >>= 8
}
r.E = base64.RawURLEncoding.EncodeToString(eBytes)
return nil
}
// encodeECDSA populates the ECDSA-specific fields ("crv", "x", "y").
// It enforces fixed-width padding for coordinates as required by RFC 7518.
func encodeECDSA(key *ecdsa.PublicKey, r *raw) error {
r.Kty = "EC"
params := key.Params()
r.Crv = params.Name
// Obtain the SEC 1 uncompressed format: 0x04 || X || Y.
b, err := key.Bytes()
if err != nil {
return fmt.Errorf("encode ecdsa key: %w", err)
}
if len(b) < 1 || b[0] != 4 {
return errors.New("invalid public key format")
}
// Calculate coordinate size dynamically based on the returned slice.
size := (len(b) - 1) / 2
x := b[1 : 1+size]
y := b[1+size : 1+(2*size)]
r.X = base64.RawURLEncoding.EncodeToString(x)
r.Y = base64.RawURLEncoding.EncodeToString(y)
return nil
}
// encodeEdDSA populates the EdDSA-specific fields ("crv", "x").
// It determines the curve name based on the key length.
func encodeEdDSA(key []byte, r *raw) error {
r.Kty = "OKP"
switch len(key) {
case ed25519.PublicKeySize:
r.Crv = "Ed25519"
case ed448.PublicKeySize:
r.Crv = "Ed448"
default:
return fmt.Errorf("invalid EdDSA key length: %d", len(key))
}
r.X = base64.RawURLEncoding.EncodeToString(key)
return nil
}
// init initializes the readers and writers maps with supported algorithms.
func init() {
const size = 10
readers = make(map[string]reader, size)
addReader(jwa.RS256, decodeRSA)
addReader(jwa.RS384, decodeRSA)
addReader(jwa.RS512, decodeRSA)
addReader(jwa.PS256, decodeRSA)
addReader(jwa.PS384, decodeRSA)
addReader(jwa.PS512, decodeRSA)
addReader(jwa.ES256, decodeECDSA(elliptic.P256()))
addReader(jwa.ES384, decodeECDSA(elliptic.P384()))
addReader(jwa.ES512, decodeECDSA(elliptic.P521()))
addReader(jwa.EdDSA, decodeEdDSA)
writers = make(map[string]writer, size)
addWriter(jwa.RS256, encodeRSA)
addWriter(jwa.RS384, encodeRSA)
addWriter(jwa.RS512, encodeRSA)
addWriter(jwa.PS256, encodeRSA)
addWriter(jwa.PS384, encodeRSA)
addWriter(jwa.PS512, encodeRSA)
addWriter(jwa.ES256, encodeECDSA)
addWriter(jwa.ES384, encodeECDSA)
addWriter(jwa.ES512, encodeECDSA)
addWriter(jwa.EdDSA, encodeEdDSA)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwt provides tools for parsing, verifying, and signing JSON Web
// Tokens (JWTs).
//
// This package uses generics to allow users to define their own custom claims
// structures. A common pattern is to embed the provided [Reserved] claims
// struct and add extra fields for any other claims present in the token.
//
// # Basic Verification
//
// Start by defining custom claims:
//
// type Claims struct {
// jwt.Reserved
// Scope string `json:"scp"`
// Extra map[string]any `json:",unknown"`
// }
//
// The top-level [Verify] function can be used for simple, one-off signature
// verification without claim validation:
//
// keySet, err := jwk.ParseSet(`{"keys": [...]}`)
// if err != nil { /* handle parsing error */ }
// claims, err := jwt.Verify[Claims](keySet, []byte("eyJhb..."))
//
// # Advanced Validation
//
// For advanced validation of claims like issuer, audience, and token age,
// create a reusable [Verifier] with the desired configuration using functional
// options:
//
// verifier := jwt.NewVerifier[Claims](
// keySet,
// jwt.WithIssuers("foo", "bar"),
// jwt.WithAudiences("baz"),
// jwt.WithLeeway(1 * time.Minute),
// jwt.WithMaxAge(1 * time.Hour),
// )
//
// claims, err := verifier.Verify([]byte("eyJhb..."))
// if err != nil { /* handle validation error */ }
// fmt.Println("Scope:", claims.Scope)
//
// # Basic Signing
//
// The top-level [Sign] function can be used to create signed tokens from any
// JSON-serializable struct or map. This is useful for simple tokens where
// you manually handle all claims:
//
// // keyPair must be a jwk.KeyPair (containing a private key)
// claims := map[string]any{"sub": "user_123", "admin": true}
// token, err := jwt.Sign(keyPair, claims)
//
// # Advanced Signing
//
// To enforce policies like expiration or consistent issuers, create a reusable
// [Signer]. Your claims struct must implement [MutableClaims] (embedding
// [Reserved] handles this automatically).
//
// signer := jwt.NewSigner(
// []jwk.KeyPair{keyPair},
// jwt.WithIssuer("https://api.example.com"),
// jwt.WithLifetime(1 * time.Hour),
// )
//
// // The signer will automatically set "iss", "iat", and "exp" on the struct.
// claims := &MyClaims{
// Reserved: jwt.Reserved{Subject: "user_123"},
// Scope: "admin",
// }
// token, err := signer.Sign(claims)
//
// # Usage
//
// Verify a JWT using a custom claims struct and standard validation rules.
//
// Example:
//
// verifier := jwt.NewVerifier[MyClaims](keySet, jwt.WithIssuers("trusted"))
// claims, err := verifier.Verify(tokenBytes)
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json/jsontext"
"encoding/json/v2"
"errors"
"fmt"
"slices"
"time"
"github.com/deep-rent/nexus/internal/rotor"
"github.com/deep-rent/nexus/jose/jwk"
)
// Header provides access to the metadata associated with a JWT, such as the
// cryptographic algorithm used to sign the token and identifiers for the
// signing key.
//
// It is an alias for [jwk.Hint], allowing it to be passed directly to a
// [jwk.Set]'s Find method to locate the appropriate verification key.
type Header jwk.Hint
// header is the concrete implementation of the [Header] interface, providing
// JSON tags for standard JWS header parameters.
type header struct {
// Typ is the media type of the JWT.
Typ string `json:"typ,omitempty"`
// Alg is the JWA algorithm identifier.
Alg string `json:"alg"`
// Kid is the key identifier.
Kid string `json:"kid,omitempty"`
// X5t is the SHA-256 thumbprint of the X.509 certificate.
X5t string `json:"x5t#S256,omitempty"`
}
// Type returns the "typ" parameter from the header.
func (h *header) Type() string { return h.Typ }
// Algorithm implements [jwk.Hint].
func (h *header) Algorithm() string { return h.Alg }
// KeyID implements [jwk.Hint].
func (h *header) KeyID() string { return h.Kid }
// Thumbprint implements [jwk.Hint].
func (h *header) Thumbprint() string { return h.X5t }
var _ Header = (*header)(nil)
var (
// ErrKeyNotFound is returned when no matching key is found in the JWK set.
ErrKeyNotFound = errors.New("no matching key found")
// ErrInvalidSignature is returned when the token's signature differs from
// the computed signature.
ErrInvalidSignature = errors.New("invalid signature")
)
// Token represents a parsed, but not necessarily verified, JWT.
// The generic type T is the user-defined claims structure.
type Token[T Claims] interface {
// Header returns the token's header parameters.
Header() Header
// Claims returns the token's payload claims.
Claims() T
// Verify checks the token's signature using the provided JWK set.
// It returns ErrKeyNotFound if no matching key is found or
// ErrInvalidSignature if the signature is incorrect.
Verify(set jwk.Set) error
}
// token is the internal implementation of the [Token] interface.
type token[T Claims] struct {
// header contains the JWS header fields.
header Header
// claims contains the unmarshaled payload.
claims T
// msg is the raw JWS Protected Header and JWS Payload.
msg []byte
// sig is the raw JWS Signature.
sig []byte
}
// Header implements [Token].
func (t *token[T]) Header() Header { return t.header }
// Claims implements [Token].
func (t *token[T]) Claims() T { return t.claims }
// Verify implements [Token].
func (t *token[T]) Verify(set jwk.Set) error {
key := set.Find(t.header)
if key == nil {
return ErrKeyNotFound
}
if !key.Verify(t.msg, t.sig) {
return ErrInvalidSignature
}
return nil
}
// Ensure token implements the Token interface.
var _ Token[Claims] = (*token[Claims])(nil)
// audience represents the "aud" (Audience) claim of a JWT as defined in
// RFC 7519, Section 4.1.3.
//
// Because the "aud" claim can be either a single case-sensitive string or
// an array of such strings, this type implements custom JSON unmarshaling
// logic to ensure it is always handled as a slice of strings internally.
type audience []string
// UnmarshalJSON handles the polymorphic nature of the "aud" claim.
func (a *audience) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err == nil {
*a = audience{s}
return nil
}
var m []string
if err := json.Unmarshal(b, &m); err == nil {
*a = audience(m)
return nil
}
return errors.New("expected a string or an array of strings")
}
// Claims provides access to the standard JWT claims.
// It is used by [Verifier] for claim validation.
type Claims interface {
// ID returns the "jti" (JWT ID) claim, or an empty string if absent.
ID() string
// Subject returns the "sub" (Subject) claim, or an empty string if absent.
Subject() string
// Issuer returns the "iss" (Issuer) claim, or an empty string if absent.
Issuer() string
// Audience returns the "aud" (Audience) claim, or nil if absent.
Audience() []string
// IssuedAt returns the "iat" (Issued At) claim, or the zero time if absent.
IssuedAt() time.Time
// ExpiresAt returns the "exp" (Expires At) claim, or the zero time if absent.
ExpiresAt() time.Time
// NotBefore returns the "nbf" (Not Before) claim, or the zero time if absent.
NotBefore() time.Time
}
// MutableClaims extends [Claims] with setters for standard JWT claims.
//
// The setter methods are not safe for concurrent use and should only be called
// during token creation.
type MutableClaims interface {
Claims
// SetID sets the "jti" (JWT ID) claim.
SetID(id string)
// SetSubject sets the "sub" (Subject) claim.
SetSubject(sub string)
// SetIssuer sets the "iss" (Issuer) claim.
SetIssuer(iss string)
// SetAudience sets the "aud" (Audience) claim.
SetAudience(aud []string)
// SetIssuedAt sets the "iat" (Issued At) claim.
SetIssuedAt(t time.Time)
// SetExpiresAt sets the "exp" (Expires At) claim.
SetExpiresAt(t time.Time)
// SetNotBefore sets the "nbf" (Not Before) claim.
SetNotBefore(t time.Time)
}
// Reserved contains the standard registered claims for a JWT. It implements
// the [Claims] interface and should be embedded in custom claims structs to
// enable standard claim handling.
type Reserved struct {
Jti string `json:"jti,omitempty"` // JWT ID
Sub string `json:"sub,omitempty"` // Subject
Iss string `json:"iss,omitempty"` // Issuer
Aud audience `json:"aud,omitempty"` // Audience
Iat time.Time `json:"iat,omitzero,format:unix"` // Issued At
Exp time.Time `json:"exp,omitzero,format:unix"` // Expires At
Nbf time.Time `json:"nbf,omitzero,format:unix"` // Not Before
}
// ID implements [Claims].
func (r *Reserved) ID() string { return r.Jti }
// SetID implements [MutableClaims].
func (r *Reserved) SetID(id string) { r.Jti = id }
// Subject implements [Claims].
func (r *Reserved) Subject() string { return r.Sub }
// SetSubject implements [MutableClaims].
func (r *Reserved) SetSubject(sub string) { r.Sub = sub }
// Issuer implements [Claims].
func (r *Reserved) Issuer() string { return r.Iss }
// SetIssuer implements [MutableClaims].
func (r *Reserved) SetIssuer(iss string) { r.Iss = iss }
// Audience implements [Claims].
func (r *Reserved) Audience() []string { return r.Aud }
// SetAudience implements [MutableClaims].
func (r *Reserved) SetAudience(aud []string) { r.Aud = aud }
// IssuedAt implements [Claims].
func (r *Reserved) IssuedAt() time.Time { return r.Iat }
// SetIssuedAt implements [MutableClaims].
func (r *Reserved) SetIssuedAt(t time.Time) { r.Iat = t }
// ExpiresAt implements [Claims].
func (r *Reserved) ExpiresAt() time.Time { return r.Exp }
// SetExpiresAt implements [MutableClaims].
func (r *Reserved) SetExpiresAt(t time.Time) { r.Exp = t }
// NotBefore implements [Claims].
func (r *Reserved) NotBefore() time.Time { return r.Nbf }
// SetNotBefore implements [MutableClaims].
func (r *Reserved) SetNotBefore(t time.Time) { r.Nbf = t }
// Ensure Reserved implements the MutableClaims interface.
var _ MutableClaims = (*Reserved)(nil)
// DynamicClaims represents a standard JWT payload extended with arbitrary
// custom claims. It embeds the standard [Reserved] claims and captures any
// unmapped JSON properties into the Other map.
//
// By applying [jsontext.Value] and the `json:",inline"` tag from the
// encoding/json/v2 package, custom claims are retained as raw JSON bytes.
// This defers parsing until the exact target type is known, avoiding the
// common pitfalls of default map[string]any unmarshaling (such as all
// numbers defaulting to float64).
type DynamicClaims struct {
// Reserved contains the standard registered JWT claims.
Reserved
// Other captures all custom claims as raw JSON.
Other map[string]jsontext.Value `json:",inline"`
}
// Get retrieves a specific custom claim by key from the [DynamicClaims]
// payload and unmarshals it into the requested type T.
//
// It safely handles nil pointers, missing keys, and parsing errors. If the
// receiver 'c' is nil, the 'Other' map is uninitialized, the key is not
// found, or the raw JSON cannot be successfully unmarshaled into type T,
// Get returns the zero value of T and false. Otherwise, it returns the
// parsed value and true.
func Get[T any](c *DynamicClaims, key string) (T, bool) {
if c == nil || c.Other == nil {
var zero T
return zero, false
}
val, ok := c.Other[key]
if !ok {
var zero T
return zero, false
}
var out T
if err := json.Unmarshal(val, &out); err != nil {
var zero T
return zero, false
}
return out, true
}
// dot is the byte value for the delimiting character of JWS segments.
const dot = byte('.')
// Parse decodes a JWT from its compact serialization format into a [Token]
// without verifying the signature. The type parameter T specifies the target
// struct for the token's claims. If the token is malformed or the payload does
// not unmarshal into T (using encoding/json/v2), an error is returned.
func Parse[T Claims](in []byte) (Token[T], error) {
i := bytes.IndexByte(in, dot)
j := bytes.LastIndexByte(in, dot)
if i <= 0 || i == j || j == len(in)-1 {
return nil, errors.New("expected three dot-separated segments")
}
h, err := decode(in[:i])
if err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
header := new(header)
if err := json.Unmarshal(h, header); err != nil {
return nil, fmt.Errorf("failed to unmarshal header: %w", err)
}
if typ := header.Typ; typ != "" && typ != "JWT" {
return nil, fmt.Errorf("unexpected token type %q", typ)
}
c, err := decode(in[i+1 : j])
if err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var claims T
if err := json.Unmarshal(c, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
sig, err := decode(in[j+1:])
if err != nil {
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
msg := in[:j]
return &token[T]{
header: header,
claims: claims,
msg: msg,
sig: sig,
}, nil
}
// decode is a helper for Base64URL decoding without padding.
func decode(src []byte) ([]byte, error) {
n := base64.RawURLEncoding.DecodedLen(len(src))
d := make([]byte, n)
k, err := base64.RawURLEncoding.Decode(d, src)
if err != nil {
return nil, err
}
return d[:k], nil
}
// Verify first parses a JWT and then verifies its signature against a given key
// set. The type parameter T specifies the target struct for the token's claims.
//
// This function only checks the cryptographic signature, not the content of the
// claims. For claim validation (e.g., issuer, audience, expiration), create and
// configure a [Verifier]. It is a shorthand for [Parse] followed by calling
// [Token.Verify] on the resulting [Token].
func Verify[T Claims](set jwk.Set, in []byte) (T, error) {
tok, err := Parse[T](in)
if err != nil {
var zero T
return zero, err
}
if err := tok.Verify(set); err != nil {
var zero T
return zero, err
}
return tok.Claims(), nil
}
var (
// ErrInvalidIssuer signals that the "iss" claim did not match any of the
// expected issuers.
ErrInvalidIssuer = errors.New("invalid issuer")
// ErrInvalidAudience signals that the "aud" claim did not match any of the
// expected audiences.
ErrInvalidAudience = errors.New("invalid audience")
// ErrTokenExpired signals that the "exp" claim is in the past.
ErrTokenExpired = errors.New("token is expired")
// ErrTokenNotYetActive signals that the "nbf" claim is in the future.
ErrTokenNotYetActive = errors.New("token not yet active")
// ErrTokenTooOld signals that the "iat" claim is further in the past than
// the configured maximum age.
ErrTokenTooOld = errors.New("token is too old")
)
// Verifier defines the interface for a configured, reusable JWT verifier. The
// type parameter T is the user-defined struct for the token's claims. It must
// implement the [Claims] interface, or else verification will always fail.
type Verifier[T Claims] interface {
// Verify parses a token from its compact serialization, verifies its
// signature against the verifier's key set, and validates its claims
// according to the verifier's configuration.
Verify(in []byte) (T, error)
}
// VerifierOption defines a functional option for configuring a [Verifier].
type VerifierOption func(*verifierConfig)
// verifierConfig holds the configuration options for a [Verifier].
type verifierConfig struct {
// issuers is the list of trusted issuers.
issuers []string
// audiences is the list of trusted audiences.
audiences []string
// leeway is the clock skew tolerance.
leeway time.Duration
// age is the maximum allowed token age.
age time.Duration
// now is the time source for validation.
now func() time.Time
}
// WithIssuers adds one or more trusted issuers to the verifier. If a token's
// "iss" claim is missing or does not match one of these, it will be rejected.
// This option can be used multiple times to append additional values. By
// default, no issuer validation is performed.
func WithIssuers(iss ...string) VerifierOption {
return func(c *verifierConfig) {
c.issuers = append(c.issuers, iss...)
}
}
// WithAudiences adds one or more trusted audiences to the verifier. If the
// token's "aud" claim is missing or does not contain at least one of these
// values, it will be rejected. This option can be used multiple times to append
// additional values. By default, no audience validation is performed.
func WithAudiences(aud ...string) VerifierOption {
return func(c *verifierConfig) {
c.audiences = append(c.audiences, aud...)
}
}
// WithLeeway sets a grace period to allow for clock skew in temporal
// validations of the "exp", "nbf", and "iat" claims. It is subtracted from or
// added to the current time as appropriate. The default is zero, meaning no
// leeway. Negative values will be ignored.
func WithLeeway(d time.Duration) VerifierOption {
return func(c *verifierConfig) {
if d > 0 {
c.leeway = d
}
}
}
// WithMaxAge sets the maximum age for tokens based on their "iat" claim.
// Tokens without an "iat" claim will no longer be accepted. The default is
// zero, meaning no age validation. Negative values will be ignored.
func WithMaxAge(d time.Duration) VerifierOption {
return func(c *verifierConfig) {
if d > 0 {
c.age = d
}
}
}
// WithVerifierClock sets the function used to retrieve the current time during
// validation. This is useful for deterministic testing or synchronizing with
// an external time source. The default is [time.Now].
func WithVerifierClock(now func() time.Time) VerifierOption {
return func(c *verifierConfig) {
if now != nil {
c.now = now
}
}
}
// verifier is the default implementation of the [Verifier] interface.
type verifier[T Claims] struct {
// set is the JWK set used for signature verification.
set jwk.Set
// cfg contains the validation rules.
cfg verifierConfig
}
// Ensure verifier implements the Verifier interface.
var _ Verifier[Claims] = (*verifier[Claims])(nil)
// NewVerifier creates a new [Verifier] bound to a specific JWK set.
// The type parameter T is the user-defined struct for the token's claims.
func NewVerifier[T Claims](set jwk.Set, opts ...VerifierOption) Verifier[T] {
cfg := verifierConfig{
now: time.Now,
}
for _, opt := range opts {
opt(&cfg)
}
return &verifier[T]{
set: set,
cfg: cfg,
}
}
// Verify implements the [Verifier] interface.
func (v *verifier[T]) Verify(in []byte) (T, error) {
c, err := Verify[T](v.set, in)
if err != nil {
var zero T
return zero, err
}
now := v.cfg.now()
if len(v.cfg.issuers) > 0 && !slices.Contains(v.cfg.issuers, c.Issuer()) {
var zero T
return zero, ErrInvalidIssuer
}
if len(v.cfg.audiences) > 0 {
found := false
for _, aud := range v.cfg.audiences {
if slices.Contains(c.Audience(), aud) {
found = true
break
}
}
if !found {
var zero T
return zero, ErrInvalidAudience
}
}
if nbf := c.NotBefore(); !nbf.IsZero() {
if now.Add(v.cfg.leeway).Before(nbf) {
var zero T
return zero, ErrTokenNotYetActive
}
}
if exp := c.ExpiresAt(); !exp.IsZero() {
if now.Add(-v.cfg.leeway).After(exp) {
var zero T
return zero, ErrTokenExpired
}
}
if iat := c.IssuedAt(); v.cfg.age > 0 && !iat.IsZero() {
if iat.Add(v.cfg.age).Before(now.Add(-v.cfg.leeway)) {
var zero T
return zero, ErrTokenTooOld
}
}
return c, nil
}
// Sign creates a new signed JWT using the provided [jwk.KeyPair] and claims.
//
// It marshals the claims using encoding/json/v2, creates a header based on
// the key's properties, and signs the payload. The claims argument can be
// any type that serializes to a JSON object.
func Sign(k jwk.KeyPair, claims any) ([]byte, error) {
// Prepare and marshal the header.
header := &header{
Typ: "JWT",
Alg: k.Algorithm(),
Kid: k.KeyID(),
X5t: k.Thumbprint(),
}
h, err := json.Marshal(header)
if err != nil {
return nil, fmt.Errorf("failed to marshal header: %w", err)
}
h = encode(h)
// Marshal the claims.
c, err := json.Marshal(claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
c = encode(c)
// Construct the signing input (message).
msg := make([]byte, 0, len(h)+1+len(c))
msg = append(msg, h...)
msg = append(msg, dot)
msg = append(msg, c...)
// Sign the message.
sig, err := k.Sign(msg)
if err != nil {
return nil, fmt.Errorf("failed to sign token: %w", err)
}
sig = encode(sig)
// Assemble the final token.
token := make([]byte, 0, len(msg)+1+len(sig))
token = append(token, msg...)
token = append(token, dot)
token = append(token, sig...)
return token, nil
}
// encode is a helper for Base64URL encoding without padding.
func encode(src []byte) []byte {
dst := make([]byte, base64.RawURLEncoding.EncodedLen(len(src)))
base64.RawURLEncoding.Encode(dst, src)
return dst
}
// Signer defines the interface for a configured, reusable JWT creator.
type Signer interface {
// Sign applies the signer's configuration (issuer, audience, and temporal
// validity) directly to the mutable claims object, then signs it.
Sign(claims MutableClaims) ([]byte, error)
}
// SignerOption defines a functional option for configuring a [Signer].
type SignerOption func(*signerConfig)
// signerConfig holds the configuration options for a [Signer].
type signerConfig struct {
// iat determines if "iat" should be added automatically.
iat bool
// iss is the fixed issuer to set.
iss string
// aud is the fixed audience list to set.
aud []string
// ttl is the token lifetime for calculating "exp".
ttl time.Duration
// now is the time source for timestamping.
now func() time.Time
}
// WithIssuedAt enables or disables automatic setting of the "iat" (Issued At)
// claim for all tokens created by this signer. It is enabled by default and
// will be stamped with the current time.
func WithIssuedAt(use bool) SignerOption {
return func(c *signerConfig) {
c.iat = use
}
}
// WithIssuer sets the "iss" (Issuer) claim for all tokens created by this
// signer. If the user-provided claims already contain an issuer, this
// configuration will overwrite it.
func WithIssuer(iss string) SignerOption {
return func(c *signerConfig) {
c.iss = iss
}
}
// WithAudience sets the "aud" (Audience) claim. If the user-provided claims
// already contain an audience, this configuration will overwrite it.
func WithAudience(aud ...string) SignerOption {
return func(c *signerConfig) {
c.aud = aud
}
}
// WithLifetime sets the duration for which tokens are valid. It calculates the
// "exp" (Expires At) claim by adding this duration to the current time.
// If zero (default), no "exp" claim is added unless provided in the input
// claims.
func WithLifetime(d time.Duration) SignerOption {
return func(c *signerConfig) {
if d > 0 {
c.ttl = d
}
}
}
// WithSignerClock sets the function used to retrieve the current time when
// timestamping tokens ("iat", "nbf", "exp"). This is useful for deterministic
// testing. The default is [time.Now].
func WithSignerClock(now func() time.Time) SignerOption {
return func(c *signerConfig) {
if now != nil {
c.now = now
}
}
}
// signer is the default implementation of the [Signer] interface.
type signer struct {
// rot handles key rotation.
rot rotor.Rotor[jwk.KeyPair]
// cfg contains the generation rules.
cfg signerConfig
}
// Ensure signer implements the Signer interface.
var _ Signer = (*signer)(nil)
// NewSigner creates a new [Signer] that uses the provided key pool for
// signing. At least one key pair must be provided in the slice; otherwise, it
// panics. If multiple keys are given, they will be rotated through in a
// round-robin fashion to ensure even usage across the key pool.
func NewSigner(keys []jwk.KeyPair, opts ...SignerOption) Signer {
if len(keys) == 0 {
panic("jwt: at least one key pair is required to create a signer")
}
cfg := signerConfig{
iat: true,
now: time.Now,
}
for _, opt := range opts {
opt(&cfg)
}
return &signer{
rot: rotor.New(keys),
cfg: cfg,
}
}
// Sign implements the [Signer] interface.
func (s *signer) Sign(claims MutableClaims) ([]byte, error) {
now := s.cfg.now()
// Always stamp the current time as time of issuance if configured.
if s.cfg.iat {
claims.SetIssuedAt(now)
}
// Apply configured issuer name.
if s.cfg.iss != "" {
claims.SetIssuer(s.cfg.iss)
}
// Apply configured audience.
if len(s.cfg.aud) > 0 {
claims.SetAudience(s.cfg.aud)
}
// Calculate and apply expiration if a lifetime is configured.
if s.cfg.ttl > 0 {
claims.SetExpiresAt(now.Add(s.cfg.ttl))
}
key := s.rot.Next()
// Delegate to the low-level Sign function.
return Sign(key, claims)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package log provides a configurable constructor for the standard slog.Logger.
//
// Package log provides a configurable constructor for the standard
// [slog.Logger], allowing for easy setup using the functional options pattern.
// It simplifies the creation of a structured logger by abstracting away the
// handler setup and providing flexible options for setting the level, format,
// and output from common types like strings.
//
// # Usage
//
// Create a logger that outputs JSON at a debug level to standard error:
//
// Example:
//
// logger := log.New(
// log.WithLevel("debug"),
// log.WithFormat("json"),
// log.WithWriter(os.Stderr),
// log.WithAddSource(true), // Include file and line number.
// )
//
// slog.SetDefault(logger)
// slog.Debug("This is a debug message")
//
// Create a multi-target logger using Combine and NewHandler:
//
// Example:
//
// h1 := log.NewHandler(
// log.WithLevel("debug"),
// log.WithFormat("text"),
// log.WithWriter(os.Stdout),
// )
// h2 := log.NewHandler(
// log.WithLevel("error"),
// log.WithFormat("json"),
// log.WithWriter(os.Stderr),
// )
// multiLogger := log.Combine(h1, h2)
//
// slog.SetDefault(multiLogger)
// slog.Debug("This is a debug message")
package log
import (
"fmt"
"io"
"log/slog"
"os"
"strings"
)
// Default configuration values for a new logger.
const (
// DefaultLevel is the level used when none is specified.
DefaultLevel = slog.LevelInfo
// DefaultAddSource is the default setting for including source information.
DefaultAddSource = false
// DefaultFormat is the format used when none is specified.
DefaultFormat = FormatText
)
// Format defines the log output format, such as JSON or plain text.
type Format uint8
const (
// FormatText produces human-readable text format.
FormatText Format = iota
// FormatJSON produces JSON format suitable for machine parsing.
FormatJSON
)
// String returns the lower-case string representation of the log format.
func (f Format) String() string {
switch f {
case FormatJSON:
return "json"
default:
return "text"
}
}
// New creates and configures a new [slog.Logger]. By default, it logs at
// [slog.LevelInfo] in plain text to [os.Stdout], without source information.
// These defaults can be overridden by passing in one or more [Option] functions.
func New(opts ...Option) *slog.Logger {
return slog.New(NewHandler(opts...))
}
// Combine creates a new [slog.Logger] that broadcasts log records to multiple
// provided [slog.Handler] instances simultaneously using [slog.NewMultiHandler].
func Combine(handlers ...slog.Handler) *slog.Logger {
return slog.New(slog.NewMultiHandler(handlers...))
}
// NewHandler creates and configures a new [slog.Handler]. By default, it
// sets up a text handler logging at [slog.LevelInfo] to [os.Stdout].
// These defaults can be overridden by passing in one or more [Option] functions.
func NewHandler(opts ...Option) slog.Handler {
c := config{
Level: DefaultLevel,
AddSource: DefaultAddSource,
Format: DefaultFormat,
Writer: os.Stdout,
}
for _, opt := range opts {
opt(&c)
}
w := c.Writer
o := &slog.HandlerOptions{
Level: c.Level,
AddSource: c.AddSource,
}
var handler slog.Handler
switch c.Format {
case FormatJSON:
handler = slog.NewJSONHandler(w, o)
default:
handler = slog.NewTextHandler(w, o)
}
return handler
}
// config holds the configuration settings for the logger.
type config struct {
// Level is the minimum log level enabled.
Level slog.Level
// AddSource determines if file/line information is included.
AddSource bool
// Format determines the output encoding.
Format Format
// Writer is the output destination.
Writer io.Writer
}
// Option defines a function that modifies the logger configuration.
type Option func(*config)
// WithLevel sets the minimum log level. It accepts either a [slog.Level]
// constant (e.g., [slog.LevelDebug]) or a case-insensitive string (e.g.,
// "debug") as handled by [ParseLevel]. If an invalid string or type is
// provided, the option is a no-op.
func WithLevel(v any) Option {
return func(c *config) {
switch t := v.(type) {
case slog.Level:
c.Level = t
case string:
level, err := ParseLevel(t)
if err == nil {
c.Level = level
}
}
}
}
// WithFormat sets the log output format. It accepts either a [Format] constant
// ([FormatText] or [FormatJSON]) or a case-insensitive string ("text" or
// "json") as handled by [ParseFormat]. If an invalid string or type is
// provided, the option is a no-op.
func WithFormat(v any) Option {
return func(c *config) {
switch t := v.(type) {
case Format:
c.Format = t
case string:
format, err := ParseFormat(t)
if err == nil {
c.Format = format
}
}
}
}
// WithAddSource configures the logger to include the source code position (file
// and line number) in each log entry.
//
// Note that this has a performance cost and should be used judiciously, often
// enabled only during development or at debug levels.
func WithAddSource(add bool) Option {
return func(c *config) {
c.AddSource = add
}
}
// WithWriter returns an [Option] that sets the output destination for the logs.
// If the provided [io.Writer] is nil, it is ignored.
func WithWriter(w io.Writer) Option {
return func(c *config) {
if w != nil {
c.Writer = w
}
}
}
// ParseLevel converts a case-insensitive string into a [slog.Level].
// It accepts standard level names like "debug", "info", "warn", and "error".
// It returns an error if the string is not a valid level.
func ParseLevel(s string) (level slog.Level, err error) {
if e := level.UnmarshalText([]byte(s)); e != nil {
err = fmt.Errorf("invalid log level %q", s)
}
return level, err
}
// ParseFormat converts a case-insensitive string into a [Format].
// Valid inputs are "text" and "json". It returns an error for any other value.
func ParseFormat(s string) (format Format, err error) {
switch strings.ToLower(s) {
case "json":
format = FormatJSON
return format, err
case "text":
format = FormatText
return format, err
default:
err = fmt.Errorf("invalid log format %q", s)
return format, err
}
}
// Silent creates a logger that discards all output.
func Silent() *slog.Logger {
const LevelSilent = slog.Level(100)
return New(
WithWriter(io.Discard),
WithLevel(LevelSilent),
)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mail provides abstractions for sending transactional emails.
//
// It defines a generic payload model ([Message]) and a common [Sender]
// interface for email delivery. This decouples the application's
// business logic from the underlying mechanism. By default, this package
// provides a production-ready Twilio SendGrid implementation initialized via
// [NewSender].
//
// # Usage
//
// Typically, you initialize a [Sender] at application startup, construct
// a [Message] using the fluent API, and pass it to the sender.
//
// Example:
//
// // 1. Initialize the default SendGrid sender with a custom User-Agent.
// sender := mail.NewSender("your-api-key", mail.WithUserAgent("MyApp/1.0"))
//
// // 2. Construct the email message.
// msg := mail.NewMessage(
// mail.New("no-reply@example.com", "My App"),
// "template-id-123",
// mail.NewRecipient(mail.New("user@example.com", "Alice")).
// AddTemplateData("name", "Alice"),
// )
//
// // 3. Dispatch the email.
// err = sender.Send(context.Background(), msg)
package mail
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"time"
"github.com/deep-rent/nexus/retry"
)
const (
// DefaultBaseURL is the standard API endpoint for SendGrid v3.
DefaultBaseURL = "https://api.sendgrid.com/v3"
// DefaultTimeout is the default timeout for API requests (5 seconds).
DefaultTimeout = 5 * time.Second
)
var (
// ErrNilMessage is returned when a nil [Message] is validated.
ErrNilMessage = errors.New("mail: message cannot be nil")
// ErrMissingRecipients is returned when an email has no recipients.
ErrMissingRecipients = errors.New("mail: at least one recipient is required")
// ErrMissingTemplateID is returned when an email has no template ID.
ErrMissingTemplateID = errors.New("mail: template ID is required")
// ErrMissingFrom is returned when an email has no sender address.
ErrMissingFrom = errors.New("mail: from address is required")
// ErrDispatchFailed is returned when the underlying provider rejects the
// payload.
ErrDispatchFailed = errors.New("mail: dispatching failed")
)
// APIError represents an error returned by the underlying email provider.
type APIError struct {
// Status is the HTTP status code returned by the provider.
Status int
// Body is the raw response body returned by the provider.
Body string
}
// Error implements the [error] interface.
func (e *APIError) Error() string {
return fmt.Sprintf("mail: api returned status %d: %s", e.Status, e.Body)
}
// Unwrap allows [errors.Is] to match against [ErrDispatchFailed].
func (e *APIError) Unwrap() error {
return ErrDispatchFailed
}
var _ error = (*APIError)(nil)
// Mail represents an email address and an optional display name.
type Mail struct {
// Addr is the actual email address (e.g., "alice@example.com").
Addr string `json:"email"`
// Name is an optional display name (e.g., "Alice Smith").
Name string `json:"name,omitzero"`
}
// New creates a new [Mail] with an optional display name.
func New(addr, name string) Mail {
return Mail{
Addr: addr,
Name: name,
}
}
// String implements [fmt.Stringer] to return the string representation of the
// email instance (e.g., "Name <email@example.com>").
func (m Mail) String() string {
if m.Name == "" {
return m.Addr
}
return fmt.Sprintf("%s <%s>", m.Name, m.Addr)
}
// Recipient represents a single intended recipient or group of receivers,
// along with the specific template data to be used for them.
type Recipient struct {
// To contains the primary [Mail] recipients.
To []Mail `json:"to"`
// CC contains the carbon copy [Mail] recipients.
CC []Mail `json:"cc,omitzero"`
// TemplateData holds the key-value pairs used to populate the template
// variables for this specific recipient group.
TemplateData map[string]any `json:"dynamic_template_data,omitzero"`
}
// NewRecipient creates a new [Recipient] group with the required primary
// destinations.
func NewRecipient(to ...Mail) *Recipient {
return &Recipient{
To: to,
}
}
// AddTo appends one or more [Mail] recipients to the To list.
func (r *Recipient) AddTo(mails ...Mail) *Recipient {
r.To = append(r.To, mails...)
return r
}
// AddCC appends one or more [Mail] recipients to the CC list.
func (r *Recipient) AddCC(mails ...Mail) *Recipient {
r.CC = append(r.CC, mails...)
return r
}
// AddTemplateData adds or updates a key-value pair in the TemplateData map.
func (r *Recipient) AddTemplateData(key string, value any) *Recipient {
if r.TemplateData == nil {
r.TemplateData = make(map[string]any)
}
r.TemplateData[key] = value
return r
}
// SetTemplateData replaces the entire TemplateData map for the [Recipient].
func (r *Recipient) SetTemplateData(data map[string]any) *Recipient {
r.TemplateData = data
return r
}
// Validate checks if the [Recipient] group has at least one primary destination.
func (r *Recipient) Validate() error {
if r == nil || len(r.To) == 0 {
return ErrMissingRecipients
}
return nil
}
// Message represents a transactional email payload designed for dynamic
// templates.
type Message struct {
// From is the sender's [Mail] address.
From Mail `json:"from"`
// Recipients contains groups of receivers and their specific template data.
Recipients []*Recipient `json:"personalizations"`
// ReplyTo is an optional [Mail] address where replies should be directed.
ReplyTo *Mail `json:"reply_to,omitzero"`
// TemplateID is the provider-specific identifier of the dynamic template to
// use.
TemplateID string `json:"template_id"`
}
// NewMessage creates a new [Message] with the required fields.
func NewMessage(
from Mail,
templateID string,
recipients ...*Recipient,
) *Message {
return &Message{
From: from,
TemplateID: templateID,
Recipients: recipients,
}
}
// AddRecipient appends a [Recipient] group to the [Message].
func (m *Message) AddRecipient(r *Recipient) *Message {
m.Recipients = append(m.Recipients, r)
return m
}
// WithReplyTo sets an optional ReplyTo [Mail] address on the [Message].
func (m *Message) WithReplyTo(mail Mail) *Message {
m.ReplyTo = &mail
return m
}
// Validate checks if the [Message] has the minimum required fields for sending.
func (m *Message) Validate() error {
if m == nil {
return ErrNilMessage
}
if m.From.Addr == "" {
return ErrMissingFrom
}
if len(m.Recipients) == 0 {
return ErrMissingRecipients
}
for _, r := range m.Recipients {
if err := r.Validate(); err != nil {
return err
}
}
if m.TemplateID == "" {
return ErrMissingTemplateID
}
return nil
}
// Sender is the interface that wraps the Send method.
//
// Implementations of this interface are expected to be safe for concurrent
// use by multiple goroutines. They should respect the provided context for
// timeouts and cancellation.
type Sender interface {
// Send dispatches the provided [Message] payload to the underlying provider.
// It returns an error if the email is invalid, if the network request
// fails, or if the provider rejects the payload.
Send(ctx context.Context, msg *Message) error
}
// sender is a SendGrid email client that implements the [Sender] interface.
//
// It manages the HTTP client and authentication state required to interact
// with the SendGrid API. Once initialized via [NewSender], a [sender] is safe
// for concurrent use by multiple goroutines.
type sender struct {
// auth stores the Authorization header value for the provider.
auth string
// url is the resolved API endpoint for dispatching requests.
url string
// userAgent is the value sent in the User-Agent header of requests.
userAgent string
// client holds the configured [http.Client].
client *http.Client
// logger is used for structured diagnostic output.
logger *slog.Logger
// retry contains the retry configuration options.
retry []retry.Option
}
var _ Sender = (*sender)(nil)
// config holds the optional configuration for the [sender].
type config struct {
// client holds a custom [http.Client], if provided.
client *http.Client
// baseURL overrides the default SendGrid API endpoint.
baseURL string
// userAgent defines the User-Agent header value for outgoing requests.
userAgent string
// timeout sets the maximum [time.Duration] for HTTP requests.
timeout time.Duration
// retry stores options for the HTTP transport retry mechanism.
retry []retry.Option
// logger specifies the custom structured [slog.Logger].
logger *slog.Logger
}
// Option defines the functional option pattern for configuring the [sender].
type Option func(*config)
// WithClient allows passing a custom [http.Client] to the [sender].
// If provided, it overrides the [WithTimeout] setting. Nil values will be
// ignored.
func WithClient(client *http.Client) Option {
return func(c *config) {
if client != nil {
c.client = client
}
}
}
// WithBaseURL allows overriding the SendGrid API base URL for testing or
// mocking.
func WithBaseURL(url string) Option {
return func(c *config) {
c.baseURL = url
}
}
// WithUserAgent configures a custom User-Agent header for the outbound
// API requests.
func WithUserAgent(v string) Option {
return func(c *config) {
c.userAgent = v
}
}
// WithTimeout configures the timeout for the default [http.Client].
func WithTimeout(d time.Duration) Option {
return func(c *config) {
c.timeout = d
}
}
// WithRetryOptions configures the retry mechanism for the default HTTP client.
// If a custom HTTP client is provided via [WithClient], these options are
// ignored.
func WithRetryOptions(opts ...retry.Option) Option {
return func(c *config) {
c.retry = append(c.retry, opts...)
}
}
// WithLogger injects a structured [slog.Logger] into the sender.
// Nil values will be ignored.
func WithLogger(logger *slog.Logger) Option {
return func(c *config) {
if logger != nil {
c.logger = logger
}
}
}
// NewSender creates a configured SendGrid client implementing the [Sender]
// interface.
//
// It initializes the client with a default base URL, a sensible timeout,
// and a standard logger. These defaults can be overridden by passing one or
// more [Option] functions. If no custom [http.Client] is provided, it builds
// an internal client optimized for API calls with connection pooling and
// automatic retry capabilities. It panics if the API key is empty or the base
// URL is invalid.
func NewSender(apiKey string, opts ...Option) Sender {
if apiKey == "" {
panic("mail: API key is required")
}
cfg := config{
baseURL: DefaultBaseURL,
timeout: DefaultTimeout,
logger: slog.Default(),
}
for _, opt := range opts {
opt(&cfg)
}
endpoint, err := url.JoinPath(cfg.baseURL, "mail/send")
if err != nil {
panic(fmt.Errorf("mail: invalid base URL: %w", err))
}
s := &sender{
auth: "Bearer " + apiKey,
url: endpoint,
userAgent: cfg.userAgent,
logger: cfg.logger,
retry: cfg.retry,
}
// Initialize the default HTTP client if a custom one wasn't provided.
if cfg.client == nil {
d := &net.Dialer{
Timeout: cfg.timeout / 3,
}
var t http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: d.DialContext,
TLSHandshakeTimeout: cfg.timeout / 3,
ResponseHeaderTimeout: cfg.timeout * 9 / 10,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
}
t = retry.NewTransport(t, cfg.retry...)
s.client = &http.Client{
Timeout: cfg.timeout,
Transport: t,
}
} else {
if len(cfg.retry) > 0 {
cfg.logger.Warn("Custom client provided; retry options will be ignored")
}
s.client = cfg.client
}
return s
}
// Send executes the HTTP request to the SendGrid v3 API.
//
// It maps the domain [Message] payload into SendGrid's expected JSON
// structure and dispatches the request. It respects the provided
// [context.Context] for timeouts and cancellation. If the API responds with
// an HTTP status code >= 400, it logs the response body and returns an
// [*APIError].
func (s *sender) Send(ctx context.Context, msg *Message) error {
if err := msg.Validate(); err != nil {
return err
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(msg); err != nil {
return fmt.Errorf("mail: failed to encode payload: %w", err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
s.url,
&buf,
)
if err != nil {
return fmt.Errorf("mail: failed to create request: %w", err)
}
req.Header.Set("Authorization", s.auth)
if s.userAgent != "" {
req.Header.Set("User-Agent", s.userAgent)
}
req.Header.Set("Content-Type", "application/json")
s.logger.DebugContext(ctx, "Dispatching message to provider",
slog.String("template_id", msg.TemplateID),
slog.Int("recipients", len(msg.Recipients)),
)
start := time.Now()
res, err := s.client.Do(req)
if err != nil {
return fmt.Errorf("mail: request failed: %w", err)
}
delta := time.Since(start)
defer func() {
// Drain body to ensure connection reuse:
_, _ = io.Copy(io.Discard, res.Body)
err := res.Body.Close()
if err != nil {
s.logger.WarnContext(
ctx,
"Failed to close response body",
slog.Any("error", err),
)
}
}()
if code := res.StatusCode; code >= 400 {
body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20))
return &APIError{
Status: code,
Body: string(body),
}
}
s.logger.DebugContext(
ctx,
"Message dispatched",
slog.Duration("duration", delta),
)
return nil
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cors provides a configurable CORS middleware for http.Handlers.
//
// Package cors provides a configurable CORS (Cross-Origin Resource Sharing)
// middleware for [http.Handler] instances. It automatically handles preflight
// (OPTIONS) requests and injects the appropriate CORS headers into responses
// for actual requests.
//
// # Usage
//
// The [New] function creates the middleware pipe, which can be configured with
// functional options (e.g., [WithAllowedOrigins], [WithAllowedMethods]).
//
// Example:
//
// // Configure CORS to allow requests from a specific origin with
// // restricted methods and additional headers.
// pipe := cors.New(
// cors.WithAllowedOrigins("https://example.com"),
// cors.WithAllowedMethods(http.MethodGet, http.MethodOptions),
// cors.WithAllowedHeaders("Authorization", "Content-Type"),
// cors.WithMaxAge(12*time.Hour),
// )
//
// handler := http.HandlerFunc( ... )
// // Apply the CORS middleware as one of the first layers.
// chainedHandler := middleware.Chain(handler, pipe)
//
// http.ListenAndServe(":8080", chainedHandler)
package cors
import (
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/deep-rent/nexus/middleware"
)
// wildcard is a special value that can be passed in configuration to allow
// requests from any origin.
const wildcard = "*"
// config stores the pre-computed configuration for internal use.
type config struct {
// allowedOrigins is the whitelist of permitted Origin values.
allowedOrigins map[string]struct{}
// allowedMethods is the pre-joined string for Access-Control-Allow-Methods.
allowedMethods string
// allowedHeaders is the pre-joined string for Access-Control-Allow-Headers.
allowedHeaders string
// exposedHeaders is the pre-joined string for Access-Control-Expose-Headers.
exposedHeaders string
// allowCredentials maps to the Access-Control-Allow-Credentials header.
allowCredentials bool
// maxAge is the string representation of Access-Control-Max-Age in seconds.
maxAge string
}
// Option is a function that configures the CORS middleware.
type Option func(*config)
// WithAllowedOrigins sets the allowed origins for CORS requests.
//
// By default, all origins are allowed. The same behavior can be achieved by
// leaving the list empty or by manually including the special wildcard "*". In
// other cases, this option restricts requests to a specific whitelist. If
// credentials are enabled via [WithAllowCredentials], browsers forbid a
// wildcard origin, and this middleware will dynamically reflect the request's
// Origin header if it is in the allowed list.
func WithAllowedOrigins(origins ...string) Option {
return func(c *config) {
if len(origins) != 0 && !slices.Contains(origins, wildcard) {
c.allowedOrigins = make(map[string]struct{}, len(origins))
for _, origin := range origins {
c.allowedOrigins[origin] = struct{}{}
}
}
}
}
// WithAllowedMethods sets the allowed HTTP methods for CORS requests.
//
// If no methods are provided, this header is omitted by default, and only
// simple methods (GET, POST, HEAD) are implicitly allowed by browsers for
// non-preflighted requests. It is recommended to list all methods your API
// supports, including OPTIONS.
func WithAllowedMethods(methods ...string) Option {
return func(c *config) {
if len(methods) != 0 {
c.allowedMethods = strings.Join(methods, ", ")
}
}
}
// WithAllowedHeaders sets the allowed HTTP headers for CORS requests.
//
// This is necessary for any non-standard headers the client needs to send,
// such as "Authorization" or custom "X-" headers. If not set, browsers will
// only permit requests with CORS-safelisted request headers.
func WithAllowedHeaders(headers ...string) Option {
return func(c *config) {
if len(headers) != 0 {
c.allowedHeaders = strings.Join(headers, ", ")
}
}
}
// WithExposedHeaders sets the HTTP headers safe to expose to the API.
//
// By default, client-side scripts can only access a limited set of simple
// response headers. This option lists additional headers (like a custom
// "X-Pagination-Total" header) that should be made accessible to the script.
func WithExposedHeaders(headers ...string) Option {
return func(c *config) {
if len(headers) != 0 {
c.exposedHeaders = strings.Join(headers, ", ")
}
}
}
// WithAllowCredentials indicates if the response can be exposed with credentials.
//
// When used as part of a response to a preflight request, it indicates that the
// actual request can include cookies and other user credentials. This option
// defaults to false. Note that browsers require a specific origin (not a
// wildcard) in the Access-Control-Allow-Origin header when this is enabled.
func WithAllowCredentials(allow bool) Option {
return func(c *config) {
c.allowCredentials = allow
}
}
// WithMaxAge indicates how long preflight results can be cached, in seconds.
//
// If set to 0 (the default), the header is omitted. Be aware that browsers
// have a default internal limit (usually 5 seconds) when this header is
// missing. This results in a preflight request for almost every API call, which
// can double the traffic to your server. It is recommended to set this to a
// higher value (e.g., 10 minutes) for stable APIs to reduce latency.
func WithMaxAge(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.maxAge = strconv.FormatInt(int64(d.Seconds()), 10)
}
}
}
// New creates a middleware [middleware.Pipe] that handles CORS requests.
//
// The middleware distinguishes between preflight and actual requests. Preflight
// (OPTIONS) requests are intercepted and terminated with a 204 No Content
// response. For actual requests, it adds the necessary CORS headers to the
// response before passing control to the next handler. Non-CORS requests are
// passed through without modification.
func New(opts ...Option) middleware.Pipe {
cfg := config{}
for _, opt := range opts {
opt(&cfg)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if proceed := handle(&cfg, w, r); proceed {
next.ServeHTTP(w, r)
}
})
}
}
// handle processes CORS headers for the given request.
//
// It returns true if the request should be passed to the next handler. It
// returns false if the request has been fully handled, such as in a preflight
// request.
func handle(cfg *config, w http.ResponseWriter, r *http.Request) bool {
origin := r.Header.Get("Origin")
// Pass through non-CORS requests.
if origin == "" {
return true
}
// Apply this header immediately to ensure caches respect the difference
// between allowed and disallowed origin responses.
h := w.Header()
h.Add("Vary", "Origin")
preflight := r.Method == http.MethodOptions
// Pass through invalid preflight requests.
if preflight && r.Header.Get("Access-Control-Request-Method") == "" {
return true
}
// Validate origin if not in wildcard mode.
if cfg.allowedOrigins != nil {
if _, ok := cfg.allowedOrigins[origin]; !ok {
return true // Let non-matching origins pass through without CORS headers.
}
}
if !cfg.allowCredentials && cfg.allowedOrigins == nil {
origin = wildcard
}
h.Set("Access-Control-Allow-Origin", origin)
if cfg.allowCredentials {
h.Set("Access-Control-Allow-Credentials", "true")
}
// Handle preflight requests.
if preflight {
if cfg.allowedMethods != "" {
h.Set("Access-Control-Allow-Methods", cfg.allowedMethods)
}
if cfg.allowedHeaders != "" {
h.Set("Access-Control-Allow-Headers", cfg.allowedHeaders)
}
if cfg.maxAge != "" {
h.Set("Access-Control-Max-Age", cfg.maxAge)
}
w.WriteHeader(http.StatusNoContent)
return false // Terminate request chain.
}
// Handle actual requests.
if cfg.exposedHeaders != "" {
h.Set("Access-Control-Expose-Headers", cfg.exposedHeaders)
}
return true
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gzip provides an HTTP middleware for compressing response bodies.
//
// Package gzip provides an HTTP middleware for compressing response bodies
// using the gzip algorithm. It automatically adds the "Content-Encoding: gzip"
// header and compresses the payload for clients that support it (indicated by
// the "Accept-Encoding" request header).
//
// # Usage
//
// The middleware is designed to be efficient. It pools [gzip.Writer] instances
// to reduce memory allocations and gracefully skips compression for responses
// that already have a "Content-Encoding" header set.
//
// Example:
//
// // Create the final handler.
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Header().Set("Content-Type", "text/plain")
// w.Write([]byte("This is a long string that will be compressed."))
// })
//
// // Create a gzip middleware pipe with the highest level of compression.
// pipe := gzip.New(
// gzip.WithCompressionLevel(gzip.BestCompression),
// gzip.WithExcludeMimeTypes("text/*", "application/font-woff"),
// )
//
// // Apply the middleware as one of the first layers.
// chainedHandler := middleware.Chain(handler, pipe)
//
// http.ListenAndServe(":8080", chainedHandler)
package gzip
import (
"bufio"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/middleware"
)
// Mirror constants from the [compress/gzip] package for easy access without
// requiring an extra import.
const (
// BestCompression provides the highest level of compression.
BestCompression = gzip.BestCompression
// BestSpeed provides the fastest compression time.
BestSpeed = gzip.BestSpeed
// DefaultCompression provides a balance between speed and ratio.
DefaultCompression = gzip.DefaultCompression
// NoCompression disables compression entirely.
NoCompression = gzip.NoCompression
)
// defaultExcludeList lists common media types that are already compressed.
var defaultExcludeList = []string{
// Media
"image/*",
"video/*",
"audio/*",
// Fonts
"font/*",
// Archives & Documents
"application/zip",
"application/gzip",
"application/pdf",
"application/wasm",
}
// interceptor wraps an [http.ResponseWriter] to compress the response body.
//
// It transparently compresses the response body with gzip. It also implements
// [http.Hijacker] and [http.Flusher] to support protocol upgrades and
// streaming.
type interceptor struct {
// ResponseWriter is the underlying writer being wrapped.
http.ResponseWriter
// gz is the active gzip writer for the current response.
gz *gzip.Writer
// level is the compression level to use.
level int
// exclude is the list of MIME types to skip.
exclude []string
// pool is the sync.Pool used for gzip writer reuse.
pool *sync.Pool
// wrote tracks if WriteHeader has been called.
wrote bool
// hijacked tracks if the connection has been hijacked.
hijacked bool
// skip determines whether to skip compression for this response.
skip bool
}
// WriteHeader sets the Content-Encoding header and deletes Content-Length.
//
// Deleting Content-Length is crucial, as the size of the compressed content is
// unknown until it is fully written.
func (w *interceptor) WriteHeader(statusCode int) {
if w.wrote {
return
}
w.wrote = true
if w.ResponseWriter.Header().Get("Content-Encoding") != "" {
w.skip = true
}
mime := header.MediaType(w.Header())
if mime != "" {
for _, t := range w.exclude {
if strings.HasSuffix(t, "*") {
if strings.HasPrefix(mime, t[:len(t)-1]) {
w.skip = true
break
}
} else {
if mime == t {
w.skip = true
break
}
}
}
}
if !w.skip {
w.Header().Set("Content-Encoding", "gzip")
w.Header().Del("Content-Length")
w.gz = w.pool.Get().(*gzip.Writer)
w.gz.Reset(w.ResponseWriter)
}
w.ResponseWriter.WriteHeader(statusCode)
}
// Write compresses the data and writes it to the underlying
// [http.ResponseWriter].
//
// It also handles setting the Content-Encoding header on the first write.
func (w *interceptor) Write(b []byte) (int, error) {
if !w.wrote {
w.WriteHeader(http.StatusOK)
}
if w.skip {
return w.ResponseWriter.Write(b)
}
return w.gz.Write(b)
}
// Close flushes buffered data, closes the gzip writer, and returns it to the
// pool.
func (w *interceptor) Close() {
// If the connection was hijacked, don't write the gzip footer.
// Just return the writer to the pool.
if w.gz != nil {
if !w.hijacked {
_ = w.gz.Close()
}
w.gz.Reset(io.Discard)
w.pool.Put(w.gz)
w.gz = nil
}
}
// Hijack implements the [http.Hijacker] interface.
//
// It allows the underlying connection to be taken over for protocol upgrades
// like WebSockets.
func (w *interceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijacking not supported")
}
w.hijacked = true
return hijacker.Hijack()
}
// Flush implements the [http.Flusher] interface.
//
// It enables incremental flushing of the response body, which is useful for
// streaming data.
func (w *interceptor) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
if w.gz != nil {
_ = w.gz.Flush()
}
flusher.Flush()
}
}
// Ensure interceptor implements the necessary contracts.
var (
_ http.ResponseWriter = (*interceptor)(nil)
_ http.Hijacker = (*interceptor)(nil)
_ http.Flusher = (*interceptor)(nil)
)
// New creates a middleware [middleware.Pipe] that compresses HTTP responses.
//
// The middleware is a no-op if the client does not send an Accept-Encoding
// header including "gzip" or if the response already has a non-empty
// Content-Encoding header. It adds the "Vary: Accept-Encoding" header to
// responses to prevent cache poisoning.
func New(opts ...Option) middleware.Pipe {
cfg := config{
level: DefaultCompression,
exclude: defaultExcludeList,
}
for _, opt := range opts {
opt(&cfg)
}
pool := &sync.Pool{
New: func() any {
// Errors are ignored as they only occur with an invalid level,
// which we guard against in the option.
gw, _ := gzip.NewWriterLevel(io.Discard, cfg.level)
return gw
},
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip if the client doesn't accept gzip compression.
if !header.Accepts(r.Header.Get("Accept-Encoding"), "gzip") ||
w.Header().Get("Content-Encoding") != "" {
next.ServeHTTP(w, r)
return
}
// Create the gzip response writer.
gzw := &interceptor{
ResponseWriter: w,
level: cfg.level,
exclude: cfg.exclude,
pool: pool,
}
defer gzw.Close()
// Indicate that the response is subject to content negotiation.
gzw.Header().Add("Vary", "Accept-Encoding")
next.ServeHTTP(gzw, r)
})
}
}
// config holds the middleware configuration.
type config struct {
// level is the compression level.
level int
// exclude is the list of MIME types to skip.
exclude []string
}
// Option is a function that configures the middleware.
type Option func(*config)
// WithCompressionLevel sets the compression level.
//
// It accepts values ranging from [BestSpeed] (1) to [BestCompression] (9). For
// other values, it will fall back to [DefaultCompression], a good balance
// between speed and compression ratio.
func WithCompressionLevel(level int) Option {
return func(c *config) {
if level >= NoCompression && level <= BestCompression {
c.level = level
}
}
}
// WithExcludeMimeTypes adds MIME types to the exclusion list.
//
// This option is additive and can be called multiple times; it appends to the
// default exclusion list rather than replacing it. The matching logic supports
// two formats:
//
// - Exact: Provide the full MIME type (e.g., "application/pdf").
// - Prefix: End the MIME type with a wildcard "*" (e.g., "image/*")
// to exclude all subtypes for that primary type.
func WithExcludeMimeTypes(types ...string) Option {
return func(c *config) {
for _, t := range types {
c.exclude = append(c.exclude, strings.ToLower(t))
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package middleware provides a standard approach for HTTP transport
// middleware.
//
// Package middleware provides a standard approach for chaining and composing
// HTTP transport middleware. This package focuses on low-level HTTP operations
// (like logging, CORS, and compression) that operate directly on
// [http.Handler].
//
// For higher-level business logic that requires structured error handling and
// API contexts, see the Middleware definitions in the router package. The
// router provides an adapter to seamlessly integrate the transport pipes
// defined here within its richer handler ecosystem.
//
// # Usage
//
// The core type is [Pipe], an adapter that wraps an [http.Handler] to add
// functionality. The [Chain] function composes these pipes into a single
// handler.
//
// Example:
//
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("OK"))
// })
//
// // Chain middleware around the final handler.
// // Order matters: Recover must be first (outermost).
// logger := slog.Default()
// chainedHandler := middleware.Chain(handler,
// middleware.Recover(logger),
// middleware.RequestID(),
// middleware.Log(logger),
// )
//
// http.ListenAndServe(":8080", chainedHandler)
package middleware
import (
"context"
"crypto/rand"
"encoding/hex"
"log/slog"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
)
// Pipe is a middleware function.
//
// Pipe is an adapter that takes an [http.Handler] and returns a new
// [http.Handler], allowing functionality to be composed in layers.
type Pipe func(http.Handler) http.Handler
// Chain combines a handler with multiple middleware Pipes.
//
// The pipes are applied in reverse order, meaning the first pipe in the list is
// the outermost and executes first. For example, Chain(h, A, B, C) results in a
// handler equivalent to A(B(C(h))). Any nil pipes in the list are safely
// ignored.
func Chain(h http.Handler, pipes ...Pipe) http.Handler {
for i := len(pipes) - 1; i >= 0; i-- {
if pipe := pipes[i]; pipe != nil {
h = pipe(h)
}
}
return h
}
// Recover produces a middleware [Pipe] that catches panics in downstream
// handlers.
//
// It uses the provided logger to report the exception with a stack trace and
// returns an empty response with status code 500 to the client. The log entry
// also pinpoints the request method and URL that caused the panic. For maximum
// effectiveness, this should be the first (outermost) middleware in the chain.
func Recover(logger *slog.Logger) Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
defer func() {
if r := recover(); r != nil {
method, url := req.Method, req.URL.String()
logger.Error(
"Panic caught by middleware",
slog.String("method", method),
slog.String("url", url),
slog.Any("panic", r),
slog.String("stack", string(debug.Stack())),
)
res.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(res, req)
})
}
}
// contextKey prevents collisions with other packages.
type contextKey struct{}
// requestIDKey is the key under which the request ID is stored in the request
// context.
var requestIDKey contextKey
// RequestID returns a middleware [Pipe] that injects a unique ID into each
// request.
//
// It adds the ID to the response via the "X-Request-ID" header and to the
// request's context for downstream use. Downstream handlers and other
// middleware can retrieve the ID using [GetRequestID]. If a unique ID cannot be
// generated from the random source, this middleware does nothing and passes
// the request to the next handler.
func RequestID() Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
next.ServeHTTP(w, r)
return
}
id := hex.EncodeToString(b)
w.Header().Set("X-Request-ID", id)
next.ServeHTTP(w, r.WithContext(SetRequestID(r.Context(), id)))
})
}
}
// GetRequestID retrieves the request ID from a given context.
//
// It returns an empty string if the ID is not found.
func GetRequestID(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey).(string)
return id
}
// SetRequestID sets the request ID in the provided context.
//
// It returns a new context that carries the ID.
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, requestIDKey, id)
}
// interceptor is used to wrap the original [http.ResponseWriter] to capture
// the status code.
type interceptor struct {
// ResponseWriter is the original writer.
http.ResponseWriter
// statusCode is the captured HTTP response code.
statusCode int
}
// WriteHeader captures the status code before calling the original WriteHeader.
func (i *interceptor) WriteHeader(code int) {
i.statusCode = code
i.ResponseWriter.WriteHeader(code)
}
// Log returns a middleware [Pipe] that logs a summary of each HTTP request.
//
// It captures the final HTTP status code by wrapping the [http.ResponseWriter].
// The log entry is generated at the debug level after the request has been
// handled. It includes the method, URL, status code, duration, and other common
// attributes. To include a request ID in the log, this middleware should be
// placed after the [RequestID] middleware in the chain.
func Log(logger *slog.Logger) Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
incpt := &interceptor{w, http.StatusOK}
next.ServeHTTP(incpt, r)
logger.Debug(
"HTTP request handled",
slog.String("id", GetRequestID(r.Context())),
slog.String("method", r.Method),
slog.String("url", r.URL.String()),
slog.String("remote", r.RemoteAddr),
slog.String("user_agent", r.UserAgent()),
slog.Int("status", incpt.statusCode),
slog.Duration("duration", time.Since(start)),
)
})
}
}
// Volatile returns a middleware [Pipe] that prevents caching of the response.
//
// It sets standard HTTP headers (Cache-Control, Pragma, Expires) to ensure
// clients and proxies always fetch a fresh copy of the resource.
func Volatile() Pipe {
control := strings.Join([]string{
"no-store",
"no-cache",
"must-revalidate",
"proxy-revalidate",
}, ", ")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", control)
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
next.ServeHTTP(w, r)
})
}
}
// SecurityConfig defines the headers applied by the [Secure] middleware.
type SecurityConfig struct {
// STSMaxAge is the maximum age for HSTS in seconds. If 0, the header is
// not set.
STSMaxAge int64
// STSIncludeSubdomains adds the "includeSubDomains" directive to HSTS.
STSIncludeSubdomains bool
// FrameOptions sets the X-Frame-Options header (e.g., "DENY",
// "SAMEORIGIN"). If empty, the header is not set.
FrameOptions string
// NoSniff sets X-Content-Type-Options to "nosniff" if true. This helps
// prevent MIME type sniffing by browsers.
NoSniff bool
// CSP sets the Content-Security-Policy header. If empty, it is not set.
CSP string
// ReferrerPolicy sets the Referrer-Policy header. If empty, it is not set.
ReferrerPolicy string
// PermissionsPolicy sets the Permissions-Policy header. Example:
// "geolocation=(), microphone=()"
PermissionsPolicy string
// CrossOriginOpenerPolicy sets the Cross-Origin-Opener-Policy header.
// Recommended: "same-origin"
CrossOriginOpenerPolicy string
}
// DefaultSecurityConfig provides a baseline configuration.
//
// It enables HSTS for 1 year, disables MIME sniffing, and protects against
// clickjacking by denying framing.
var DefaultSecurityConfig = SecurityConfig{
STSMaxAge: 31536000,
STSIncludeSubdomains: true,
FrameOptions: "DENY",
NoSniff: true,
PermissionsPolicy: "geolocation=(),microphone=(),camera=(),payment=()",
CrossOriginOpenerPolicy: "same-origin",
}
// Secure returns a middleware [Pipe] that sets security-related HTTP headers.
//
// Headers are set based on the provided configuration.
func Secure(cfg SecurityConfig) Pipe {
// Pre-calculate HSTS header to avoid string allocation on every request.
hsts := ""
if cfg.STSMaxAge > 0 {
hsts = "max-age=" + strconv.FormatInt(cfg.STSMaxAge, 10)
if cfg.STSIncludeSubdomains {
hsts += "; includeSubDomains"
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
// 1. Strict-Transport-Security
if hsts != "" {
h.Set("Strict-Transport-Security", hsts)
}
// 2. X-Content-Type-Options
if cfg.NoSniff {
h.Set("X-Content-Type-Options", "nosniff")
}
// 3. X-Frame-Options
if cfg.FrameOptions != "" {
h.Set("X-Frame-Options", cfg.FrameOptions)
}
// 4. Content-Security-Policy
if cfg.CSP != "" {
h.Set("Content-Security-Policy", cfg.CSP)
}
// 5. Referrer-Policy
if cfg.ReferrerPolicy != "" {
h.Set("Referrer-Policy", cfg.ReferrerPolicy)
}
// 6. Permissions-Policy (New)
if cfg.PermissionsPolicy != "" {
h.Set("Permissions-Policy", cfg.PermissionsPolicy)
}
// 7. Cross-Origin-Opener-Policy (New)
if cfg.CrossOriginOpenerPolicy != "" {
h.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
}
// 8. X-Permitted-Cross-Domain-Policies (Hardening for PDF/Flash)
h.Set("X-Permitted-Cross-Domain-Policies", "none")
next.ServeHTTP(w, r)
})
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mock provides an in-memory implementation of the migrate.Driver.
//
// Package mock provides an in-memory implementation of the [migrate.Driver]
// interface designed strictly for unit testing. It is safe for concurrent use
// and allows injecting errors for every operation to test the Migrator's error
// handling and rollback logic.
//
// # Usage
//
// Create a new mock driver and use its exported fields to assert that specific
// database operations (like locking or initialization) were performed by the
// migrator.
//
// Example:
//
// drv := mock.New()
// m := migrate.New(migrate.WithDriver(drv), migrate.WithSource(src))
// _ = m.Up(ctx)
//
// if !drv.IsInit {
// t.Error("expected driver to be initialized")
// }
package mock
import (
"cmp"
"context"
"errors"
"maps"
"slices"
"sync"
"github.com/deep-rent/nexus/internal/schema"
"github.com/deep-rent/nexus/migrate"
)
// Driver is an in-memory implementation of [migrate.Driver].
//
// It is safe for concurrent use and allows injecting errors for every operation
// to test the Migrator's error handling and rollback logic.
type Driver struct {
// mu protects the internal state of the mock driver.
mu sync.Mutex
// records stores the simulated migration state.
records map[uint64]migrate.Record
// IsLocked indicates if the mock advisory lock is currently held.
IsLocked bool
// IsClosed indicates if the driver has been closed.
IsClosed bool
// IsInit indicates if the tracking table initialization was called.
IsInit bool
// ParserFunc allows injecting a custom statement parser.
ParserFunc schema.Parser
// InitErr is returned by the Init method if non-nil.
InitErr error
// LockErr is returned by the Lock method if non-nil.
LockErr error
// UnlockErr is returned by the Unlock method if non-nil.
UnlockErr error
// AppliedErr is returned by the Applied method if non-nil.
AppliedErr error
// ForceErr is returned by the Force method if non-nil.
ForceErr error
// ExecuteErr is returned by the Execute method if non-nil.
ExecuteErr error
// CloseErr is returned by the Close method if non-nil.
CloseErr error
}
// New creates a new in-memory [Driver] with an empty state.
func New() *Driver {
return &Driver{
records: make(map[uint64]migrate.Record),
// Provide a dummy parser that just returns the raw script as a single
// statement.
ParserFunc: func(script []byte) []string {
if len(script) == 0 {
return nil
}
return []string{string(script)}
},
}
}
// Set writes a [migrate.Record] to the in-memory table.
func (d *Driver) Set(rec migrate.Record) {
d.mu.Lock()
defer d.mu.Unlock()
d.records[rec.Version] = rec
}
// Get reads a [migrate.Record] from the in-memory table.
func (d *Driver) Get(version uint64) (migrate.Record, bool) {
d.mu.Lock()
defer d.mu.Unlock()
rec, ok := d.records[version]
return rec, ok
}
// State returns a copy of the in-memory table.
func (d *Driver) State() map[uint64]migrate.Record {
d.mu.Lock()
defer d.mu.Unlock()
out := make(map[uint64]migrate.Record, len(d.records))
maps.Copy(out, d.records)
return out
}
// Parser returns the injected [Driver.ParserFunc].
func (d *Driver) Parser() schema.Parser {
return d.ParserFunc
}
// Init simulates creating the tracking table.
func (d *Driver) Init(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.InitErr != nil {
return d.InitErr
}
d.IsInit = true
return nil
}
// Lock simulates acquiring an exclusive distributed lock.
func (d *Driver) Lock(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.LockErr != nil {
return d.LockErr
}
if d.IsLocked {
return errors.New("mock: already locked")
}
d.IsLocked = true
return nil
}
// Unlock simulates releasing the exclusive lock.
func (d *Driver) Unlock(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.UnlockErr != nil {
return d.UnlockErr
}
if !d.IsLocked {
return errors.New("mock: not locked")
}
d.IsLocked = false
return nil
}
// Applied returns all successfully applied migration records, sorted by
// version.
func (d *Driver) Applied(ctx context.Context) ([]migrate.Record, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.AppliedErr != nil {
return nil, d.AppliedErr
}
out := make([]migrate.Record, 0, len(d.records))
for _, r := range d.records {
out = append(out, r)
}
slices.SortFunc(out, func(a, b migrate.Record) int {
return cmp.Compare(a.Version, b.Version)
})
return out, nil
}
// Force manually sets the database version.
//
// It clears the dirty flag for that version and removes any records greater
// than it.
func (d *Driver) Force(ctx context.Context, version uint64) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.ForceErr != nil {
return d.ForceErr
}
if rec, ok := d.records[version]; ok {
rec.Dirty = false
d.records[version] = rec
}
for v := range d.records {
if v > version {
delete(d.records, v)
}
}
return nil
}
// Execute simulates running a migration script.
//
// If [Driver.ExecuteErr] is set, it correctly simulates a failure by leaving
// the target version in a dirty state.
func (d *Driver) Execute(
ctx context.Context,
script migrate.ParsedScript,
) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.ExecuteErr != nil {
// Simulate the dirty state left behind by a failed migration.
switch script.Direction {
case migrate.Up:
d.records[script.Version] = migrate.Record{
Version: script.Version,
Checksum: script.Checksum,
Dirty: true,
}
case migrate.Down:
if rec, ok := d.records[script.Version]; ok {
rec.Dirty = true
d.records[script.Version] = rec
}
}
return d.ExecuteErr
}
// Simulate successful execution.
switch script.Direction {
case migrate.Up:
d.records[script.Version] = migrate.Record{
Version: script.Version,
Checksum: script.Checksum,
Dirty: false,
}
case migrate.Down:
delete(d.records, script.Version)
}
return nil
}
// Close simulates cleaning up driver resources.
func (d *Driver) Close() error {
d.mu.Lock()
defer d.mu.Unlock()
if d.CloseErr != nil {
return d.CloseErr
}
d.IsClosed = true
return nil
}
// Ensure Driver satisfies the migrate.Driver interface.
var _ migrate.Driver = (*Driver)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Explicitly allow SQL string concatenation:
// #nosec G202
// Package postgres provides a PostgreSQL-specific driver for the migrate
// package.
//
// Package postgres executes database migrations, manages the state of applied
// migrations, and ensures concurrent safety using PostgreSQL advisory locks.
// The driver supports configurable schema and table names for state tracking,
// structured logging, and transactional execution of migration scripts.
//
// # Usage
//
// Initialize the driver with an existing [*sql.DB] connection and optional
// configuration functions.
//
// Example:
//
// db, _ := sql.Open("postgres", "postgres://user:pass@localhost:5432/db")
// drv := postgres.New(db,
// postgres.WithSchema("public"),
// postgres.WithTable("migrations"),
// )
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/deep-rent/nexus/internal/quote"
"github.com/deep-rent/nexus/internal/schema"
"github.com/deep-rent/nexus/migrate"
)
const (
// DefaultTable is the default name for the migration tracking table.
DefaultTable = "migrations"
// DefaultSchema is the default PostgreSQL schema where the tracking table
// resides.
DefaultSchema = "public"
)
// config holds the internal configuration options for the PostgreSQL driver.
type config struct {
// table is the name of the migration tracking table.
table string
// schema is the PostgreSQL schema containing the tracking table.
schema string
// lockID is an optional fixed identifier for advisory locks.
lockID *int64
// lockTimeout is the maximum wait time for acquiring the advisory lock.
lockTimeout time.Duration
// stmtTimeout is the maximum execution time for a single SQL statement.
stmtTimeout time.Duration
// logger is the structured logger for driver activity.
logger *slog.Logger
}
// Option configures a PostgreSQL [Driver] instance.
type Option func(*config)
// WithTable sets a custom name for the migration tracking table.
//
// Empty string values are ignored.
func WithTable(name string) Option {
return func(c *config) {
if name != "" {
c.table = name
}
}
}
// WithSchema sets a custom database schema for the tracking table.
//
// Empty string values are ignored.
func WithSchema(name string) Option {
return func(c *config) {
if name != "" {
c.schema = name
}
}
}
// WithLockID sets a static identifier for the PostgreSQL advisory lock.
//
// If not provided, a random 64-bit identifier is securely generated. This is
// primarily intended for testing lock contention.
func WithLockID(id int64) Option {
return func(c *config) {
c.lockID = &id
}
}
// WithLockTimeout sets the maximum duration to wait for the advisory lock.
//
// If 0, it waits indefinitely (the default behavior).
func WithLockTimeout(timeout time.Duration) Option {
return func(c *config) {
c.lockTimeout = timeout
}
}
// WithStatementTimeout sets a maximum duration for individual SQL statements.
//
// If 0, no timeout is applied (the default behavior).
func WithStatementTimeout(timeout time.Duration) Option {
return func(c *config) {
c.stmtTimeout = timeout
}
}
// WithLogger injects a structured logger to record driver operations.
//
// Nil values are ignored, falling back to [slog.Default].
func WithLogger(logger *slog.Logger) Option {
return func(c *config) {
if logger != nil {
c.logger = logger
}
}
}
// Driver implements the [migrate.Driver] interface for PostgreSQL.
//
// It manages the database connection, distributed locks, and the execution of
// migration statements.
type Driver struct {
// db is the underlying database connection pool.
db *sql.DB
// lock is a dedicated connection held while the advisory lock is active.
lock *sql.Conn
// table is the unquoted name of the tracking table.
table string
// schema is the unquoted database schema containing the table.
schema string
// ident is the precomputed, safely quoted schema and table identifier.
ident string
// lockID is the identifier used for pg_advisory_lock.
lockID int64
// lockTimeout is the maximum wait time for lock acquisition.
lockTimeout time.Duration
// stmtTimeout is the maximum duration for a single statement.
stmtTimeout time.Duration
// logger records driver operations.
logger *slog.Logger
}
// New creates a new PostgreSQL migration driver.
//
// It uses the provided database connection and options. It generates a unique
// cryptographic identifier for advisory locks to prevent concurrent migration
// conflicts. It panics if the lock identifier generation fails.
func New(db *sql.DB, opts ...Option) *Driver {
cfg := &config{
table: DefaultTable,
schema: DefaultSchema,
logger: slog.Default(),
}
for _, opt := range opts {
opt(cfg)
}
d := &Driver{
db: db,
table: cfg.table,
schema: cfg.schema,
ident: ident(cfg.schema, cfg.table),
lockTimeout: cfg.lockTimeout,
stmtTimeout: cfg.stmtTimeout,
logger: cfg.logger,
}
if cfg.lockID != nil {
d.lockID = *cfg.lockID
} else {
// Generate a random identifier for the table lock.
var b [8]byte
if _, err := rand.Read(b[:]); err != nil {
panic(fmt.Sprintf("postgres: failed to generate random lock ID: %v", err))
}
raw := binary.BigEndian.Uint64(b[:])
d.lockID = int64(raw & 0x7FFFFFFFFFFFFFFF)
}
return d
}
// Parser returns the PostgreSQL-specific statement parser.
//
// This parser safely splits scripts while ignoring semicolons inside string
// literals, comments, and dollar-quoted blocks.
func (d *Driver) Parser() schema.Parser {
return schema.Postgres
}
// Lock acquires an exclusive distributed lock using pg_advisory_lock.
//
// This prevents multiple migrator instances from running concurrently on the
// same database. It holds a dedicated connection for the duration of the lock
// and respects the configured lock timeout.
func (d *Driver) Lock(ctx context.Context) error {
if d.lock != nil {
return errors.New("already locked")
}
d.logger.Debug("Acquiring advisory lock", slog.Int64("id", d.lockID))
conn, err := d.db.Conn(ctx)
if err != nil {
return fmt.Errorf("failed to acquire database connection: %w", err)
}
// Apply lock timeout if configured
if d.lockTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.lockTimeout)
defer cancel()
}
if _, err := conn.ExecContext(
ctx,
"SELECT pg_advisory_lock($1)",
d.lockID,
); err != nil {
if e := conn.Close(); e != nil {
d.logger.Error(
"Failed to close connection after lock failure",
slog.Any("error", e),
)
}
return fmt.Errorf("failed to acquire advisory lock: %w", err)
}
d.lock = conn
d.logger.Info("Advisory lock acquired")
return nil
}
// Unlock releases the advisory lock and returns the connection to the pool.
//
// It releases the lock acquired via pg_advisory_unlock.
func (d *Driver) Unlock(ctx context.Context) error {
if d.lock == nil {
return errors.New("not locked")
}
d.logger.Debug("Releasing advisory lock", slog.Int64("id", d.lockID))
_, err := d.lock.ExecContext(
ctx,
"SELECT pg_advisory_unlock($1)",
d.lockID,
)
e := d.lock.Close()
d.lock = nil
if err != nil {
return fmt.Errorf("failed to release advisory lock: %w", err)
}
d.logger.Info("Advisory lock released")
return e
}
// Init ensures that the tracking table exists in the target schema.
//
// It creates the table with columns for version, checksum, dirty state, and
// application timestamp if it is not already present.
func (d *Driver) Init(ctx context.Context) error {
d.logger.Debug(
"Initializing migration table if missing",
slog.String("name", d.table),
slog.String("schema", d.schema),
)
query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
version BIGINT PRIMARY KEY,
checksum BYTEA NOT NULL DEFAULT '\x',
dirty BOOLEAN NOT NULL DEFAULT false,
applied_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);`, d.ident)
_, err := d.db.ExecContext(ctx, query)
return err
}
// Applied retrieves all successfully applied migration records.
//
// The records are ordered by their version in ascending order.
func (d *Driver) Applied(ctx context.Context) ([]migrate.Record, error) {
d.logger.Debug("Fetching applied migrations")
query := "SELECT version, checksum, dirty FROM " +
d.ident + " ORDER BY version ASC"
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer func() {
if e := rows.Close(); e != nil {
d.logger.Error("Failed to close rows", slog.Any("error", e))
}
}()
var records []migrate.Record
for rows.Next() {
var rec migrate.Record
var checksum []byte
if err := rows.Scan(&rec.Version, &checksum, &rec.Dirty); err != nil {
return nil, err
}
copy(rec.Checksum[:], checksum)
records = append(records, rec)
}
if err := rows.Err(); err != nil {
return nil, err
}
return records, nil
}
// withTx is an internal helper that manages serializable transactions.
func (d *Driver) withTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := d.db.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelSerializable,
})
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if e := tx.Rollback(); e != nil && !errors.Is(e, sql.ErrTxDone) {
d.logger.Error("Failed to rollback transaction", slog.Any("error", e))
}
}()
if err := fn(tx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// Force manually sets the database to the specified version.
//
// It clears the dirty flag for the target version and deletes any migration
// records with a version strictly greater than the target. This is typically
// used to recover from a dirty database state after human intervention.
func (d *Driver) Force(ctx context.Context, version uint64) error {
d.logger.Info("Forcing database version", slog.Uint64("version", version))
return d.withTx(ctx, func(tx *sql.Tx) error {
queryUpdate := "UPDATE " + d.ident + " SET dirty = false WHERE version = $1"
if _, err := tx.ExecContext(ctx, queryUpdate, version); err != nil {
return fmt.Errorf("failed to clear dirty flag: %w", err)
}
queryDelete := "DELETE FROM " + d.ident + " WHERE version > $1"
if _, err := tx.ExecContext(ctx, queryDelete, version); err != nil {
return fmt.Errorf("failed to delete newer versions: %w", err)
}
return nil
})
}
// Execute runs the provided migration script against the database.
//
// It marks the version as dirty, executes the statements, and clears the dirty
// state upon success. If execution fails, the database remains marked as dirty
// to prevent further automated migrations.
func (d *Driver) Execute(
ctx context.Context,
script migrate.ParsedScript,
) error {
d.logger.Info(
"Executing migration",
slog.Uint64("version", script.Version),
slog.String("direction", script.Direction.String()),
)
if err := d.setDirty(
ctx,
script.Version,
script.Direction,
script.Checksum,
); err != nil {
return fmt.Errorf("failed to mark migration as dirty: %w", err)
}
if script.Tx {
d.logger.Debug("Running migration in transaction")
err := d.withTx(ctx, func(tx *sql.Tx) error {
return d.execAll(ctx, tx, script.Statements)
})
if err != nil {
return err
}
} else {
d.logger.Debug("Running migration without transaction")
if err := d.execAll(ctx, d.db, script.Statements); err != nil {
return err
}
}
if err := d.setClean(ctx, script.Version, script.Direction); err != nil {
return fmt.Errorf("failed to clear dirty state: %w", err)
}
return nil
}
// runner is an interface satisfied by both [*sql.DB] and [*sql.Tx].
type runner interface {
// ExecContext executes a query without returning any rows.
ExecContext(
ctx context.Context,
query string,
args ...any,
) (sql.Result, error)
}
// execAll iterates through SQL statements and runs them sequentially.
func (d *Driver) execAll(
ctx context.Context,
run runner,
statements []string,
) error {
for i, stmt := range statements {
d.logger.Debug("Executing statement", slog.Int("index", i+1))
if err := d.execOne(ctx, run, stmt); err != nil {
return fmt.Errorf("statement %d failed: %w", i+1, err)
}
}
return nil
}
// execOne isolates the execution of a single statement.
func (d *Driver) execOne(ctx context.Context, run runner, stmt string) error {
if d.stmtTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.stmtTimeout)
defer cancel()
}
_, err := run.ExecContext(ctx, stmt)
return err
}
// setDirty records a migration attempt in the tracking table.
//
// For upward migrations, it inserts or updates a row. For downward migrations,
// it updates the existing row.
func (d *Driver) setDirty(
ctx context.Context,
version uint64,
direction migrate.Direction,
checksum [32]byte,
) error {
d.logger.Debug("Marking migration as dirty", slog.Uint64("version", version))
switch direction {
case migrate.Up:
query := "INSERT INTO " + d.ident +
" (version, checksum, dirty) VALUES ($1, $2, true) " +
"ON CONFLICT (version) DO UPDATE SET dirty = true"
if _, err := d.db.ExecContext(
ctx,
query,
version,
checksum[:],
); err != nil {
return fmt.Errorf("failed to mark migration as dirty: %w", err)
}
case migrate.Down:
query := "UPDATE " + d.ident + " SET dirty = true WHERE version = $1"
if _, err := d.db.ExecContext(
ctx,
query,
version,
); err != nil {
return fmt.Errorf("failed to mark migration as dirty: %w", err)
}
}
return nil
}
// setClean finalizes a successful migration by removing the dirty state.
//
// For upward migrations, it sets dirty to false. For downward migrations, it
// removes the version record entirely.
func (d *Driver) setClean(
ctx context.Context,
version uint64,
direction migrate.Direction,
) error {
d.logger.Debug("Clearing dirty state", slog.Uint64("version", version))
switch direction {
case migrate.Up:
query := "UPDATE " + d.ident + " SET dirty = false WHERE version = $1"
if _, err := d.db.ExecContext(
ctx,
query,
version,
); err != nil {
return fmt.Errorf("failed to clear dirty state: %w", err)
}
case migrate.Down:
query := "DELETE FROM " + d.ident + " WHERE version = $1"
if _, err := d.db.ExecContext(
ctx,
query,
version,
); err != nil {
return fmt.Errorf("failed to remove migration record: %w", err)
}
}
return nil
}
// Close gracefully closes the underlying database connection.
func (d *Driver) Close() error {
d.logger.Debug("Closing database driver")
return d.db.Close()
}
// ident assembles a fully qualified, safely quoted PostgreSQL identifier.
func ident(schema, table string) string {
// Example output: "public"."migrations"
return fmt.Sprintf("%s.%s", escape(schema), escape(table))
}
// escape safely wraps PostgreSQL identifiers in double quotes.
func escape(s string) string {
return quote.Double(strings.ReplaceAll(s, `"`, `""`))
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package migrate provides the core orchestration logic for database
// migrations.
//
// Package migrate manages the loading, sorting, verification, and execution of
// migration files against a database driver. It ensures that migrations are
// applied in a consistent, reproducible order and tracks the state of the
// database to prevent duplicate or conflicting changes.
//
// # Usage
//
// To perform migrations, initialize a source and a driver, then create a new
// [Migrator].
//
// Example:
//
// src := file.New(os.DirFS("./migrations"))
// drv := postgres.New(db)
//
// m := migrate.New(
// migrate.WithSource(src),
// migrate.WithDriver(drv),
// )
//
// if err := m.Up(context.Background()); err != nil {
// log.Fatal("Migration failed:", err)
// }
package migrate
import (
"cmp"
"context"
"crypto/sha256"
"fmt"
"log/slog"
"slices"
"github.com/deep-rent/nexus/internal/schema"
)
// Direction signals whether a migration is being applied or reverted.
type Direction int
const (
// Up indicates a migration that applies changes to the database schema.
Up Direction = iota
// Down indicates a migration that undoes changes made by an Up migration.
Down
)
// String implements [fmt.Stringer].
func (d Direction) String() string {
switch d {
case Up:
return "up"
case Down:
return "down"
default:
return "unknown"
}
}
// Record represents a successfully applied migration stored in the database.
type Record struct {
// Version is the unique sequence number of the applied migration.
Version uint64
// Checksum is the SHA-256 hash of the migration's content, used to detect
// tampering or accidental modification of historical migration files.
Checksum [32]byte
// Dirty indicates if the migration failed mid-execution, leaving the
// database in a potentially inconsistent state that requires manual
// intervention.
Dirty bool
}
// Driver is the interface that database-specific backends must implement.
//
// It abstracts all direct database interactions, locking mechanisms, and state
// tracking away from the core migration logic.
type Driver interface {
// Parser returns a database-specific statement parser to safely split raw
// SQL scripts into individual executable statements.
Parser() schema.Parser
// Init ensures the migration tracking table exists in the database.
Init(ctx context.Context) error
// Lock acquires an exclusive, distributed lock to prevent concurrent
// migrator instances from causing race conditions.
Lock(ctx context.Context) error
// Unlock releases the exclusive distributed lock.
Unlock(ctx context.Context) error
// Applied returns all successfully applied migration records from the
// tracking table, ordered by version in ascending order.
Applied(ctx context.Context) ([]Record, error)
// Force sets the database to the specified version and clears the dirty
// state, effectively ignoring any migrations past that point.
Force(ctx context.Context, version uint64) error
// Execute runs the parsed migration statements and records the state update
// within the tracking table.
Execute(ctx context.Context, script ParsedScript) error
// Close cleans up driver resources and closes database connections.
Close() error
}
// Source provides migrations from an external system (e.g., filesystem).
type Source interface {
// List returns a list of all available migration files.
// The Migrator will handle hashing the content and sorting the results.
List() ([]SourceScript, error)
}
// SourceScript represents an unhashed migration script retrieved from a
// [Source].
type SourceScript struct {
// Version is the unique sequence number of the migration.
Version uint64
// Description is a human-readable summary of the migration's intent.
Description string
// Direction indicates whether this script applies (Up) or reverts (Down)
// changes.
Direction Direction
// Path is the location identifier of the script within the source.
Path string
// Content contains the raw, unparsed SQL script.
Content []byte
// Tx specifies whether the script should be executed within a transaction.
Tx bool
}
// ParsedScript holds the parameters required to execute a migration.
type ParsedScript struct {
// Version is the unique sequence number of the migration.
Version uint64
// Direction indicates whether this script applies (Up) or reverts (Down)
// changes.
Direction Direction
// Checksum is the cryptographic hash of the original script content.
Checksum [32]byte
// Statements contains the individually parsed SQL statements ready for
// execution.
Statements []string
// Tx specifies whether the statements should be executed within a
// transaction.
Tx bool
}
// Migration represents a fully parsed and hashed migration file.
type Migration struct {
// Version is the unique sequence number of the migration.
Version uint64
// Description is a human-readable description of the migration.
Description string
// Direction indicates whether this script applies (Up) or reverts (Down)
// changes.
Direction Direction
// Path is the location identifier of the script within the source from
// which the migration stems.
Path string
// Checksum is the SHA-256 hash of the content.
Checksum [32]byte
// Content is the raw SQL content of the migration, which can contain
// multiple statements.
Content []byte
// Tx indicates whether to run all statements together in a transaction.
Tx bool
}
// Compare returns an integer comparing two migrations to establish a strict
// ordering.
//
// The result will be 0 if m == other, -1 if m < other, and +1 if m > other.
// Migrations are ordered primarily by version in ascending order. If two
// migrations share the same version, they are secondarily ordered by direction
// (so "up" comes before "down") to guarantee deterministic sorting.
func (m Migration) Compare(other Migration) int {
if n := cmp.Compare(m.Version, other.Version); n != 0 {
return n
}
return cmp.Compare(m.Direction, other.Direction)
}
// Migrator orchestrates the execution of database migrations.
type Migrator struct {
// source is the provider for migration files.
source Source
// driver is the backend database implementation.
driver Driver
// dryRun determines if execution should be skipped.
dryRun bool
// logger is the structured logger for migration events.
logger *slog.Logger
}
// Option configures a [Migrator] instance.
type Option func(*Migrator)
// WithSource sets the migration source.
//
// This option is mandatory.
func WithSource(source Source) Option {
return func(m *Migrator) {
m.source = source
}
}
// WithDriver sets the database driver.
//
// This option is mandatory.
func WithDriver(driver Driver) Option {
return func(m *Migrator) {
m.driver = driver
}
}
// WithDryRun enables a mode where the [Migrator] computes checksums and logs.
//
// It logs the parsed statements without executing them against the database.
func WithDryRun(enabled bool) Option {
return func(m *Migrator) {
m.dryRun = enabled
}
}
// WithLogger sets the logger for the migrator.
//
// A nil value will be ignored.
func WithLogger(logger *slog.Logger) Option {
return func(m *Migrator) {
if logger != nil {
m.logger = logger
}
}
}
// New creates a new [Migrator] instance.
//
// It panics if the required dependencies ([Source] and [Driver]) are not
// provided.
func New(opts ...Option) *Migrator {
m := &Migrator{
logger: slog.Default(),
}
for _, opt := range opts {
opt(m)
}
if m.source == nil {
panic("migrate: source is required")
}
if m.driver == nil {
panic("migrate: driver is required")
}
return m
}
// lock is a helper that acquires the driver lock and initializes tracking.
//
// It ensures the tracking table is initialized, executes the provided
// function, and guarantees the lock is released afterward.
func (m *Migrator) lock(
ctx context.Context,
fn func(context.Context) error,
) error {
if err := m.driver.Lock(ctx); err != nil {
return fmt.Errorf("failed to acquire lock: %w", err)
}
defer func() {
if err := m.driver.Unlock(context.Background()); err != nil {
m.logger.Error("Failed to release lock", slog.Any("error", err))
}
}()
if err := m.driver.Init(ctx); err != nil {
return fmt.Errorf("failed to initialize driver: %w", err)
}
return fn(ctx)
}
// files fetches migrations from the source and prepares them.
//
// It calculates cryptographic checksums, maps them to domain objects, and
// strictly sorts them.
func (m *Migrator) files() ([]Migration, error) {
files, err := m.source.List()
if err != nil {
return nil, fmt.Errorf("failed to list source files: %w", err)
}
migrations := make([]Migration, 0, len(files))
for _, raw := range files {
migrations = append(migrations, Migration{
Version: raw.Version,
Description: raw.Description,
Direction: raw.Direction,
Path: raw.Path,
Checksum: sha256.Sum256(raw.Content),
Content: raw.Content,
Tx: raw.Tx,
})
}
slices.SortFunc(migrations, Migration.Compare)
return migrations, nil
}
// filter is a helper to fetch either pending or applied migrations.
func (m *Migrator) filter(ctx context.Context, up bool) ([]Migration, error) {
records, files, err := m.load(ctx)
if err != nil {
return nil, err
}
applied := toLookup(records)
out := make([]Migration, 0, len(files))
for _, f := range files {
if f.Direction == Up && applied[f.Version] == up {
out = append(out, f)
}
}
return out, nil
}
// Up applies all pending migrations in ascending order.
//
// It acquires an exclusive database lock before delegating to the internal
// up implementation.
func (m *Migrator) Up(ctx context.Context) error {
return m.lock(ctx, m.up)
}
// up is the internal implementation of [Migrator.Up].
//
// It determines which migrations are pending and executes them sequentially.
// It assumes the caller has already acquired the necessary database locks.
func (m *Migrator) up(ctx context.Context) error {
// 1. Identify which migrations have not yet been applied.
pending, err := m.Pending(ctx)
if err != nil {
return err
}
// 2. Fast-path return if the database is already fully up to date.
if len(pending) == 0 {
m.logger.Info("Migrations are up to date")
return nil
}
m.logger.Info(
"Applying pending migrations",
slog.Int("count", len(pending)),
)
// 3. Execute each pending migration in strict ascending order.
// If any single migration fails, the loop halts immediately to prevent
// cascading errors and leaves the database in a dirty state for review.
for _, p := range pending {
if err := m.run(ctx, p); err != nil {
return err
}
}
m.logger.Info("All migrations applied successfully")
return nil
}
// Down reverts the most recently applied migration.
//
// It acquires an exclusive database lock before delegating to the internal
// down implementation.
func (m *Migrator) Down(ctx context.Context) error {
return m.lock(ctx, m.down)
}
// down is the internal implementation of [Migrator.Down].
//
// It identifies the most recently applied migration, locates its corresponding
// down script from the source, and executes it. It assumes the caller has
// already acquired the database lock.
func (m *Migrator) down(ctx context.Context) error {
// 1. Fetch the historical record of all applied migrations.
applied, err := m.Applied(ctx)
if err != nil {
return err
}
// 2. Fast-path return if the database is pristine and has no migrations.
if len(applied) == 0 {
m.logger.Info("No applied migrations to revert")
return nil
}
// 3. Isolate the target version to rollback (the last one applied).
last := applied[len(applied)-1]
// 4. Load all parsed files from the source to find the matching down script.
files, err := m.files()
if err != nil {
return err
}
// 5. Scan the available files for the exact version and down direction.
for _, f := range files {
if f.Version == last.Version && f.Direction == Down {
// 6. Execute the rollback script.
err := m.run(ctx, f)
if err == nil {
m.logger.Info(
"Migration reverted successfully",
slog.Uint64("version", f.Version),
)
}
return err
}
}
// 7. Error out if the database claims a version is applied, but the source
// lacks the required file to safely revert it.
return fmt.Errorf(
"down migration file not found for version %d",
last.Version,
)
}
// Force manually sets the database version and clears the dirty flag.
//
// It should be used to resolve a dirty state after human intervention.
func (m *Migrator) Force(ctx context.Context, version uint64) error {
fn := func(c context.Context) error {
if err := m.driver.Force(c, version); err != nil {
return fmt.Errorf("failed to force version: %w", err)
}
m.logger.Info(
"Successfully forced migration version",
slog.Uint64("version", version),
)
return nil
}
return m.lock(ctx, fn)
}
// MigrateTo applies or reverts migrations to reach the target version.
func (m *Migrator) MigrateTo(ctx context.Context, target uint64) error {
fn := func(c context.Context) error {
records, files, err := m.load(c)
if err != nil {
return err
}
applied := toLookup(records)
// Revert applied migrations strictly greater than the target version in
// descending order.
for i := len(records) - 1; i >= 0; i-- {
v := records[i].Version
if v > target {
found := false
for _, f := range files {
if f.Version == v && f.Direction == Down {
if err := m.run(c, f); err != nil {
return err
}
found = true
break
}
}
if !found {
return fmt.Errorf("down migration file not found for version %d", v)
}
applied[v] = false
}
}
// Apply pending migrations less than or equal to the target version in
// ascending order.
for _, f := range files {
if f.Direction == Up && f.Version <= target && !applied[f.Version] {
if err := m.run(c, f); err != nil {
return err
}
applied[f.Version] = true
}
}
return nil
}
return m.lock(ctx, fn)
}
// Pending returns a list of "Up" migrations that have not yet been applied.
func (m *Migrator) Pending(ctx context.Context) ([]Migration, error) {
return m.filter(ctx, false)
}
// Applied returns a list of "Up" migrations that have already been executed.
func (m *Migrator) Applied(ctx context.Context) ([]Migration, error) {
return m.filter(ctx, true)
}
// run reads the migration payload and executes it via the driver.
//
// If dry run is enabled, it logs the statements and skips execution.
func (m *Migrator) run(ctx context.Context, migration Migration) error {
m.logger.Info(
"Running migration",
slog.Uint64("version", migration.Version),
slog.String("description", migration.Description),
slog.String("direction", migration.Direction.String()),
)
parse := m.driver.Parser()
stmts := parse(migration.Content)
if m.dryRun {
m.logger.Info(
"Dry run: skipping execution",
slog.Int("statements", len(stmts)),
)
for i, stmt := range stmts {
m.logger.Debug(
"Dry run statement",
slog.Int("index", i+1),
slog.String("query", stmt),
)
}
return nil
}
err := m.driver.Execute(ctx, ParsedScript{
Version: migration.Version,
Direction: migration.Direction,
Checksum: migration.Checksum,
Statements: stmts,
Tx: migration.Tx,
})
if err != nil {
err = fmt.Errorf("migration %d failed: %w", migration.Version, err)
m.logger.Error("Migration failed", slog.Any("error", err))
return err
}
m.logger.Info(
"Migration completed",
slog.Uint64("version", migration.Version),
)
return nil
}
// load loads applied records and available files.
//
// It ensures that there are no missing files or checksum mismatches for
// previously applied migrations.
func (m *Migrator) load(ctx context.Context) ([]Record, []Migration, error) {
files, err := m.files()
if err != nil {
return nil, nil, err
}
applied, err := m.driver.Applied(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get applied versions: %w", err)
}
ups := make(map[uint64]Migration, len(files))
for _, f := range files {
if f.Direction == Up {
ups[f.Version] = f
}
}
for _, a := range applied {
if a.Dirty {
return nil, nil, fmt.Errorf(
"database is dirty at version %d; manual intervention required",
a.Version,
)
}
f, ok := ups[a.Version]
if !ok {
return nil, nil, fmt.Errorf(
"applied migration %d is missing from source files",
a.Version,
)
}
if a.Checksum != f.Checksum {
return nil, nil, fmt.Errorf(
"checksum mismatch for migration %d: database has %x, file has %x",
a.Version,
a.Checksum,
f.Checksum,
)
}
}
return applied, files, nil
}
// toLookup converts a slice of migration records to a map for quick lookups.
func toLookup(records []Record) map[uint64]bool {
applied := make(map[uint64]bool, len(records))
for _, r := range records {
applied[r.Version] = true
}
return applied
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package file provides a file system-based source for database migrations.
//
// Package file implements the [migrate.Source] interface by reading migration
// scripts from any implementation of the standard library's [fs.FS] interface
// (such as [os.DirFS] or [embed.FS]). It parses filenames to extract the
// migration version, description, execution direction (up/down), and whether
// the migration should run inside a transaction.
//
// # Filename Format
//
// The expected filename format is:
//
// <version>_<description>.<direction>[_notx]<extension>
//
// # Usage
//
// Initialize a source by providing a filesystem and optional configuration.
//
// Example:
//
// // Using an embedded filesystem
// //go:embed sql/*.sql
// var fs embed.FS
//
// src := file.New(fs, file.WithExtension(".sql"), file.WithLogger(logger))
// migrations, err := src.List()
package file
import (
"errors"
"fmt"
"io/fs"
"log/slog"
"strconv"
"strings"
"github.com/deep-rent/nexus/migrate"
)
const (
// DefaultExtension is the default file extension used when searching for
// migration scripts in the file system.
DefaultExtension = ".sql"
)
// Errors explaining why the [Source.Parse] method has failed:
var (
// ErrExtension is returned when a filename does not end with the configured
// file extension.
ErrExtension = errors.New("extension mismatch")
// ErrMissingDirection is returned when a filename lacks the dot separator
// preceding the direction segment.
ErrMissingDirection = errors.New("missing direction segment")
// ErrIllegalDirection is returned when the direction segment is neither
// "up" nor "down".
ErrIllegalDirection = errors.New("illegal direction")
// ErrMissingSeparator is returned when a filename lacks the underscore
// separating the version from the description.
ErrMissingSeparator = errors.New("missing underscore separator")
// ErrInvalidDescription is returned when the description segment of the
// filename is empty.
ErrInvalidDescription = errors.New("invalid description")
// ErrInvalidVersion is returned when the version segment is empty or cannot
// be parsed into an unsigned integer.
ErrInvalidVersion = errors.New("invalid version")
)
// config holds the internal configuration options for the file source.
type config struct {
// ext is the file extension to filter for.
ext string
// logger is the structured logger for reporting skipped files.
logger *slog.Logger
}
// Option configures a [Source] instance.
type Option func(*config)
// WithExtension sets a custom file extension for migration files.
//
// It automatically prepends a leading dot if one is missing. Empty string
// values are ignored.
func WithExtension(ext string) Option {
return func(c *config) {
if ext == "" {
return
}
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
c.ext = ext
}
}
// WithLogger injects a structured logger to record file parsing skipped files.
//
// Nil values are ignored, falling back to [slog.Default].
func WithLogger(logger *slog.Logger) Option {
return func(c *config) {
if logger != nil {
c.logger = logger
}
}
}
// Source implements the [migrate.Source] interface for an [fs.FS].
//
// It scans the file system to discover and parse migration files.
type Source struct {
// dir is the filesystem containing the migration scripts.
dir fs.FS
// ext is the file extension used to filter relevant scripts.
ext string
// logger is the logger used for debugging missed conventions.
logger *slog.Logger
}
// New creates a new [Source] instance that reads from the provided [fs.FS].
//
// Options can be provided to customize behavior, such as changing the expected
// file extension.
func New(dir fs.FS, opts ...Option) *Source {
cfg := &config{
ext: DefaultExtension,
logger: slog.Default(),
}
for _, opt := range opts {
opt(cfg)
}
return &Source{
dir: dir,
ext: cfg.ext,
logger: cfg.logger,
}
}
// Directory returns the underlying file system used by the source.
func (s *Source) Directory() fs.FS {
return s.dir
}
// Extension returns the configured file extension used to identify
// migration scripts.
func (s *Source) Extension() string {
return s.ext
}
// Parse extracts metadata from a given filename.
//
// It extracts the version, description, execution direction, and transaction
// flag. It returns an error if the filename does not match the strict
// <version>_<description>.<direction>[_notx]<extension> format.
func (s *Source) Parse(name string) (
version uint64,
desc string,
direction migrate.Direction,
tx bool,
err error,
) {
// Default to transactional execution unless explicitly disabled.
tx = true
// Strip the configured file extension (e.g., ".sql").
base, found := strings.CutSuffix(name, s.ext)
if !found {
return 0, "", -1, false, ErrExtension
}
// Locate the dot that separates the version/description from the direction.
dot := strings.LastIndexByte(base, '.')
if dot <= 0 {
return 0, "", -1, false, ErrMissingDirection
}
// Extract the direction segment (e.g., "up", "down", or "up_notx").
s2 := base[dot+1:]
// Check for the "_notx" suffix to determine if transactions should be
// disabled.
if disabled, found := strings.CutSuffix(s2, "_notx"); found {
tx = false
s2 = disabled
}
// Map the direction string to the internal direction type.
switch s2 {
case "up":
direction = migrate.Up
case "down":
direction = migrate.Down
default:
return 0, "", 0, false, ErrIllegalDirection
}
// Move the cursor back to the prefix (version and description).
base = base[:dot]
// Split the remaining string into the version and the description.
// We expect the first underscore to be the separator.
s0, s1, found := strings.Cut(base, "_")
if !found {
return 0, "", 0, false, ErrMissingSeparator
}
// Ensure neither the version nor the description segments are empty strings.
if s0 == "" {
return 0, "", 0, false, ErrInvalidVersion
}
if s1 == "" {
return 0, "", 0, false, ErrInvalidDescription
}
// Parse the version segment into an unsigned long.
v, e := strconv.ParseUint(s0, 10, 64)
if e != nil {
return 0, "", 0, false, ErrInvalidVersion
}
// Finalize the version and sanitize the description by restoring spaces.
version = v
desc = strings.ReplaceAll(s1, "_", " ")
return version, desc, direction, tx, nil
}
// List reads the underlying file system and returns all valid migrations.
//
// It parses all files matching the configured extension. Files that do not
// match the naming convention are skipped and logged at the debug level.
func (s *Source) List() ([]migrate.SourceScript, error) {
var scripts []migrate.SourceScript
fn := func(path string, d fs.DirEntry, err error) error {
if err != nil || d.IsDir() {
return err
}
name := d.Name()
version, desc, direction, tx, skipped := s.Parse(name)
if skipped != nil {
s.logger.Debug(
"Skipping file in migration directory",
slog.String("name", name),
slog.String("reason", skipped.Error()),
)
return nil // Ignore files that don't match the naming convention
}
content, err := fs.ReadFile(s.dir, path)
if err != nil {
return fmt.Errorf("failed to read migration file %q: %w", path, err)
}
scripts = append(scripts, migrate.SourceScript{
Version: version,
Description: desc,
Direction: direction,
Path: path,
Content: content,
Tx: tx,
})
return nil
}
err := fs.WalkDir(s.dir, ".", fn)
if err != nil {
return nil, fmt.Errorf("failed to traverse migration directory: %w", err)
}
return scripts, nil
}
// Ensure Source satisfies the migrate.Source interface.
var _ migrate.Source = (*Source)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mock provides an in-memory implementation of the migrate.Source.
//
// Package mock provides an in-memory implementation of the [migrate.Source]
// interface designed strictly for unit testing. It allows developers to
// simulate various migration scenarios, including success paths and error
// conditions, without requiring access to a physical filesystem or external
// storage.
//
// # Usage
//
// Initialize the source with a slice of predefined scripts and use it in place
// of a real source in your migration tests.
//
// Example:
//
// scripts := []migrate.SourceScript{
// {Version: 1, Description: "init", Direction: migrate.Up, Content: []byte("...")},
// }
// src := mock.New(scripts...)
package mock
import (
"github.com/deep-rent/nexus/migrate"
)
// Source is an in-memory implementation of [migrate.Source].
//
// It allows you to pre-define the list of scripts and optionally inject an
// error to test failure paths.
type Source struct {
// Scripts contains the pre-defined migration scripts and should be treated
// as read-only after initialization.
Scripts []migrate.SourceScript
// ListErr is an injectable error used to test failure paths during the List
// operation.
ListErr error
}
// New creates a new in-memory [Source] with the provided scripts.
func New(scripts ...migrate.SourceScript) *Source {
return &Source{
Scripts: scripts,
}
}
// List returns the pre-configured scripts or the injected [Source.ListErr].
func (s *Source) List() ([]migrate.SourceScript, error) {
if s.ListErr != nil {
return nil, s.ListErr
}
// Return a copy to prevent accidental mutation by the caller.
out := make([]migrate.SourceScript, len(s.Scripts))
copy(out, s.Scripts)
return out, nil
}
// Ensure Source satisfies the migrate.Source interface.
var _ migrate.Source = (*Source)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package proxy provides a configurable reverse proxy handler.
//
// Package proxy constructs an [httputil.ReverseProxy], starting with sensible
// defaults, integrating a reusable buffer pool, structured logging, and robust
// error handling via a functional options API.
//
// # Usage
//
// Create a new proxy handler by providing a target URL and optional
// configuration functions.
//
// Example:
//
// target, _ := url.Parse("https://backend.internal")
// proxyHandler := proxy.NewHandler(target,
// proxy.WithFlushInterval(100*time.Millisecond),
// proxy.WithMaxBufferSize(512<<10),
// )
//
// http.ListenAndServe(":8080", proxyHandler)
package proxy
import (
"context"
"errors"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/deep-rent/nexus/internal/buffer"
)
const (
// DefaultMinBufferSize is the default minimum size of pooled buffers (32 KiB).
DefaultMinBufferSize = 32 << 10
// DefaultMaxBufferSize is the default maximum size of pooled buffers (256 KiB).
DefaultMaxBufferSize = 256 << 10
)
// Handler is an alias of [http.Handler] representing a reverse proxy.
type Handler = http.Handler
// NewHandler creates a new reverse proxy handler that routes to the target URL.
//
// The behavior of the proxy can be customized through the given options. It
// avoids the deprecated Director hook in favor of the modern Rewrite API.
func NewHandler(target *url.URL, opts ...HandlerOption) Handler {
cfg := handlerConfig{
transport: http.DefaultTransport.(*http.Transport).Clone(),
flushInterval: 0,
minBufferSize: DefaultMinBufferSize,
maxBufferSize: DefaultMaxBufferSize,
newRewrite: NewRewrite,
newErrorHandler: NewErrorHandler,
logger: slog.Default(),
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.minBufferSize > cfg.maxBufferSize {
cfg.minBufferSize = cfg.maxBufferSize
}
// Construct ReverseProxy directly to avoid the deprecated Director hook
// set by NewSingleHostReverseProxy.
h := &httputil.ReverseProxy{
ErrorHandler: cfg.newErrorHandler(cfg.logger),
Transport: cfg.transport,
BufferPool: buffer.NewPool(cfg.minBufferSize, cfg.maxBufferSize),
FlushInterval: cfg.flushInterval,
}
defaultRewrite := func(pr *httputil.ProxyRequest) {
pr.SetXForwarded()
pr.SetURL(target)
}
h.Rewrite = cfg.newRewrite(defaultRewrite)
return h
}
// RewriteFunc defines a function to modify requests before they go upstream.
//
// The signature matches [httputil.ReverseProxy.Rewrite].
type RewriteFunc func(*httputil.ProxyRequest)
// RewriteFactory creates a [RewriteFunc] using the provided original
// [RewriteFunc].
//
// The returned [RewriteFunc] may call original to retain its behavior.
type RewriteFactory = func(original RewriteFunc) RewriteFunc
// NewRewrite is the default [RewriteFactory] for the proxy.
//
// It returns the original [RewriteFunc] unmodified. The default rewrite already
// sets X-Forwarded-Host, X-Forwarded-Proto, and X-Forwarded-For headers, and
// correctly rewrites the Host header to match the target.
func NewRewrite(original RewriteFunc) RewriteFunc {
return original
}
// ErrorHandler defines a function for handling proxy operation errors.
//
// The signature matches [httputil.ReverseProxy.ErrorHandler].
type ErrorHandler = func(http.ResponseWriter, *http.Request, error)
// ErrorHandlerFactory creates an [ErrorHandler] using the provided logger.
//
// It receives the configured logger to be used for error reporting.
type ErrorHandlerFactory = func(*slog.Logger) ErrorHandler
// NewErrorHandler is the default [ErrorHandlerFactory] for the proxy.
//
// It creates an error handler that logs upstream errors and maps them to
// appropriate HTTP status codes, while silencing client-initiated disconnects.
func NewErrorHandler(logger *slog.Logger) ErrorHandler {
return func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
// Silence client-initiated disconnects; there's nothing useful to send
return
}
status := http.StatusBadGateway
method, uri := r.Method, r.RequestURI
if errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, http.ErrHandlerTimeout) {
status = http.StatusGatewayTimeout
logger.ErrorContext(
r.Context(),
"Upstream request timed out",
slog.String("method", method),
slog.String("uri", uri),
)
} else {
logger.ErrorContext(
r.Context(),
"Upstream request failed",
slog.String("method", method),
slog.String("uri", uri),
slog.Any("error", err),
)
}
w.WriteHeader(status)
}
}
// handlerConfig holds the configurable settings for the proxy handler.
type handlerConfig struct {
// transport handles the network communication with the upstream.
transport *http.Transport
// flushInterval is the periodic flush interval for response body copying.
flushInterval time.Duration
// minBufferSize is the minimum size of pooled buffers.
minBufferSize int
// maxBufferSize is the maximum size of pooled buffers.
maxBufferSize int
// newRewrite is the factory for creating the request rewrite function.
newRewrite RewriteFactory
// newErrorHandler is the factory for creating the error handling function.
newErrorHandler ErrorHandlerFactory
// logger is the structured logger for error reporting.
logger *slog.Logger
}
// HandlerOption defines a function for setting reverse proxy options.
type HandlerOption func(*handlerConfig)
// WithTransport sets the [http.Transport] for upstream requests.
//
// Use this option to tune connection pooling, timeouts, and keep-alives. If nil
// is given, this option is ignored.
func WithTransport(t *http.Transport) HandlerOption {
return func(cfg *handlerConfig) {
if t != nil {
cfg.transport = t
}
}
}
// WithFlushInterval specifies the periodic flush interval for the response.
//
// A zero value (default) disables periodic flushing. A negative value tells the
// proxy to flush immediately after each write. Adjust this if you observe high
// latencies for responses buffered by the proxy.
func WithFlushInterval(d time.Duration) HandlerOption {
return func(cfg *handlerConfig) {
cfg.flushInterval = d
}
}
// WithMinBufferSize specifies the minimum size of pooled buffers.
//
// Non-positive values are ignored. The value will be capped at MaxBufferSize.
// Adapt this if you know from profiling that most responses are larger than the
// default 32 KiB.
func WithMinBufferSize(n int) HandlerOption {
return func(cfg *handlerConfig) {
if n > 0 {
cfg.minBufferSize = n
}
}
}
// WithMaxBufferSize specifies the maximum size of buffers to be pooled.
//
// Buffers that grow larger than this size will be discarded after use to
// prevent memory bloat. If your P95 response size is larger than this value,
// the pool will be ineffective.
func WithMaxBufferSize(n int) HandlerOption {
return func(cfg *handlerConfig) {
if n > 0 {
cfg.maxBufferSize = n
}
}
}
// WithRewrite provides a custom [RewriteFactory] for the proxy.
//
// If nil is given, this option is ignored. By default, [NewRewrite] is used.
func WithRewrite(f RewriteFactory) HandlerOption {
return func(cfg *handlerConfig) {
if f != nil {
cfg.newRewrite = f
}
}
}
// WithErrorHandler provides a custom [ErrorHandlerFactory] for the proxy.
//
// If nil is given, this option is ignored. By default, [NewErrorHandler] is
// used.
func WithErrorHandler(f ErrorHandlerFactory) HandlerOption {
return func(cfg *handlerConfig) {
if f != nil {
cfg.newErrorHandler = f
}
}
}
// WithLogger sets the logger to be used by the proxy's [ErrorHandler].
//
// If nil is given, this option is ignored. The default error handler uses this
// logger for capturing upstream errors.
func WithLogger(logger *slog.Logger) HandlerOption {
return func(cfg *handlerConfig) {
if logger != nil {
cfg.logger = logger
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package retry is an http.RoundTripper middleware for automatic retries.
//
// Package retry provides an [http.RoundTripper] middleware that provides
// automatic, policy-driven retries for HTTP requests. It wraps an existing
// [http.RoundTripper] (such as [http.DefaultTransport]) and intercepts requests
// to apply retry logic. The decision to retry is controlled by a [Policy], and
// the delay between attempts is determined by a [backoff.Strategy].
//
// # Usage
//
// A new transport is created with [NewTransport], configured with functional
// options like [WithAttemptLimit] and [WithBackoff].
//
// Example:
//
// // Retry up to 3 times with exponential backoff starting at 1 second.
// transport := retry.NewTransport(
// http.DefaultTransport,
// retry.WithAttemptLimit(3),
// retry.WithBackoff(backoff.New(
// backoff.WithMinDelay(1*time.Second),
// )),
// )
//
// client := &http.Client{Transport: transport}
//
// // This request will be retried automatically on temporary failures.
// res, err := client.Get("http://example.com/flaky")
// if err != nil {
// slog.Error("Request failed after all retries", "error", err)
// return
// }
// defer res.Body.Close()
package retry
import (
"context"
"errors"
"io"
"log/slog"
"net"
"net/http"
"time"
"github.com/deep-rent/nexus/backoff"
"github.com/deep-rent/nexus/header"
)
// Attempt encapsulates the state of a single HTTP request attempt.
//
// It is passed to a [Policy] to determine if a retry is warranted.
type Attempt struct {
// Request is the original HTTP request being attempted.
Request *http.Request
// Response is the result of the attempt, if one was received.
Response *http.Response
// Error is the error returned by the transport, if any.
Error error
// Count is the number of the current attempt (starting at 1).
Count int
}
// Idempotent reports whether the request can be safely retried.
//
// It considers standard HTTP methods that are idempotent according to RFC 7231,
// such as GET, HEAD, OPTIONS, TRACE, PUT, and DELETE.
func (a Attempt) Idempotent() bool {
switch a.Request.Method {
case
http.MethodGet,
http.MethodHead,
http.MethodOptions,
http.MethodTrace,
http.MethodPut,
http.MethodDelete:
return true
default:
return false
}
}
// Temporary reports whether the response indicates a server-side temporary
// failure.
//
// This is determined by specific HTTP status codes that suggest the request
// might succeed if retried, such as 408, 429, 500, 502, 503, and 504.
func (a Attempt) Temporary() bool {
if a.Response != nil {
switch a.Response.StatusCode {
case
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true
}
}
return false
}
// Transient reports whether the error suggests a temporary network-level issue.
//
// It returns true for network timeouts and unexpected EOF errors. It returns
// false for context cancellations ([context.Canceled],
// [context.DeadlineExceeded]), as these should not be retried.
func (a Attempt) Transient() bool {
if a.Error == nil ||
errors.Is(a.Error, context.Canceled) ||
errors.Is(a.Error, context.DeadlineExceeded) {
return false
}
if errors.Is(a.Error, io.ErrUnexpectedEOF) || errors.Is(a.Error, io.EOF) {
return true
}
var err net.Error
return errors.As(a.Error, &err) && err.Timeout()
}
// Policy is the decision-making function that determines whether to retry.
//
// It is invoked after each attempt with the corresponding [Attempt] details. It
// returns true to schedule a retry or false to stop and return the last result.
type Policy func(a Attempt) bool
// LimitAttempts decorates a [Policy] to enforce a maximum attempt limit.
//
// It short-circuits the decision, returning false if the attempt count has
// reached the limit n. Otherwise, it delegates the decision to the wrapped
// policy. A limit of 1 disables retries.
func (p Policy) LimitAttempts(n int) Policy {
if n <= 0 {
return p
}
return func(a Attempt) bool {
return a.Count < n && p(a)
}
}
// DefaultPolicy provides a safe and sensible default retry strategy.
//
// It enters the retry loop only for idempotent requests that have resulted in a
// temporary server error or a transient network error such as a timeout.
func DefaultPolicy() Policy {
return func(a Attempt) bool {
return a.Idempotent() && (a.Temporary() || a.Transient())
}
}
// transport wraps an underlying [http.RoundTripper] to provide automatic
// retries.
type transport struct {
// next is the underlying transport used to send requests.
next http.RoundTripper
// policy determines if a retry is allowed.
policy Policy
// backoff calculates the delay between retry attempts.
backoff backoff.Strategy
// logger handles debug and warning log messages.
logger *slog.Logger
// now is the time source for calculating delays.
now func() time.Time
}
// RoundTrip executes an HTTP transaction with retry logic.
//
// For a request to be retryable, its body must be rewindable via
// [http.Request.GetBody]. On intermediary failed attempts, the response body is
// fully read and closed to allow connection reuse. The loop respects context
// cancellation.
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
var (
res *http.Response
err error
count int
)
defer t.backoff.Done()
rewindable := req.GetBody != nil
for {
count++
// If this is a retry and the body is rewindable, obtain a new reader.
if count > 1 && rewindable {
var e error
req.Body, e = req.GetBody()
if e != nil {
// Cannot rewind the body, so we must stop here.
return nil, e
}
}
res, err = t.next.RoundTrip(req)
// Ask the policy if we should retry.
if !t.policy(Attempt{
Request: req,
Response: res,
Error: err,
Count: count,
}) {
break // Success or policy decided to exit
}
// Check if the request body is rewindable. If not, we must stop here.
// This is checked after the policy to ensure the policy still gets notified
// of the attempt.
if req.Body != nil && !rewindable {
break
}
// If retrying, drain and close the previous response body.
if res != nil && res.Body != nil {
if _, err := io.Copy(io.Discard, res.Body); err != nil {
t.logger.Warn("Failed to drain response body", slog.Any("error", err))
}
if err := res.Body.Close(); err != nil {
t.logger.Warn("Failed to close response body", slog.Any("error", err))
}
}
delay := t.backoff.Next()
if res != nil {
if d := header.Throttle(res.Header, t.now); d != 0 {
// Use the longer of the two delays to respect both the server
// and our own backoff policy.
delay = max(delay, d)
}
}
if ctx := req.Context(); t.logger.Enabled(ctx, slog.LevelDebug) {
attrs := []any{
slog.Int("attempt", count),
slog.Duration("delay", delay),
slog.String("method", req.Method),
slog.String("url", req.URL.String()),
}
if err != nil {
attrs = append(attrs, slog.Any("error", err))
}
if res != nil {
attrs = append(attrs, slog.Int("status", res.StatusCode))
}
t.logger.DebugContext(ctx, "Request attempt failed, retrying", attrs...)
}
if delay <= 0 {
continue // Retry without delay
}
// Wait for the delay, respecting context cancellation.
select {
case <-time.After(delay):
continue // Proceed to next attempt
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
return res, err
}
// Ensure transport satisfies the http.RoundTripper interface.
var _ http.RoundTripper = (*transport)(nil)
// NewTransport creates and returns a new retrying [http.RoundTripper].
//
// It wraps an existing transport and retries requests based on the configured
// policy and backoff strategy.
func NewTransport(
next http.RoundTripper,
opts ...Option,
) http.RoundTripper {
cfg := config{
policy: DefaultPolicy(),
limit: 0,
backoff: backoff.Constant(0),
logger: slog.Default(),
now: time.Now,
}
for _, opt := range opts {
opt(&cfg)
}
return &transport{
next: next,
policy: cfg.policy.LimitAttempts(cfg.limit),
backoff: cfg.backoff,
logger: cfg.logger,
now: cfg.now,
}
}
// config holds the configuration parameters supplied via functional options.
type config struct {
// policy is the base retry logic.
policy Policy
// limit is the maximum number of attempts.
limit int
// backoff is the strategy for inter-attempt delays.
backoff backoff.Strategy
// logger is the structured logger.
logger *slog.Logger
// now is the clock used for timing calculations.
now func() time.Time
}
// Option is a function that configures the retry transport.
type Option func(*config)
// WithPolicy sets the retry policy used by the transport.
//
// If not provided, [DefaultPolicy] is used. A nil value is ignored.
func WithPolicy(policy Policy) Option {
return func(c *config) {
if policy != nil {
c.policy = policy
}
}
}
// WithAttemptLimit sets the maximum number of attempts for a request.
//
// This includes the initial attempt. A value of 3 means one initial attempt and
// up to two retries. If the value is 0 or less, no limit is enforced.
func WithAttemptLimit(n int) Option {
return func(c *config) {
c.limit = n
}
}
// WithBackoff sets the strategy for calculating the delay between retries.
//
// If not provided, there is no delay between attempts. A nil value is ignored.
func WithBackoff(strategy backoff.Strategy) Option {
return func(c *config) {
if strategy != nil {
c.backoff = strategy
}
}
}
// WithLogger sets the logger for debug messages.
//
// If not provided, [slog.Default] is used. A nil value is ignored.
func WithLogger(logger *slog.Logger) Option {
return func(c *config) {
if logger != nil {
c.logger = logger
}
}
}
// WithClock provides a custom time source, primarily for testing.
//
// If not provided, [time.Now] is used. A nil value is ignored.
func WithClock(now func() time.Time) Option {
return func(c *config) {
if now != nil {
c.now = now
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package router provides a lightweight, JSON-centric wrapper around Go's
// native [http.ServeMux].
//
// It simplifies building JSON APIs by offering a consolidated "Exchange" object
// for handling requests and responses, standardized error formatting, and a
// middleware chaining mechanism.
//
// # Basic Usage
//
// 1. Setup the router with options:
//
// Example:
//
// logger := log.New()
// r := router.New(
// router.WithLogger(logger),
// router.WithMiddleware(router.Log(logger)),
// )
//
// 2. Define a handler:
//
// Example:
//
// r.HandleFunc("POST /users", func(e *router.Exchange) error {
// var req CreateUserRequest
// if err := e.BindJSON(&req); err != nil {
// return err
// }
// return e.JSON(http.StatusCreated, UserResponse{ID: "123"})
// })
//
// 3. Start the server:
//
// Example:
//
// http.ListenAndServe(":8080", r)
package router
import (
"context"
"encoding/json/v2"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/middleware"
"github.com/deep-rent/nexus/middleware/cors"
"github.com/deep-rent/nexus/middleware/gzip"
"github.com/deep-rent/nexus/valid"
)
// Standard error reasons used for machine-readable error codes.
const (
// ReasonWrongType indicates that the request had an unsupported content type.
ReasonWrongType = "wrong_type"
// ReasonEmptyBody indicates that the request body was empty.
ReasonEmptyBody = "empty_body"
// ReasonParseJSON indicates that there was an error parsing the JSON body.
ReasonParseJSON = "parse_json"
// ReasonParseForm indicates that there was an error parsing form data.
ReasonParseForm = "parse_form"
// ReasonValidationFailed indicates that input validation failed.
ReasonValidationFailed = "validation_failed"
// ReasonServerError indicates that an unexpected internal error occurred.
ReasonServerError = "server_error"
)
// Standard media types used in the Content-Type header.
const (
// MediaTypeJSON is the media type for JSON content.
MediaTypeJSON = "application/json"
// MediaTypeForm is the media type for URL-encoded form data.
MediaTypeForm = "application/x-www-form-urlencoded"
)
// ResponseWriter extends [http.ResponseWriter] with introspection capabilities.
//
// It allows handlers and middleware to check if the response headers have
// already been written, which is crucial for robust error handling.
type ResponseWriter interface {
http.ResponseWriter
// Status returns the HTTP status code written, or 0 if not written yet.
Status() int
// Closed reports whether the headers have already been written.
Closed() bool
// Unwrap returns the underlying [http.ResponseWriter].
Unwrap() http.ResponseWriter
}
// NewResponseWriter wraps an [http.ResponseWriter] into a [ResponseWriter].
func NewResponseWriter(w http.ResponseWriter) ResponseWriter {
return &responseWriter{
ResponseWriter: w,
status: 0,
}
}
// responseWriter is the concrete implementation of [ResponseWriter].
type responseWriter struct {
// ResponseWriter is the underlying standard writer.
http.ResponseWriter
// status stores the HTTP response code once committed.
status int
}
// WriteHeader implements [ResponseWriter].
func (rw *responseWriter) WriteHeader(code int) {
if rw.status != 0 {
return
}
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// Write implements [ResponseWriter].
func (rw *responseWriter) Write(b []byte) (int, error) {
if rw.status == 0 {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
// Status implements [ResponseWriter].
func (rw *responseWriter) Status() int {
return rw.status
}
// Closed implements [ResponseWriter].
func (rw *responseWriter) Closed() bool {
return rw.status != 0
}
// Unwrap implements [ResponseWriter].
func (rw *responseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}
var _ http.ResponseWriter = (*responseWriter)(nil)
// Error describes the standardized shape of API errors returned to clients.
//
// Handlers can return this struct directly to control the HTTP status code
// and error details. If a handler returns a standard Go error, the [Router]
// will wrap it in a generic internal server error.
type Error struct {
// Status is the HTTP status code (e.g., 400, 404, 500).
Status int `json:"status"`
// Reason is a short string identifying the error type.
Reason string `json:"reason"`
// Description is a human-readable explanation of the error cause.
Description string `json:"description"`
// ID is a unique identifier of the specific occurrence for tracing.
ID string `json:"id,omitempty"`
// Context contains arbitrary additional data about the error.
Context any `json:"context,omitempty"`
// Cause is the underlying error that triggered this error.
Cause error `json:"-"`
}
// Error implements the standard [error] interface.
func (e *Error) Error() string {
return e.Reason + ": " + e.Description
}
// Exchange acts as a context object for a single HTTP request/response cycle.
//
// It wraps the underlying [*http.Request] and [http.ResponseWriter] to provide
// convenient helper methods for common API tasks, such as parsing JSON,
// reading parameters, and writing structured responses.
type Exchange struct {
// R is the incoming HTTP request.
R *http.Request
// W is a writer for the outgoing HTTP response.
W ResponseWriter
// jsonOpts is inherited from the parent Router.
jsonOpts []json.Options
// errorHandler allows middlewares to trigger standardized error resolution.
errorHandler ErrorHandler
}
// Context returns the request's context.
func (e *Exchange) Context() context.Context { return e.R.Context() }
// Method returns the HTTP method (GET, POST, etc.) of the request.
func (e *Exchange) Method() string { return e.R.Method }
// URL returns the full URL of the request.
func (e *Exchange) URL() *url.URL { return e.R.URL }
// Path returns the URL path of the request.
func (e *Exchange) Path() string { return e.R.URL.Path }
// Param retrieves a path parameter by name.
//
// This relies on Go 1.22+ routing patterns (e.g., "GET /users/{id}").
func (e *Exchange) Param(name string) string { return e.R.PathValue(name) }
// Query parses the URL query parameters of the request.
func (e *Exchange) Query() url.Values { return e.R.URL.Query() }
// Header returns the HTTP headers of the request.
func (e *Exchange) Header() http.Header { return e.R.Header }
// GetHeader retrieves a specific header value from the request.
func (e *Exchange) GetHeader(key string) string { return e.R.Header.Get(key) }
// SetHeader sets a specific header value in the response.
func (e *Exchange) SetHeader(key, value string) { e.W.Header().Set(key, value) }
// BindJSON decodes the request body into v.
//
// This method verifies that the media type is "application/json", checks that
// the payload is not empty, unmarshals the JSON, and validates the input using
// [valid.Test].
func (e *Exchange) BindJSON(v any) *Error {
if t := header.MediaType(e.R.Header); t != MediaTypeJSON {
return &Error{
Status: http.StatusUnsupportedMediaType,
Reason: ReasonWrongType,
Description: "content-type must be " + MediaTypeJSON,
}
}
if e.R.Body == nil || e.R.Body == http.NoBody {
return &Error{
Status: http.StatusBadRequest,
Reason: ReasonEmptyBody,
Description: "empty request body",
}
}
if err := json.UnmarshalRead(e.R.Body, v, e.jsonOpts...); err != nil {
return &Error{
Status: http.StatusBadRequest,
Reason: ReasonParseJSON,
Description: "could not parse JSON body",
}
}
if err := valid.Test(v); err != nil {
if ctx, ok := errors.AsType[valid.Error](err); ok {
return &Error{
Status: http.StatusBadRequest,
Reason: ReasonValidationFailed,
Description: fmt.Sprintf("input violates %d constraints", len(ctx)),
Context: ctx,
}
}
return &Error{
Status: http.StatusInternalServerError,
Reason: ReasonServerError,
Description: "an unexpected error occurred during input validation",
Cause: err,
}
}
return nil
}
// ReadForm parses the request body as URL-encoded form data.
//
// Unlike standard [http.Request.FormValue], this strictly accesses PostForm
// (the request body), ignoring URL query parameters.
func (e *Exchange) ReadForm() (url.Values, *Error) {
if t := header.MediaType(e.R.Header); t != MediaTypeForm {
return nil, &Error{
Status: http.StatusUnsupportedMediaType,
Reason: ReasonWrongType,
Description: "content-type must be " + MediaTypeForm,
}
}
if err := e.R.ParseForm(); err != nil {
return nil, &Error{
Status: http.StatusBadRequest,
Reason: ReasonParseForm,
Description: "malformed form data",
}
}
return e.R.PostForm, nil
}
// JSON encodes v as JSON and writes it to the response.
//
// It automatically sets the Content-Type header to [MediaTypeJSON] if it has
// not already been set.
func (e *Exchange) JSON(code int, v any) error {
buf, err := json.Marshal(v, e.jsonOpts...)
if err != nil {
return err
}
if e.W.Header().Get("Content-Type") == "" {
e.SetHeader("Content-Type", MediaTypeJSON)
}
e.Status(code)
_, err = e.W.Write(buf)
return err
}
// Form writes the values as URL-encoded form data.
//
// It automatically sets the Content-Type header to [MediaTypeForm] if it has
// not already been set.
func (e *Exchange) Form(code int, v url.Values) error {
if e.W.Header().Get("Content-Type") == "" {
e.SetHeader("Content-Type", MediaTypeForm)
}
e.Status(code)
_, err := e.W.Write([]byte(v.Encode()))
return err
}
// Status sends an HTTP response header with the provided status code.
//
// Note: Calling this commits the response headers. It is primarily used for
// empty responses like HTTP 204 (No Content).
func (e *Exchange) Status(code int) {
e.W.WriteHeader(code)
}
// NoContent sends a HTTP 204 No Content response.
func (e *Exchange) NoContent() {
e.Status(http.StatusNoContent)
}
// Redirect replies to the request with a redirect to url.
func (e *Exchange) Redirect(url string, code int) error {
http.Redirect(e.W, e.R, url, code)
return nil
}
// RedirectTo constructs a URL with query parameters and redirects the client.
func (e *Exchange) RedirectTo(base string, params url.Values, code int) error {
u, err := url.Parse(base)
if err != nil {
return &Error{
Status: http.StatusInternalServerError,
Reason: ReasonServerError,
Description: "invalid redirect target",
}
}
q := u.Query()
for k, vs := range params {
for _, v := range vs {
q.Add(k, v)
}
}
u.RawQuery = q.Encode()
http.Redirect(e.W, e.R, u.String(), code)
return nil
}
// Handler defines the interface for HTTP request handlers used by the [Router].
type Handler interface {
// ServeHTTP processes an HTTP request encapsulated in the Exchange object.
ServeHTTP(e *Exchange) error
}
// HandlerFunc defines the function signature for HTTP request handlers.
type HandlerFunc func(e *Exchange) error
// ServeHTTP satisfies the [Handler] interface.
func (f HandlerFunc) ServeHTTP(e *Exchange) error { return f(e) }
// Ensure HandlerFunc implements Handler.
var _ Handler = HandlerFunc(nil)
// ErrorHandler defines a function that handles errors returned by routes.
type ErrorHandler func(e *Exchange, err error)
// Middleware defines a function that wraps a [Handler].
//
// It allows custom logic to be executed before and/or after the next handler.
// Unlike standard HTTP middleware, this natively supports returning API errors.
type Middleware func(Handler) Handler
// Chain combines a handler with multiple [Middleware] functions.
//
// The functions are applied in reverse order, meaning the first middleware in
// the list is the outermost and executes first.
func Chain(h Handler, mws ...Middleware) Handler {
for i := len(mws) - 1; i >= 0; i-- {
if mw := mws[i]; mw != nil {
h = mw(h)
}
}
return h
}
// Wrap converts a standard [http.Handler] into a router [Handler].
func Wrap(h http.Handler) Handler {
return HandlerFunc(func(e *Exchange) error {
h.ServeHTTP(e.W, e.R)
return nil
})
}
// Adapt converts a standard [middleware.Pipe] into a [Middleware].
//
// This bridges low-level HTTP transport middlewares into the router's
// ecosystem. It ensuring that any modifications made to the request or response
// writer by the transport middleware are preserved.
func Adapt(pipe middleware.Pipe) Middleware {
return func(next Handler) Handler {
return HandlerFunc(func(e *Exchange) error {
var nextErr error
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e.R = r
if rw, ok := w.(ResponseWriter); ok {
e.W = rw
} else {
e.W = NewResponseWriter(w)
}
nextErr = next.ServeHTTP(e)
// Resolve the error immediately so transport middlewares
// (like Logger) observe the correct HTTP status code.
if nextErr != nil && e.errorHandler != nil {
e.errorHandler(e, nextErr)
nextErr = nil // Prevent double-handling upstream
}
})
pipe(h).ServeHTTP(e.W, e.R)
return nextErr
})
}
}
// Recover mirrors [middleware.Recover] for use in the router.
func Recover(logger *slog.Logger) Middleware {
return Adapt(middleware.Recover(logger))
}
// RequestID mirrors [middleware.RequestID] for use in the router.
func RequestID() Middleware {
return Adapt(middleware.RequestID())
}
// Log mirrors [middleware.Log] for use in the router.
func Log(logger *slog.Logger) Middleware {
return Adapt(middleware.Log(logger))
}
// Volatile mirrors [middleware.Volatile] for use in the router.
func Volatile() Middleware {
return Adapt(middleware.Volatile())
}
// Secure mirrors [middleware.Secure] for use in the router.
func Secure(cfg middleware.SecurityConfig) Middleware {
return Adapt(middleware.Secure(cfg))
}
// CORS mirrors the middleware created by [cors.New] for use in the router.
func CORS(opts ...cors.Option) Middleware {
return Adapt(cors.New(opts...))
}
// Gzip mirrors the middleware created by [gzip.New] for use in the router.
func Gzip(opts ...gzip.Option) Middleware {
return Adapt(gzip.New(opts...))
}
// Option defines a functional configuration option for the [Router].
type Option func(*Router)
// WithMiddleware adds global middleware to the [Router].
func WithMiddleware(mws ...Middleware) Option {
return func(r *Router) {
r.mws = append(r.mws, mws...)
}
}
// WithMaxBodySize sets the maximum allowed size for request bodies.
func WithMaxBodySize(bytes int64) Option {
return func(r *Router) {
r.maxBytes = bytes
}
}
// WithJSONOptions sets custom JSON options for the [Router].
func WithJSONOptions(opts ...json.Options) Option {
return func(r *Router) {
r.jsonOpts = opts
}
}
// WithErrorHandler sets a custom error handler.
func WithErrorHandler(h ErrorHandler) Option {
return func(r *Router) {
if h != nil {
r.errorHandler = h
}
}
}
// WithLogger updates the default error handler to use the given logger.
func WithLogger(logger *slog.Logger) Option {
return func(r *Router) {
if logger != nil {
r.errorHandler = defaultErrorHandler(logger)
}
}
}
// Router represents an HTTP request router with middleware support.
type Router struct {
// Mux is the underlying standard [*http.ServeMux].
Mux *http.ServeMux
// mws is the global slice of middleware.
mws []Middleware
// maxBytes is the maximum request body size limit.
maxBytes int64
// jsonOpts are the standard JSON options used for I/O.
jsonOpts []json.Options
// errorHandler processes errors returned by handlers.
errorHandler ErrorHandler
}
// New creates a new [Router] instance with the provided options.
func New(opts ...Option) *Router {
r := &Router{
Mux: http.NewServeMux(),
mws: nil,
errorHandler: defaultErrorHandler(slog.Default()),
}
for _, opt := range opts {
opt(r)
}
return r
}
// ServeHTTP satisfies the [http.Handler] interface.
func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) {
r.Mux.ServeHTTP(res, req)
}
// Handle registers a new route with a pattern and handler.
//
// The pattern must follow Go 1.22+ syntax. The handler is wrapped with the
// Router's global middleware and any local middleware provided.
func (r *Router) Handle(
pattern string,
handler Handler,
mws ...Middleware,
) {
local := make([]Middleware, 0, len(r.mws)+len(mws))
local = append(local, r.mws...)
local = append(local, mws...)
chained := Chain(handler, local...)
h := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
if r.maxBytes > 0 {
req.Body = http.MaxBytesReader(res, req.Body, r.maxBytes)
}
e := &Exchange{
R: req,
W: NewResponseWriter(res),
jsonOpts: r.jsonOpts,
errorHandler: r.errorHandler,
}
if err := chained.ServeHTTP(e); err != nil {
r.errorHandler(e, err)
}
})
r.Mux.Handle(pattern, h)
}
// HandleFunc is a convenience wrapper for [Router.Handle].
func (r *Router) HandleFunc(
pattern string,
fn func(*Exchange) error,
mws ...Middleware,
) {
r.Handle(pattern, HandlerFunc(fn), mws...)
}
// Mount registers a standard [http.Handler] under a pattern.
func (r *Router) Mount(pattern string, handler http.Handler) {
r.Handle(pattern, Wrap(handler))
}
// defaultErrorHandler centralizes error processing.
func defaultErrorHandler(logger *slog.Logger) ErrorHandler {
return func(e *Exchange, err error) {
if e.W.Closed() {
logger.Error(
"Handler returned error after writing response",
slog.Any("err", err),
)
return
}
ae := &Error{}
ok := errors.As(err, &ae)
if !ok {
logger.Error("An internal server error occurred", slog.Any("err", err))
ae = &Error{
Status: http.StatusInternalServerError,
Reason: ReasonServerError,
Description: "internal server error",
}
}
if we := e.JSON(ae.Status, ae); we != nil {
logger.Warn("Failed to write error response", slog.Any("err", we))
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package scheduler provides a flexible framework for running recurring tasks.
//
// Package scheduler manages the lifecycle of concurrent, scheduled jobs. The
// basic unit of work is a [Task], which can be adapted into a schedulable
// [Tick]. A [Tick] is a self-repeating job that determines its own next run
// time by returning a duration after each execution.
//
// # Usage
//
// Helpers like [Every] and [After] are provided to easily convert a simple
// [Task] into a [Tick] with common scheduling patterns:
//
// - [Every]: Creates a drift-free Tick that runs at a fixed cadence,
// accounting for the task's own execution time.
// - [After]: Creates a drifting Tick that waits for a fixed duration after
// the previous run completes.
//
// Example:
//
// s := scheduler.New(context.Background())
// defer s.Shutdown()
//
// task := scheduler.TaskFn(func(context.Context) {
// slog.Info("Tick!")
// })
//
// tick := scheduler.Every(2*time.Second, task)
// s.Dispatch(tick)
//
// // Let the scheduler run for a while.
// time.Sleep(5 * time.Second)
package scheduler
import (
"context"
"sync"
"time"
)
// Tick represents a unit of work that can be scheduled to run repeatedly.
type Tick interface {
// Run executes the job and returns the duration to wait before the next
// execution. It accepts a context that is cancelled when the scheduler
// is shut down.
//
// If the returned duration is zero or negative, the next run is scheduled
// immediately.
Run(ctx context.Context) time.Duration
}
// TickFn is an adapter to allow the use of ordinary functions as [Tick]s.
type TickFn func(ctx context.Context) time.Duration
// Run implements [Tick].
func (f TickFn) Run(ctx context.Context) time.Duration { return f(ctx) }
// Task represents a unit of work to be executed in a scheduler loop.
//
// Helpers like [After] and [Every] adapt a [Task] into a [Tick].
type Task interface {
// Run executes the job. It accepts a context for cancellation and
// timeout control.
Run(ctx context.Context)
}
// TaskFn is an adapter to allow the use of ordinary functions as [Task]s.
type TaskFn func(ctx context.Context)
// Run implements [Task].
func (f TaskFn) Run(ctx context.Context) { f(ctx) }
// After creates a drifting [Tick] that runs after a fixed delay.
//
// The scheduler waits for the full delay after the task has completed, so
// the effective cadence will vary based on the task's execution time.
func After(d time.Duration, task Task) Tick {
return TickFn(func(ctx context.Context) time.Duration {
task.Run(ctx)
return d
})
}
// Every creates a drift-free [Tick] that runs at a fixed interval.
//
// The wrapper measures the [Task] execution time and subtracts it from the
// specified interval, ensuring the task starts at a consistent cadence. If a
// task's execution time exceeds the interval, the next run starts immediately.
func Every(d time.Duration, task Task) Tick {
return TickFn(func(ctx context.Context) time.Duration {
start := time.Now()
task.Run(ctx)
elapsed := time.Since(start)
return max(0, d-elapsed)
})
}
// Scheduler manages the non-blocking execution of [Tick]s at their intervals.
type Scheduler interface {
// Context returns the scheduler's context. This context is cancelled when
// Shutdown is called. Users can select on this context's Done channel to
// coordinate with the scheduler's termination.
Context() context.Context
// Dispatch executes the given tick in a separate goroutine. The tick will
// run immediately and then repeat according to the duration it returns
// until the scheduler is shut down. Multiple ticks can be dispatched
// concurrently without blocking each other.
Dispatch(tick Tick)
// Shutdown gracefully stops the scheduler. It cancels the scheduler's
// context and waits for all its pending tasks to complete. Shutdown blocks
// until all dispatched goroutines have finished.
Shutdown()
}
// New creates a new [Scheduler] tied to the provided parent context.
//
// Cancelling this context will also cause the scheduler to shut down.
func New(ctx context.Context) Scheduler {
ctx, cancel := context.WithCancel(ctx)
return &scheduler{
ctx: ctx,
cancel: cancel,
}
}
// scheduler is the concrete implementation of the [Scheduler] interface.
type scheduler struct {
// ctx is the internal lifecycle context.
ctx context.Context
// cancel stops all dispatched goroutines.
cancel context.CancelFunc
// wg tracks active task goroutines.
wg sync.WaitGroup
}
// Context implements [Scheduler].
func (s *scheduler) Context() context.Context {
return s.ctx
}
// Dispatch implements [Scheduler].
func (s *scheduler) Dispatch(tick Tick) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-timer.C:
wait := tick.Run(s.ctx)
timer.Reset(wait)
}
}
}()
}
// Shutdown implements [Scheduler].
func (s *scheduler) Shutdown() {
s.cancel()
s.wg.Wait()
}
var _ Scheduler = (*scheduler)(nil)
// Once creates a synchronous [Scheduler] that runs each [Tick] exactly once.
//
// Its [Scheduler.Dispatch] method is blocking and runs the [Tick] in the
// calling goroutine. This implementation is useful for testing or executing a
// task without true background scheduling.
func Once(ctx context.Context) Scheduler {
return &once{ctx: ctx}
}
// once is a [Scheduler] implementation for synchronous, single execution.
type once struct {
// ctx is the context passed to executed ticks.
ctx context.Context
}
// Context implements [Scheduler].
func (o *once) Context() context.Context { return o.ctx }
// Dispatch implements [Scheduler].
func (o *once) Dispatch(tick Tick) { tick.Run(o.ctx) }
// Shutdown implements [Scheduler].
func (o *once) Shutdown() {}
var _ Scheduler = (*once)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package build provides test helpers for compiling Go binaries.
//
// Package build offers utilities for compiling Go source code during tests. It
// ensures that build artifacts are isolated and automatically cleaned up after
// the tests finish by leveraging the testing framework's temporary directory
// management.
//
// # Usage
//
// Call the [Binary] function within a test to compile a target program.
//
// Example:
//
// func TestIntegration(t *testing.T) {
// exe := build.Binary(t, "./cmd/app", "app-bin")
// cmd := exec.Command(exe)
// // ... run and test the binary ...
// }
package build
import (
"os/exec"
"path/filepath"
"runtime"
"testing"
)
// Binary compiles Go source code and returns the path to the executable.
//
// It compiles the code located at the src directory and writes the resulting
// executable to dst within a temporary directory. The test framework
// automatically removes the executable and its directory when the test
// completes. It appends the ".exe" suffix on Windows systems.
func Binary(t testing.TB, src, dst string) string {
t.Helper()
exe := filepath.Join(t.TempDir(), dst)
if runtime.GOOS == "windows" {
exe += ".exe"
}
// Compile the current directory but execute the command inside the src
// directory. This ensures Go respects the module context of the target
// program.
cmd := exec.Command("go", "build", "-o", exe, ".") //nolint:gosec
cmd.Dir = src
if out, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("failed to build %s: %v\n%s", dst, err, out)
}
return exe
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ports provides integration test helpers for network port management.
//
// Package ports offers utilities to find available network ports and block
// until those ports begin accepting connections. These helpers are primarily
// intended for integration testing scenarios where services are started on
// dynamic ports to avoid collisions.
//
// # Usage
//
// Find a free port and wait for a service to become ready on it.
//
// Example:
//
// port := ports.FreeT(t)
// go startService(port)
// ports.WaitT(t, "127.0.0.1", port)
package ports
import (
"context"
"net"
"strconv"
"testing"
"time"
)
// Free asks the kernel for a free, open port that is ready to use.
//
// It opens a temporary TCP listener on a random port and returns the assigned
// port number.
func Free() (int, error) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
defer func() {
_ = l.Close()
}()
return l.Addr().(*net.TCPAddr).Port, nil
}
// FreeT is a test helper that wraps [Free].
//
// It fails the test immediately if a free port cannot be found.
func FreeT(t testing.TB) int {
t.Helper()
port, err := Free()
if err != nil {
t.Fatalf("failed to get free port: %v", err)
}
return port
}
// Wait blocks until a TCP connection is accepted at the specified address.
//
// It periodically attempts to dial the host and port until it succeeds or the
// provided [context.Context] is canceled or its deadline is exceeded.
func Wait(ctx context.Context, host string, port int) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
var d net.Dialer
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
// DialContext respects the context for the dial operation itself.
conn, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = conn.Close()
return nil
}
select {
case <-ctx.Done():
// The context was canceled or its deadline exceeded.
return ctx.Err()
case <-ticker.C:
// Wait for the next tick before trying again.
}
}
}
// WaitT is a test helper that wraps [Wait].
//
// It fails the test immediately if the port does not become available before
// the test context expires.
func WaitT(t testing.TB, host string, port int) {
t.Helper()
if err := Wait(t.Context(), host, port); err != nil {
t.Fatalf("failed waiting for port %d on %s: %v", port, host, err)
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package updater provides functionality to check for newer GitHub releases.
//
// Package updater queries the GitHub Releases API to retrieve the latest
// release and compares its tag against the current application version using
// semantic versioning.
//
// # Usage
//
// Initialize a [Config] and use the [Check] function to look for new versions.
//
// Example:
//
// cfg := &updater.Config{
// Owner: "deep-rent",
// Repository: "vouch",
// Current: "v1.0.0",
// UserAgent: "Vouch/1.0.0",
// }
//
// // Check for updates.
// rel, err := updater.Check(context.Background(), cfg)
// if err != nil {
// log.Printf("Failed to check for updates: %v", err)
// } else if rel != nil {
// log.Printf("New version available: %s (see %s)", rel.Version, rel.URL)
// }
package updater
import (
"context"
"encoding/json/v2"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/mod/semver"
)
// Default configuration values for the updater.
const (
// DefaultBaseURL is the default GitHub API base URL.
DefaultBaseURL = "https://api.github.com"
// DefaultTimeout is the default timeout for HTTP requests (5 seconds).
DefaultTimeout = 5 * time.Second
)
// Release represents a published release on GitHub.
type Release struct {
// Version is the tag name of the release (e.g., "v1.0.0").
Version string `json:"tag_name"`
// URL is the web address to view the release on GitHub.
URL string `json:"html_url"`
// Published is the timestamp the release was published on GitHub.
Published time.Time `json:"published_at"`
// Notes contains the release notes or description.
Notes string `json:"body"`
}
// Config holds the configuration for the [Updater].
type Config struct {
// BaseURL is the base URL for the GitHub API. It defaults to
// [DefaultBaseURL] if not set.
BaseURL string
// Owner is the GitHub repository owner (required).
Owner string
// Repository is the name of the GitHub repository (required).
Repository string
// Current is the current version string of the application (required).
Current string
// UserAgent is the value for the User-Agent header sent with requests.
UserAgent string
// Timeout is the time limit for requests made by the updater. It defaults
// to [DefaultTimeout] if not set.
Timeout time.Duration
}
// Updater checks for updates on GitHub for a specific repository.
type Updater struct {
// baseURL is the API endpoint for release lookups.
baseURL string
// owner is the GitHub user or organization.
owner string
// repository is the GitHub project name.
repository string
// current is the normalized current version of the application.
current string
// userAgent is the identifying string sent in the HTTP header.
userAgent string
// client is the HTTP client for making API requests.
client *http.Client
}
// New creates a new [Updater] with the given configuration.
//
// It initializes the HTTP client with the specified timeout. It panics if the
// configuration is missing required fields or if the current version string is
// not a valid semantic version.
func New(cfg *Config) *Updater {
if cfg.Owner == "" {
panic("updater: owner is required")
}
if cfg.Repository == "" {
panic("updater: repository is required")
}
if cfg.Current == "" {
panic("updater: current version is required")
}
current := normalize(cfg.Current)
if !semver.IsValid(current) {
panic(fmt.Sprintf(
"updater: current version %q is not a valid semver",
cfg.Current,
))
}
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = DefaultBaseURL
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
return &Updater{
baseURL: baseURL,
owner: cfg.Owner,
repository: cfg.Repository,
current: current,
userAgent: cfg.UserAgent,
client: &http.Client{
Timeout: timeout,
},
}
}
// Check queries the GitHub Releases API to determine if a newer version exists.
//
// It compares the latest release tag against the current version using semantic
// versioning. It returns a [Release] if a newer version is found. It returns
// nil if the current version is up-to-date or if the latest release is older or
// equal.
func (u *Updater) Check(ctx context.Context) (*Release, error) {
endpoint, err := url.JoinPath(
u.baseURL,
"repos",
u.owner,
u.repository,
"releases",
"latest",
)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
if u.userAgent != "" {
req.Header.Set("User-Agent", u.userAgent)
}
res, err := u.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
}
defer func() {
_ = res.Body.Close()
}()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status from github api: %s", res.Status)
}
var r Release
if err := json.UnmarshalRead(res.Body, &r); err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
latest := normalize(r.Version)
if !semver.IsValid(latest) {
return nil, fmt.Errorf("latest version %q is not a valid semver", r.Version)
}
if semver.Compare(latest, u.current) > 0 {
return &r, nil
}
return nil, nil
}
// Check is a convenience function to check for updates in a single call.
//
// It creates a temporary [Updater] with the provided config and calls its
// [Updater.Check] method.
func Check(ctx context.Context, cfg *Config) (*Release, error) {
return New(cfg).Check(ctx)
}
// normalize ensures the version string has a "v" prefix for the semver package.
func normalize(v string) string {
if strings.HasPrefix(v, "v") {
return v
}
return "v" + v
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package uuid provides an implementation of Version 7 UUIDs.
//
// Package uuid provides an implementation of Version 7 (Time-ordered)
// Universally Unique Identifiers (UUID) as defined in RFC 4122 and RFC 9562.
//
// # Migration Note (v4 -> v7)
//
// We migrated from UUIDv4 (fully random) to UUIDv7 (time-ordered) to improve
// database performance. UUIDv4 causes significant index fragmentation and
// random I/O in B-Tree structures (standard database primary keys) due to its
// lack of locality.
//
// UUIDv7 solves this by being strictly monotonic while retaining global
// uniqueness. This results in "append-only" index behavior, higher write
// throughput, and better cache locality. It also aligns with native support
// arriving in PostgreSQL 18+.
//
// # Usage
//
// Generate a new time-ordered identifier or parse an existing string.
//
// Example:
//
// id := uuid.New()
// fmt.Println(id.String())
package uuid
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"sync"
"time"
)
// UUIDv7 is a 128-bit time-ordered identifier (16 bytes).
//
// Layout:
// - 48 bits: Unix Timestamp (milliseconds)
// - 4 bits: Version (0111)
// - 12 bits: Random Data A
// - 2 bits: Variant (10)
// - 62 bits: Random Data B
type UUIDv7 [16]byte
// String returns the canonical hyphenated string representation of the
// [UUIDv7].
//
// Format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
func (u UUIDv7) String() string {
buf := make([]byte, 36)
hex.Encode(buf[0:8], u[0:4])
buf[8] = '-'
hex.Encode(buf[9:13], u[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], u[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], u[8:10])
buf[23] = '-'
hex.Encode(buf[24:], u[10:])
return string(buf)
}
// MarshalJSON transforms the UUIDv7 into a JSON string.
func (u UUIDv7) MarshalJSON() ([]byte, error) {
return []byte(`"` + u.String() + `"`), nil
}
// UnmarshalJSON parses a JSON string into a UUIDv7.
func (u *UUIDv7) UnmarshalJSON(b []byte) error {
// Remove quotes from the JSON string
if len(b) < 2 || b[0] != '"' || b[len(b)-1] != '"' {
return errors.New("uuid: invalid json string")
}
parsed, err := Parse(string(b[1 : len(b)-1]))
if err != nil {
return err
}
*u = parsed
return nil
}
// EqualString checks if the [UUIDv7] matches the provided hyphenated string.
// It is highly optimized to avoid allocations and exits early on a mismatch.
func (u UUIDv7) EqualString(s string) bool {
if len(s) != 36 {
return false
}
// Quick check for proper hyphenation:
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return false
}
// We compare the hex pairs in the string to the bytes in u.
// Groups: 8-4-4-4-12 chars -> 4-2-2-2-6 bytes
// Total 16 bytes.
// Group 1 (8 chars -> bytes 0-3)
if !compareHex(s[0:8], u[0:4]) {
return false
}
// Group 2 (4 chars -> bytes 4-5)
if !compareHex(s[9:13], u[4:6]) {
return false
}
// Group 3 (4 chars -> bytes 6-7)
if !compareHex(s[14:18], u[6:8]) {
return false
}
// Group 4 (4 chars -> bytes 8-9)
if !compareHex(s[19:23], u[8:10]) {
return false
}
// Group 5 (12 chars -> bytes 10-15)
if !compareHex(s[24:36], u[10:16]) {
return false
}
return true
}
// New generates a strictly monotonic [UUIDv7] with sub-millisecond precision.
//
// It fills the timestamp and sequence fields using a global monotonic counter
// derived from the system clock, ensuring that IDs generated within the same
// millisecond are ordered. The remaining bits are filled with cryptographically
// secure random data.
func New() UUIDv7 {
ms, seq := tick()
var u UUIDv7
// Fill the first 6 bytes with the timestamp (big endian style).
// Fill bytes 7 and 8 with the version and sequence.
u[0] = byte(ms >> 40) //nolint:gosec
u[1] = byte(ms >> 32) //nolint:gosec
u[2] = byte(ms >> 24) //nolint:gosec
u[3] = byte(ms >> 16) //nolint:gosec
u[4] = byte(ms >> 8) //nolint:gosec
u[5] = byte(ms) //nolint:gosec
u[6] = 0x70 | byte(seq>>8) //nolint:gosec
u[7] = byte(seq) //nolint:gosec
// Fill bytes 8 to 15 with random data.
if _, err := io.ReadFull(rand.Reader, u[8:]); err != nil {
panic(fmt.Errorf("uuid: failed to read random bytes: %w", err))
}
// Set byte 8 to the variant.
u[8] = (u[8] & 0x3f) | 0x80
return u
}
// Parse converts a 36-character hyphenated string into a [UUIDv7].
//
// It strictly validates that the UUID is Version 7 and Variant 1 (RFC 4122).
func Parse(s string) (UUIDv7, error) {
var u UUIDv7
if len(s) != 36 {
return u, fmt.Errorf("uuid: invalid length (%d)", len(s))
}
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return u, errors.New("uuid: invalid format")
}
h := s[0:8] + s[9:13] + s[14:18] + s[19:23] + s[24:]
if _, err := hex.Decode(u[:], []byte(h)); err != nil {
return u, fmt.Errorf("uuid: invalid characters: %w", err)
}
if (u[6] & 0xf0) != 0x70 {
return UUIDv7{}, errors.New("uuid: invalid version: expected v7")
}
if (u[8] & 0xc0) != 0x80 {
return UUIDv7{}, errors.New("uuid: invalid variant: expected RFC 4122")
}
return u, nil
}
// ParseBytes parses a 16-byte raw slice into a [UUIDv7].
//
// It strictly validates that the byte slice is exactly 16 bytes and conforms
// to Version 7 and Variant 1. This function creates a complete copy of the
// data.
func ParseBytes(b []byte) (UUIDv7, error) {
var u UUIDv7
if len(b) != 16 {
return u, fmt.Errorf("uuid: invalid length (%d)", len(b))
}
copy(u[:], b)
if (u[6] & 0xf0) != 0x70 {
return UUIDv7{}, errors.New("uuid: invalid version: expected v7")
}
if (u[8] & 0xc0) != 0x80 {
return UUIDv7{}, errors.New("uuid: invalid variant: expected RFC 4122")
}
return u, nil
}
// Global state for the monotonic generator.
var (
// mu protects the last generated timestamp state.
mu sync.Mutex
// last is the combined scalar of the last generated timestamp and sequence.
last int64
)
// tick implements Method 3 from the UUIDv7 specification (RFC 9562, Section
// 6.2).
//
// It returns a timestamp (ms) and a strictly increasing sequence (seq). The
// sequence holds fractional nanoseconds scaled to fit into 12 bits.
func tick() (ms, seq int64) {
mu.Lock()
defer mu.Unlock()
// 1. Get current time components.
ns := time.Now().UnixNano()
ms = ns / 1_000_000
// 2. Calculate the sequence number.
// We have 1,000,000 nanoseconds in a millisecond.
// We have 12 bits for the sequence (max 4096).
// Dividing by 256 (>> 8) maps 1,000,000 to ~3906, which fits in 12 bits.
seq = (ns - ms*1_000_000) >> 8
// 3. Pack into a comparable scalar (48 bits MS + 12 bits SEQ).
ts := ms<<12 + seq
// 4. Enforce monotonicity.
if ts <= last {
ts = last + 1
// Unpack the scalar back into components.
// If seq overflowed 12 bits, it automatically increments ms.
ms = ts >> 12
seq = ts & 0xfff
}
last = ts
return ms, seq
}
// compareHex compares a hex-encoded string segment to a byte slice.
func compareHex(s string, b []byte) bool {
for i := range b {
// Convert two hex chars to one byte
hi := decodeHex(s[i*2])
lo := decodeHex(s[i*2+1])
if hi == 0xff || lo == 0xff || (hi<<4|lo) != b[i] {
return false
}
}
return true
}
// decodeHex converts a single hex character to its byte value.
// Returns 0xff if the character is invalid.
func decodeHex(c byte) byte {
switch {
case c >= '0' && c <= '9':
return c - '0'
case c >= 'a' && c <= 'f':
return c - 'a' + 10
case c >= 'A' && c <= 'F':
return c - 'A' + 10
}
return 0xff
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package valid provides utility functions for validating common formats and
// data types.
//
// The valid package offers a comprehensive suite of validation tools designed
// to simplify data integrity checks in Go applications. It includes standalone
// predicate functions for format verification and a stateful [Validator] for
// aggregating errors across complex, nested structures.
//
// # Usage
//
// You can use standalone functions for simple checks or the [Validator] type
// for struct validation.
//
// Standalone Validation:
//
// Direct check of a single value using predicate functions.
//
// Example:
//
// isValid := valid.Email("user@example.com")
//
// Struct Validation:
//
// Implementing the [Validatable] interface to perform complex checks.
//
// Example:
//
// type User struct {
// Email string
// Age int
// }
//
// func (u *User) Validate(v *valid.Validator) {
// v.Email("email", u.Email)
// v.BetweenInt("age", u.Age, 18, 99)
// }
//
// usr := &User{Email: "user@example.com", Age: 25}
// err := valid.Test(usr)
// if err != nil {
// // Handle validation errors
// }
package valid
import (
"encoding/json/jsontext"
"mime"
"net"
"net/netip"
"net/url"
"regexp"
"strings"
"golang.org/x/mod/semver"
"github.com/deep-rent/nexus/uuid"
"github.com/deep-rent/nexus/internal/ascii"
)
var (
rxBase64 = regexp.MustCompile(`^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`)
rxBase64URL = regexp.MustCompile(`^(?:[A-Za-z0-9_-]{4})*(?:[A-Za-z0-9_-]{2}==|[A-Za-z0-9_-]{3}=|[A-Za-z0-9_-]{1,3})?$`)
rxURN = regexp.MustCompile(`^(?i)urn:[a-z0-9][a-z0-9-]{0,31}:[a-z0-9()+,\-.:=@;$_!*'%/?#]+$`)
rxHostname = regexp.MustCompile(`^(?:[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])(?:\.(?:[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9]))*$`)
rxFQDN = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\.?$`)
rxEmail = regexp.MustCompile(`^[a-zA-Z0-9.!#$%&'*+/=?^_{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`)
rxBIC = regexp.MustCompile(`^[A-Z]{6}[A-Z2-9][A-NP-Z0-9](?:[A-Z0-9]{3})?$`)
rxBCP47 = regexp.MustCompile(`(?i)^(?:(?:[a-z]{2,3}(?:-[a-z]{3}){0,3})|[a-z]{4}|[a-z]{5,8})(?:-[a-z]{4})?(?:-(?:[a-z]{2}|[0-9]{3}))?(?:-(?:[a-z0-9]{5,8}|[0-9][a-z0-9]{3}))*(?:-[0-9a-wy-z](?:-[a-z0-9]{2,8})+)*(?:-x(?:-[a-z0-9]{1,8})+)?$`)
)
// CIDR checks if the string is a valid Classless Inter-Domain Routing (CIDR)
// block. A valid CIDR block is an IP address followed by a slash and a
// prefix length.
func CIDR(s string) bool {
_, err := netip.ParsePrefix(s)
return err == nil
}
// CIDRv4 checks if the string is a valid IPv4 CIDR block.
func CIDRv4(s string) bool {
p, err := netip.ParsePrefix(s)
return err == nil && p.Addr().Is4()
}
// CIDRv6 checks if the string is a valid IPv6 CIDR block.
func CIDRv6(s string) bool {
p, err := netip.ParsePrefix(s)
return err == nil && p.Addr().Is6()
}
// Hostname checks if the string is a valid hostname according to RFC 952 and
// RFC 1123. The hostname must be at most 253 characters long.
func Hostname(s string) bool {
return len(s) != 0 && len(s) <= 253 && rxHostname.MatchString(s)
}
// Port checks if the number represents a valid network port number.
// Port numbers must be between 1 and 65535 inclusive.
func Port(n int) bool {
return n > 0 && n <= 65535
}
// IP checks if the string is a valid IP address (either IPv4 or IPv6).
func IP(s string) bool {
_, err := netip.ParseAddr(s)
return err == nil
}
// IPv4 checks if the string is a valid IPv4 address.
func IPv4(s string) bool {
addr, err := netip.ParseAddr(s)
return err == nil && addr.Is4()
}
// IPv6 checks if the string is a valid IPv6 address.
func IPv6(s string) bool {
addr, err := netip.ParseAddr(s)
return err == nil && addr.Is6()
}
// FQDN checks if the string is a Fully Qualified Domain Name (FQDN).
// An FQDN must have at least one valid top-level domain. It allows for an
// optional trailing dot.
func FQDN(s string) bool {
return len(s) != 0 && len(s) <= 253 && rxFQDN.MatchString(s)
}
// URI checks if the string is a valid URI (Uniform Resource Identifier)
// according to RFC 3986.
func URI(s string) bool {
_, err := url.ParseRequestURI(s)
return err == nil
}
// URL checks if the string is a valid URL with a scheme and host.
func URL(s string) bool {
u, err := url.ParseRequestURI(s)
return err == nil && u.Scheme != "" && u.Host != ""
}
// URN checks if the string is a valid URN (Uniform Resource Name) according to
// RFC 2141.
func URN(s string) bool {
return rxURN.MatchString(s)
}
// Alpha checks if the string contains only alphabetical characters (a-z, A-Z).
// An empty string returns false.
func Alpha(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsAlpha(rune(s[i])) {
return false
}
}
return true
}
// AlphaNum checks if the string contains only alphanumeric characters (a-z,
// A-Z, 0-9). An empty string returns false.
func AlphaNum(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsAlphaNum(rune(s[i])) {
return false
}
}
return true
}
// ASCII checks if the string contains only ASCII characters.
// An empty string returns false.
func ASCII(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if s[i] > '\x7F' {
return false
}
}
return true
}
// Slug checks if the string is a valid URL slug.
// A slug consists of lowercase letters, numbers, and hyphens, and cannot
// start or end with a hyphen or contain consecutive hyphens.
func Slug(s string) bool {
if s == "" || s[0] == '-' || s[len(s)-1] == '-' {
return false
}
for i := 0; i < len(s); i++ {
c := s[i]
if ascii.IsLower(rune(c)) || ascii.IsDigit(rune(c)) {
continue
}
if c == '-' {
if s[i-1] == '-' {
return false
}
continue
}
return false
}
return true
}
// Upper checks if the string contains only uppercase characters (A-Z).
// An empty string returns false.
func Upper(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsUpper(rune(s[i])) {
return false
}
}
return true
}
// Lower checks if the string contains only lowercase characters (a-z).
// An empty string returns false.
func Lower(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsLower(rune(s[i])) {
return false
}
}
return true
}
// Base64 checks if the string is a valid Base64 encoded string.
// It allows standard padding characters.
func Base64(s string) bool {
return rxBase64.MatchString(s)
}
// Base64URL checks if the string is a valid Base64URL encoded string.
// Padding characters are supported but optional.
func Base64URL(s string) bool {
return rxBase64URL.MatchString(s)
}
// MAC checks if the string is a valid IEEE 802 MAC address.
func MAC(s string) bool {
_, err := net.ParseMAC(s)
return err == nil
}
// Lang checks if the string is a valid BCP 47 language tag.
// It strictly follows RFC 5646.
func Lang(s string) bool {
return rxBCP47.MatchString(s)
}
// JSON checks if the string is a valid JSON document.
// It performs the check efficiently.
func JSON(s string) bool {
return jsontext.Value(s).IsValid()
}
// MIME checks if the string is a valid Media Type (MIME type) according to
// RFC 2045 and RFC 2046.
func MIME(s string) bool {
t, _, err := mime.ParseMediaType(s)
return err == nil && strings.Contains(t, "/")
}
// CreditCard checks if the string is a valid credit card number using the Luhn
// algorithm. It ignores whitespace and hyphens before calculating the checksum.
func CreditCard(s string) bool {
var (
sum int
cnt int
alt bool
)
for i := len(s) - 1; i >= 0; i-- {
c := s[i]
if c == ' ' || c == '-' {
continue
}
if c < '0' || c > '9' {
return false
}
n := int(c - '0')
if alt {
n *= 2
if n > 9 {
n -= 9
}
}
sum += n
cnt++
alt = !alt
}
return cnt >= 13 && cnt <= 19 && sum%10 == 0
}
// Email checks if the string is a valid email address according to the W3C
// HTML5 specification.
func Email(s string) bool {
return rxEmail.MatchString(s)
}
// Hex checks if the string is a valid hexadecimal number.
// The string may optionally be prefixed with "0x" or "0X".
func Hex(s string) bool {
if len(s) > 2 && s[0] == '0' && (s[1] == 'x' || s[1] == 'X') {
s = s[2:]
}
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsHex(rune(s[i])) {
return false
}
}
return true
}
// HexColor checks if the string is a valid hex color code.
// The string may optionally be prefixed with "#". It must be exactly 3 or 6
// hexadecimal characters long.
func HexColor(s string) bool {
if s == "" {
return false
}
if s[0] == '#' {
s = s[1:]
}
if len(s) != 3 && len(s) != 6 {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsHex(rune(s[i])) {
return false
}
}
return true
}
// ISSN checks if the string is a valid International Standard Serial Number
// (ISSN).
func ISSN(s string) bool {
if len(s) != 9 {
return false
}
return ascii.IsDigit(rune(s[0])) &&
ascii.IsDigit(rune(s[1])) &&
ascii.IsDigit(rune(s[2])) &&
ascii.IsDigit(rune(s[3])) &&
s[4] == '-' &&
ascii.IsDigit(rune(s[5])) &&
ascii.IsDigit(rune(s[6])) &&
ascii.IsDigit(rune(s[7])) &&
(ascii.IsDigit(rune(s[8])) || s[8] == 'X')
}
// ISBN10 checks if the string is a valid ISBN-10.
// It strips hyphens before validation.
func ISBN10(s string) bool {
var n int
for i := 0; i < len(s); i++ {
c := s[i]
if c == '-' {
continue
}
if n == 9 && c == 'X' {
n++
continue
}
if !ascii.IsDigit(rune(c)) {
return false
}
n++
}
return n == 10
}
// ISBN13 checks if the string is a valid ISBN-13.
// It strips hyphens before validation.
func ISBN13(s string) bool {
var n int
for i := 0; i < len(s); i++ {
c := s[i]
if c == '-' {
continue
}
if !ascii.IsDigit(rune(c)) {
return false
}
n++
}
return n == 13
}
// ISBN checks if the string is a valid ISBN (10 or 13).
func ISBN(s string) bool {
return ISBN10(s) || ISBN13(s)
}
// Country2 checks if the string is a valid ISO 3166-1 alpha-2
// country code (e.g., "US").
func Country2(s string) bool {
return len(s) == 2 && ascii.IsUpper(rune(s[0])) &&
ascii.IsUpper(rune(s[1]))
}
// Country3 checks if the string is a valid ISO 3166-1 alpha-3
// country code (e.g., "USA").
func Country3(s string) bool {
return len(s) == 3 && ascii.IsUpper(rune(s[0])) &&
ascii.IsUpper(rune(s[1])) &&
ascii.IsUpper(rune(s[2]))
}
// CountryN checks if the string is a valid ISO 3166-1 numeric
// country code (e.g., "840").
func CountryN(s string) bool {
return len(s) == 3 && ascii.IsDigit(rune(s[0])) &&
ascii.IsDigit(rune(s[1])) && ascii.IsDigit(rune(s[2]))
}
// Currency checks if the string is a valid ISO 4217 currency code (e.g.,
// "EUR", "USD").
func Currency(s string) bool {
return len(s) == 3 && ascii.IsUpper(rune(s[0])) &&
ascii.IsUpper(rune(s[1])) &&
ascii.IsUpper(rune(s[2]))
}
// UUIDv7 checks if the string is a valid Version 7 UUID as defined in RFC 4122
// and RFC 9562.
func UUIDv7(s string) bool {
_, err := uuid.Parse(s)
return err == nil
}
// Lat checks if the number is a valid latitude coordinate (-90 to 90).
func Lat(f float32) bool {
return f >= -90 && f <= 90
}
// Lon checks if the number is a valid longitude coordinate (-180 to 180).
func Lon(f float32) bool {
return f >= -180 && f <= 180
}
// MD5 checks if the string is a valid MD5 hash (32 hex characters).
func MD5(s string) bool {
return isHash(s, 32)
}
// SHA256 checks if the string is a valid SHA256 hash (64 hex characters).
func SHA256(s string) bool {
return isHash(s, 64)
}
// SHA384 checks if the string is a valid SHA384 hash (96 hex characters).
func SHA384(s string) bool {
return isHash(s, 96)
}
// SHA512 checks if the string is a valid SHA512 hash (128 hex characters).
func SHA512(s string) bool {
return isHash(s, 128)
}
// SemVer checks if the string is a valid Semantic Versioning 2.0.0 string.
// Note that the "v" prefix is mandatory.
func SemVer(s string) bool {
return semver.IsValid(s)
}
// Phone checks if the string is a valid E.164 formatted phone number.
// The string must start with a '+' and be followed by 2 to 15 digits.
func Phone(s string) bool {
if len(s) < 3 || len(s) > 16 || s[0] != '+' || s[1] < '1' || s[1] > '9' {
return false
}
for i := 2; i < len(s); i++ {
if !ascii.IsDigit(rune(s[i])) {
return false
}
}
return true
}
// BIC checks if the string is a valid Business Identifier Code (ISO 9362).
func BIC(s string) bool {
return rxBIC.MatchString(s)
}
// IBAN checks if the string is a valid International Bank Account Number.
// It ignores spaces and performs the modulo 97 check.
func IBAN(s string) bool {
var b [34]byte
var n int
for i := 0; i < len(s); i++ {
c := s[i]
if c == ' ' {
continue
}
if n >= 34 || !ascii.IsAlphaNum(rune(c)) {
return false
}
b[n] = c
n++
}
if n < 15 {
return false
}
if !ascii.IsAlpha(rune(b[0])) ||
!ascii.IsAlpha(rune(b[1])) ||
!ascii.IsDigit(rune(b[2])) ||
!ascii.IsDigit(rune(b[3])) {
return false
}
var rem int
// Modulo 97 check: move first 4 characters to the end.
for i := 4; i < n; i++ {
rem = mod97(rem, b[i])
}
for i := 0; i < 4; i++ {
rem = mod97(rem, b[i])
}
return rem == 1
}
// isHash reports whether the string s has the specified length and consists
// entirely of hexadecimal characters.
func isHash(s string, size int) bool {
if len(s) != size {
return false
}
for i := 0; i < len(s); i++ {
if !ascii.IsHex(rune(s[i])) {
return false
}
}
return true
}
// mod97 updates the running remainder for a large numeric string using the
// modulo 97 operation. If c is a letter, it is treated as a two-digit number
// (A=10, ..., Z=35) per the ISO 13616 standard for IBANs.
func mod97(rem int, c byte) int {
var n, k int
if ascii.IsDigit(rune(c)) {
n = rem * 10
k = int(c - '0')
} else {
n = rem * 100
k = int(ascii.ToUpper(rune(c)) - 'A' + 10)
}
return (n + k) % 97
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package valid
import (
"fmt"
"reflect"
"regexp"
"slices"
"strings"
"time"
)
// Error represents a collection of validation errors mapped by their
// corresponding field paths in dot notation. It naturally serializes to JSON,
// making it ideal for API error responses.
type Error map[string][]string
// Error implements the error interface, providing a consolidated string
// representation of all validation failures.
func (e Error) Error() string {
var sb strings.Builder
sb.WriteString("validation failed: ")
first := true
for path, msgs := range e {
if !first {
sb.WriteString("; ")
}
sb.WriteString(path)
sb.WriteString(": ")
sb.WriteString(strings.Join(msgs, ", "))
first = false
}
return sb.String()
}
// Validatable describes a structure that can self-validate using a [Validator].
// It is typically implemented by API DTOs and request payloads.
type Validatable interface {
// Validate executes validation logic on the object using the provided
// [Validator]. It records any detected failures in the validator.
Validate(v *Validator)
}
// Test validates a single [Validatable] instance or a slice of them.
// It returns a composite error if any validation checks fail, or nil if
// all checks pass.
func Test(target any) error {
if t, ok := target.(Validatable); ok {
v := New()
t.Validate(v)
return v.Error()
}
rt := reflect.TypeOf(target)
if rt != nil && rt.Kind() == reflect.Slice {
v := New()
v.Each("", target)
return v.Error()
}
return nil
}
// Each validates every element in a slice that implements the [Validatable]
// interface.
// It returns a composite error if any element fails validation, or nil if
// all elements are valid.
func Each(target any) error {
v := New()
v.Each("", target)
return v.Error()
}
// Validator orchestrates the validation of fields, builds dot-notation paths
// for nested structures, and aggregates error messages.
type Validator struct {
errs Error
path string
}
// New creates and returns a new empty [Validator].
func New() *Validator {
return &Validator{
errs: make(Error),
}
}
// Error returns the composite validation error if any checks failed, or nil
// if all checks passed.
func (v *Validator) Error() error {
if len(v.errs) == 0 {
return nil
}
return v.errs
}
// Fail records an explicit error message against the given field.
func (v *Validator) Fail(field, msg string) {
if v.errs == nil {
v.errs = make(Error)
}
p := v.join(field)
v.errs[p] = append(v.errs[p], msg)
}
// Test dives into a nested [Validatable] struct. It appends the field name
// to the current path, seamlessly propagating any validation errors using dot
// notation (e.g., "user.address" or "items[0].name").
func (v *Validator) Test(field string, target Validatable) {
if target == nil {
return
}
sub := &Validator{
errs: v.errs,
path: v.join(field),
}
target.Validate(sub)
}
// Each iterates over a slice and validates each element that implements the
// [Validatable] interface. It automatically manages array indexing in the
// dot-notation path (e.g., "items[0]", "items[1]").
func (v *Validator) Each(field string, slice any) {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return
}
p := v.join(field)
for i := 0; i < rv.Len(); i++ {
val := rv.Index(i)
var target Validatable
// Safely unwind interfaces and nested pointers (read-only).
for {
k := val.Kind()
if (k == reflect.Pointer || k == reflect.Interface) && val.IsNil() {
break
}
if t, ok := val.Interface().(Validatable); ok {
target = t
break
}
if k == reflect.Pointer || k == reflect.Interface {
val = val.Elem()
continue
}
if val.CanAddr() {
if t, ok := val.Addr().Interface().(Validatable); ok {
target = t
break
}
}
break
}
if target != nil {
sub := &Validator{
errs: v.errs,
path: fmt.Sprintf("%s[%d]", p, i),
}
target.Validate(sub)
}
}
}
// join constructs the dot-notation path, escaping any literal dots in the
// field name. If the field is empty, it returns the current path unchanged.
func (v *Validator) join(field string) string {
if field == "" {
return v.path
}
field = strings.ReplaceAll(field, ".", "\\.")
if v.path == "" {
return field
}
return v.path + "." + field
}
// ----------------------------------------------------------------------------
// Comparison-based Checks
// ----------------------------------------------------------------------------
// Min asserts that a numeric value is at least the given minimum.
func (v *Validator) Min(field string, val, min float64) {
if val < min {
v.Fail(field, fmt.Sprintf("must be at least %v", min))
}
}
// Max asserts that a numeric value is at most the given maximum.
func (v *Validator) Max(field string, val, max float64) {
if val > max {
v.Fail(field, fmt.Sprintf("must be at most %v", max))
}
}
// MinInt asserts that an integer value is at least the given minimum.
func (v *Validator) MinInt(field string, val, min int) {
if val < min {
v.Fail(field, fmt.Sprintf("must be at least %d", min))
}
}
// MaxInt asserts that an integer value is at most the given maximum.
func (v *Validator) MaxInt(field string, val, max int) {
if val > max {
v.Fail(field, fmt.Sprintf("must be at most %d", max))
}
}
// Between asserts that a numeric value is between min and max inclusive.
func (v *Validator) Between(field string, val, min, max float64) {
if val < min || val > max {
v.Fail(field, fmt.Sprintf("must be between %v and %v", min, max))
}
}
// BetweenInt asserts that an integer value is between min and max inclusive.
func (v *Validator) BetweenInt(field string, val, min, max int) {
if val < min || val > max {
v.Fail(field, fmt.Sprintf("must be between %d and %d", min, max))
}
}
// MinLen asserts that the length of a string is at least min.
func (v *Validator) MinLen(field, val string, min int) {
if len(val) < min {
v.Fail(field, fmt.Sprintf("length must be at least %d", min))
}
}
// MaxLen asserts that the length of a string is at most max.
func (v *Validator) MaxLen(field, val string, max int) {
if len(val) > max {
v.Fail(field, fmt.Sprintf("length must be at most %d", max))
}
}
// Len asserts that the length of a string is exactly the given length.
func (v *Validator) Len(field, val string, n int) {
if len(val) != n {
v.Fail(field, fmt.Sprintf("length must be exactly %d", n))
}
}
// MinSize asserts that the size of a slice or map is at least min.
func (v *Validator) MinSize(field string, size, min int) {
if size < min {
v.Fail(field, fmt.Sprintf("size must be at least %d", min))
}
}
// MaxSize asserts that the size of a slice or map is at most max.
func (v *Validator) MaxSize(field string, size, max int) {
if size > max {
v.Fail(field, fmt.Sprintf("size must be at most %d", max))
}
}
// Size asserts that the size of a slice or map is exactly the given size.
func (v *Validator) Size(field string, size, n int) {
if size != n {
v.Fail(field, fmt.Sprintf("size must be exactly %d", n))
}
}
// Unique asserts that all elements in a string slice are unique.
func (v *Validator) Unique(field string, slice []string) {
if len(slice) < 2 {
return
}
seen := make(map[string]bool, len(slice))
for _, val := range slice {
if seen[val] {
v.Fail(field, "must contain unique items")
return
}
seen[val] = true
}
}
// Whitelist asserts that a value exactly matches one of the allowed options.
// The underlying concrete types must match.
func (v *Validator) Whitelist(field string, val any, list ...any) {
if !slices.Contains(list, val) {
v.Fail(field, "must be one of the allowed values")
}
}
// Blacklist asserts that a value does not match any of the denied options.
// The underlying concrete types must match.
func (v *Validator) Blacklist(field string, val any, list ...any) {
if slices.Contains(list, val) {
v.Fail(field, "must not be one of the denied values")
}
}
// NotEmpty asserts that a string is not empty.
func (v *Validator) NotEmpty(field, val string) {
if val == "" {
v.Fail(field, "must not be empty")
}
}
// NotBlank asserts that a string is not blank (contains at least one
// non-whitespace character).
func (v *Validator) NotBlank(field, val string) {
if strings.TrimSpace(val) == "" {
v.Fail(field, "must not be blank")
}
}
// Prefix asserts that a string starts with a specific prefix.
func (v *Validator) Prefix(field, val, prefix string) {
if !strings.HasPrefix(val, prefix) {
v.Fail(field, fmt.Sprintf("must start with %q", prefix))
}
}
// Suffix asserts that a string ends with a specific suffix.
func (v *Validator) Suffix(field, val, suffix string) {
if !strings.HasSuffix(val, suffix) {
v.Fail(field, fmt.Sprintf("must end with %q", suffix))
}
}
// Contains asserts that a string contains a specific substring.
func (v *Validator) Contains(field, val, sub string) {
if !strings.Contains(val, sub) {
v.Fail(field, fmt.Sprintf("must contain %q", sub))
}
}
// Match asserts that a string matches a regular expression.
func (v *Validator) Match(field, val string, rx *regexp.Regexp) {
if !rx.MatchString(val) {
v.Fail(field, fmt.Sprintf("must match the pattern %s", rx.String()))
}
}
// Before asserts that a time is before a specific threshold.
func (v *Validator) Before(field string, val, max time.Time) {
if !val.Before(max) {
v.Fail(field, fmt.Sprintf("must be before %v", max.Format(time.RFC3339)))
}
}
// After asserts that a time is after a specific threshold.
func (v *Validator) After(field string, val, min time.Time) {
if !val.After(min) {
v.Fail(field, fmt.Sprintf("must be after %v", min.Format(time.RFC3339)))
}
}
// ----------------------------------------------------------------------------
// Standard Format Checks
// ----------------------------------------------------------------------------
// CIDR ensures that the given string satisfies [CIDR].
func (v *Validator) CIDR(field, val string) {
if !CIDR(val) {
v.Fail(field, "must be a valid CIDR")
}
}
// CIDRv4 ensures that the given string satisfies [CIDRv4].
func (v *Validator) CIDRv4(field, val string) {
if !CIDRv4(val) {
v.Fail(field, "must be a valid IPv4 CIDR")
}
}
// CIDRv6 ensures that the given string satisfies [CIDRv6].
func (v *Validator) CIDRv6(field, val string) {
if !CIDRv6(val) {
v.Fail(field, "must be a valid IPv6 CIDR")
}
}
// Hostname ensures that the given string satisfies [Hostname].
func (v *Validator) Hostname(field, val string) {
if !Hostname(val) {
v.Fail(field, "must be a valid hostname")
}
}
// Port ensures that the given value satisfies [Port].
func (v *Validator) Port(field string, val int) {
if !Port(val) {
v.Fail(field, "must be a valid port number")
}
}
// IP ensures that the given string satisfies [IP].
func (v *Validator) IP(field, val string) {
if !IP(val) {
v.Fail(field, "must be a valid IP address")
}
}
// IPv4 ensures that the given string satisfies [IPv4].
func (v *Validator) IPv4(field, val string) {
if !IPv4(val) {
v.Fail(field, "must be a valid IPv4 address")
}
}
// IPv6 ensures that the given string satisfies [IPv6].
func (v *Validator) IPv6(field, val string) {
if !IPv6(val) {
v.Fail(field, "must be a valid IPv6 address")
}
}
// FQDN ensures that the given string satisfies [FQDN].
func (v *Validator) FQDN(field, val string) {
if !FQDN(val) {
v.Fail(field, "must be a valid FQDN")
}
}
// URI ensures that the given string satisfies [URI].
func (v *Validator) URI(field, val string) {
if !URI(val) {
v.Fail(field, "must be a valid URI")
}
}
// URL ensures that the given string satisfies [URL].
func (v *Validator) URL(field, val string) {
if !URL(val) {
v.Fail(field, "must be a valid URL")
}
}
// URN ensures that the given string satisfies [URN].
func (v *Validator) URN(field, val string) {
if !URN(val) {
v.Fail(field, "must be a valid URN")
}
}
// Alpha ensures that the given string satisfies [Alpha].
func (v *Validator) Alpha(field, val string) {
if !Alpha(val) {
v.Fail(field, "must contain only alphabetical characters")
}
}
// AlphaNum ensures that the given string satisfies [AlphaNum].
func (v *Validator) AlphaNum(field, val string) {
if !AlphaNum(val) {
v.Fail(field, "must contain only alphanumeric characters")
}
}
// ASCII ensures that the given string satisfies [ASCII].
func (v *Validator) ASCII(field, val string) {
if !ASCII(val) {
v.Fail(field, "must contain only ASCII characters")
}
}
// Slug ensures that the given string satisfies [Slug].
func (v *Validator) Slug(field, val string) {
if !Slug(val) {
v.Fail(field, "must be a valid slug")
}
}
// Upper ensures that the given string satisfies [Upper].
func (v *Validator) Upper(field, val string) {
if !Upper(val) {
v.Fail(field, "must contain only uppercase characters")
}
}
// Lower ensures that the given string satisfies [Lower].
func (v *Validator) Lower(field, val string) {
if !Lower(val) {
v.Fail(field, "must contain only lowercase characters")
}
}
// Base64 ensures that the given string satisfies [Base64].
func (v *Validator) Base64(field, val string) {
if !Base64(val) {
v.Fail(field, "must be a valid Base64 string")
}
}
// Base64URL ensures that the given string satisfies [Base64URL].
func (v *Validator) Base64URL(field, val string) {
if !Base64URL(val) {
v.Fail(field, "must be a valid Base64URL string")
}
}
// MAC ensures that the given string satisfies [MAC].
func (v *Validator) MAC(field, val string) {
if !MAC(val) {
v.Fail(field, "must be a valid MAC address")
}
}
// Lang ensures that the given string satisfies [Lang].
func (v *Validator) Lang(field, val string) {
if !Lang(val) {
v.Fail(field, "must be a valid BCP 47 language tag")
}
}
// JSON ensures that the given string satisfies [JSON].
func (v *Validator) JSON(field, val string) {
if !JSON(val) {
v.Fail(field, "must be a valid JSON document")
}
}
// MIME ensures that the given string satisfies [MIME].
func (v *Validator) MIME(field, val string) {
if !MIME(val) {
v.Fail(field, "must be a valid MIME type")
}
}
// CreditCard ensures that the given string satisfies [CreditCard].
func (v *Validator) CreditCard(field, val string) {
if !CreditCard(val) {
v.Fail(field, "must be a valid credit card number")
}
}
// Email ensures that the given string satisfies [Email].
func (v *Validator) Email(field, val string) {
if !Email(val) {
v.Fail(field, "must be a valid email address")
}
}
// Hex ensures that the given string satisfies [Hex].
func (v *Validator) Hex(field, val string) {
if !Hex(val) {
v.Fail(field, "must be a valid hexadecimal number")
}
}
// HexColor ensures that the given string satisfies [HexColor].
func (v *Validator) HexColor(field, val string) {
if !HexColor(val) {
v.Fail(field, "must be a valid hex color code")
}
}
// ISSN ensures that the given string satisfies [ISSN].
func (v *Validator) ISSN(field, val string) {
if !ISSN(val) {
v.Fail(field, "must be a valid ISSN")
}
}
// ISBN10 ensures that the given string satisfies [ISBN10].
func (v *Validator) ISBN10(field, val string) {
if !ISBN10(val) {
v.Fail(field, "must be a valid ISBN-10")
}
}
// ISBN13 ensures that the given string satisfies [ISBN13].
func (v *Validator) ISBN13(field, val string) {
if !ISBN13(val) {
v.Fail(field, "must be a valid ISBN-13")
}
}
// ISBN ensures that the given string satisfies [ISBN].
func (v *Validator) ISBN(field, val string) {
if !ISBN(val) {
v.Fail(field, "must be a valid ISBN")
}
}
// Country2 ensures that the given string satisfies [Country2].
func (v *Validator) Country2(field, val string) {
if !Country2(val) {
v.Fail(field, "must be a valid ISO 3166-1 alpha-2 code")
}
}
// Country3 ensures that the given string satisfies [Country3].
func (v *Validator) Country3(field, val string) {
if !Country3(val) {
v.Fail(field, "must be a valid ISO 3166-1 alpha-3 code")
}
}
// CountryN ensures that the given string satisfies [CountryN].
func (v *Validator) CountryN(field, val string) {
if !CountryN(val) {
v.Fail(field, "must be a valid ISO 3166-1 numeric code")
}
}
// Currency ensures that the given string satisfies [Currency].
func (v *Validator) Currency(field, val string) {
if !Currency(val) {
v.Fail(field, "must be a valid ISO 4217 currency code")
}
}
// Currency ensures that the given string satisfies [UUIDv7].
func (v *Validator) UUIDv7(field, val string) {
if !Currency(val) {
v.Fail(field, "must be a valid UUIDv7")
}
}
// Lat ensures that the given value satisfies [Lat].
func (v *Validator) Lat(field string, val float32) {
if !Lat(val) {
v.Fail(field, "must be a valid latitude")
}
}
// Lon ensures that the given value satisfies [Lon].
func (v *Validator) Lon(field string, val float32) {
if !Lon(val) {
v.Fail(field, "must be a valid longitude")
}
}
// MD5 ensures that the given string satisfies [MD5].
func (v *Validator) MD5(field, val string) {
if !MD5(val) {
v.Fail(field, "must be a valid MD5 hash")
}
}
// SHA256 ensures that the given string satisfies [SHA256].
func (v *Validator) SHA256(field, val string) {
if !SHA256(val) {
v.Fail(field, "must be a valid SHA256 hash")
}
}
// SHA384 ensures that the given string satisfies [SHA384].
func (v *Validator) SHA384(field, val string) {
if !SHA384(val) {
v.Fail(field, "must be a valid SHA384 hash")
}
}
// SHA512 ensures that the given string satisfies [SHA512].
func (v *Validator) SHA512(field, val string) {
if !SHA512(val) {
v.Fail(field, "must be a valid SHA512 hash")
}
}
// SemVer ensures that the given string satisfies [SemVer].
func (v *Validator) SemVer(field, val string) {
if !SemVer(val) {
v.Fail(field, "must be a valid semantic version")
}
}
// Phone ensures that the given string satisfies [Phone].
func (v *Validator) Phone(field, val string) {
if !Phone(val) {
v.Fail(field, "must be a valid E.164 phone number")
}
}
// BIC ensures that the given string satisfies [BIC].
func (v *Validator) BIC(field, val string) {
if !BIC(val) {
v.Fail(field, "must be a valid BIC")
}
}
// IBAN ensures that the given string satisfies [IBAN].
func (v *Validator) IBAN(field, val string) {
if !IBAN(val) {
v.Fail(field, "must be a valid IBAN")
}
}