// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package app provides a managed lifecycle for command-line applications,
// ensuring graceful shutdown on OS signals.
//
// The Run function is the main entry point. It wraps your application
// components (Runnables), executing them concurrently. It listens for interrupt
// signals (like SIGINT/SIGTERM) and propagates a cancellation signal via a
// context. This allows your application to perform cleanup tasks before
// exiting.
//
// # Usage
//
// A typical use case involves starting workers or servers that run until
// interrupted. The Run function handles signal trapping, concurrency, and
// timeouts, letting you focus on the business logic.
//
// func main() {
// // 1. Configure a logger (slog).
// logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
//
// // 2. Define the application components.
// // These functions block until ctx is canceled or an error occurs.
// worker := func(ctx context.Context) error {
// logger.Info("Worker started")
//
// // Simulate a task that runs periodically.
// ticker := time.NewTicker(1 * time.Second)
// defer ticker.Stop()
//
// for {
// select {
// case <-ctx.Done():
// // Context canceled (signal received or sibling component failed).
// logger.Info("Worker stopping...")
//
// // Perform cleanup (e.g., closing DB connections).
// time.Sleep(500 * time.Millisecond)
// return nil
//
// case t := <-ticker.C:
// logger.Info("Working...", "time", t.Format(time.TimeOnly))
// }
// }
// }
//
// server := func(ctx context.Context) error {
// logger.Info("Server started")
// <-ctx.Done()
// logger.Info("Server stopping...")
// return nil
// }
//
// // 3. Run the application components concurrently.
// err := app.RunAll(
// []app.Runnable{worker, server},
// app.WithLogger(logger),
// )
// if err != nil {
// logger.Error("Application failed", "error", err)
// os.Exit(1)
// }
package app
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"runtime/debug"
"syscall"
"time"
"golang.org/x/sync/errgroup"
)
// DefaultTimeout is the default duration to wait for the application to
// gracefully shut down after receiving a termination signal.
const DefaultTimeout = 10 * time.Second
// Runnable defines a function that can be executed by the application runner.
// It receives a context that is canceled when a shutdown signal is received,
// or if another concurrently running Runnable returns an error.
// The function should perform its cleanup and return when the context is done.
type Runnable func(ctx context.Context) error
type config struct {
logger *slog.Logger
timeout time.Duration
signals []os.Signal
ctx context.Context
}
// Option is a function that configures the application runner.
type Option func(*config)
// WithLogger provides a custom logger for the application runner. If not set,
// the runner defaults to slog.Default(). A nil value will be ignored.
func WithLogger(log *slog.Logger) Option {
return func(opts *config) {
if log != nil {
opts.logger = log
}
}
}
// WithTimeout sets a custom timeout for the graceful shutdown process.
// If the application components take longer than this duration to return after
// a shutdown signal is received, the runner will exit with an error. A negative
// or zero duration will be ignored, and the DefaultTimeout is used instead.
func WithTimeout(d time.Duration) Option {
return func(opts *config) {
if d > 0 {
opts.timeout = d
}
}
}
// WithSignals allows customization of which OS signals trigger a shutdown.
// If not used, it defaults to SIGTERM and SIGINT.
func WithSignals(signals ...os.Signal) Option {
return func(c *config) {
if len(signals) > 0 {
c.signals = signals
}
}
}
// WithContext sets a parent context for the runner. The runner's main
// context will be a child of this parent. Cancelling the parent context
// triggers a graceful shutdown. If not set, context.Background() is used as
// the default parent. A nil value will be ignored.
func WithContext(ctx context.Context) Option {
return func(c *config) {
if ctx != nil {
c.ctx = ctx
}
}
}
// Run provides a managed execution environment for a single Runnable.
// It launches the Runnable in a separate goroutine and blocks until it
// completes, an OS interrupt signal is caught, or the parent context is
// canceled. For running multiple components concurrently, see RunAll.
func Run(runnable Runnable, opts ...Option) error {
return RunAll([]Runnable{runnable}, opts...)
}
// RunAll provides a managed execution environment for multiple Runnables.
// It launches each Runnable in a separate goroutine and blocks until they
// all complete on their own, an OS interrupt signal is caught, the parent
// context (if specified via WithContext) is canceled, or any single Runnable
// returns an error.
//
// Upon receiving a signal or encountering an error in any Runnable, it
// cancels the context passed to all Runnables and waits for the specified
// shutdown timeout. The Runnables are expected to honor the context
// cancellation and perform any necessary cleanup before returning. RunAll
// returns any error from the Runnables themselves, or an error if the shutdown
// process times out.
func RunAll(runnables []Runnable, opts ...Option) error {
cfg := config{
logger: slog.Default(),
timeout: DefaultTimeout,
signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT},
ctx: context.Background(),
}
for _, opt := range opts {
opt(&cfg)
}
// Create a context that cancels on OS signals.
ctx, cancel := signal.NotifyContext(cfg.ctx, cfg.signals...)
defer cancel()
// Use errgroup to manage concurrent runnables. The group context will be
// canceled if the base context is canceled, or if any goroutine returns an
// error.
g, gCtx := errgroup.WithContext(ctx)
cfg.logger.Info("Application started", slog.Int("components", len(runnables)))
for _, fn := range runnables {
g.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
stack := string(debug.Stack())
err = fmt.Errorf("application panic: %v\nstack: %s", r, stack)
}
}()
return fn(gCtx)
})
}
errCh := make(chan error, 1)
go func() {
errCh <- g.Wait()
}()
select {
case err := <-errCh:
// The application exited naturally or due to a failure in one component.
if err != nil {
return fmt.Errorf("application exited with error: %w", err)
}
cfg.logger.Info("Application stopped")
return nil
case <-ctx.Done():
// A signal was received (or parent context canceled).
cfg.logger.Info("Shutdown signal received, initiating graceful shutdown")
timer := time.NewTimer(cfg.timeout)
defer timer.Stop()
select {
case err := <-errCh:
// If the error encountered is just "context canceled", we consider it a
// successful shutdown.
if err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("error during graceful shutdown: %w", err)
}
cfg.logger.Info("Shutdown completed successfully")
return nil
case <-timer.C:
return fmt.Errorf("shutdown timed out after %v", cfg.timeout)
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package backoff provides customizable strategies for retrying operations with
// increasing delays.
//
// The core of the package is the Strategy interface, which computes the next
// backoff duration. Implementations are stateful; the Next method returns
// progressively longer durations with each call. Once the retried operation is
// successful or abandoned, Done must be called to reset the strategy's internal
// state.
//
// # Usage
//
// A default exponential backoff strategy with jitter can be created using the
// New function. The behavior can be customized using various Option functions,
// such as WithMinDelay, WithMaxDelay, WithGrowthFactor, and WithJitterAmount.
// Jitter is added by default to prevent multiple clients from retrying in sync
// (the "thundering herd" problem), which can overwhelm a recovering service.
package backoff
import (
"math"
"sync/atomic"
"time"
"github.com/deep-rent/nexus/internal/jitter"
)
const (
// DefaultMinDelay is the default minimum time between consecutive retries.
DefaultMinDelay = 1 * time.Second
// DefaultMaxDelay is the default maximum time between consecutive retries.
DefaultMaxDelay = 1 * time.Minute
// DefaultGrowthFactor is the default growth factor in exponential backoff.
DefaultGrowthFactor float64 = 2.0
// DefaultJitterAmount is the default amount of jitter applied.
DefaultJitterAmount float64 = 0.5
)
// Strategy defines the contract for a backoff algorithm.
// Implementations of this interface are expected to be safe for concurrent use.
type Strategy interface {
// Next returns the backoff duration for the upcoming retry attempt.
// This method is stateful and returns incrementally larger durations based
// on the number of times it has been called since the last call to Done.
// The returned duration is bounded by MinDelay() and MaxDelay().
Next() time.Duration
// Done resets the strategy's internal state, such as its attempt counter.
// This must be called after the retried operation succeeds or is abandoned.
Done()
// MinDelay returns the lower bound for the backoff duration returned by Next.
MinDelay() time.Duration
// MaxDelay returns the upper bound for the backoff duration returned by Next.
MaxDelay() time.Duration
}
type constant struct {
delay time.Duration
}
// Constant produces a Strategy that always yields the same delay duration.
// If the provided delay is negative, it is treated as zero (meaning no delay).
func Constant(delay time.Duration) Strategy {
return &constant{delay: max(0, delay)}
}
func (c *constant) Next() time.Duration { return c.delay }
func (c *constant) Done() {}
func (c *constant) MinDelay() time.Duration { return c.delay }
func (c *constant) MaxDelay() time.Duration { return c.delay }
var _ Strategy = (*constant)(nil)
type linear struct {
minDelay time.Duration
maxDelay time.Duration
attempts atomic.Int64
}
func (l *linear) Next() time.Duration {
n := l.attempts.Add(1)
// If n is huge, simply return the delay limit to avoid overflow.
if l.minDelay > 0 && n > int64(l.maxDelay/l.minDelay) {
return l.maxDelay
}
d := l.minDelay * time.Duration(n)
return max(l.minDelay, min(l.maxDelay, d))
}
func (l *linear) Done() { l.attempts.Store(0) }
func (l *linear) MinDelay() time.Duration { return l.minDelay }
func (l *linear) MaxDelay() time.Duration { return l.maxDelay }
var _ Strategy = (*linear)(nil)
type exponential struct {
minDelay time.Duration
maxDelay time.Duration
growthFactor float64
attempts atomic.Int64
}
func (e *exponential) Next() time.Duration {
n := e.attempts.Add(1)
d := time.Duration(float64(e.minDelay) * math.Pow(e.growthFactor, float64(n)))
return max(e.minDelay, min(e.maxDelay, d))
}
func (e *exponential) Done() { e.attempts.Store(0) }
func (e *exponential) MinDelay() time.Duration { return e.minDelay }
func (e *exponential) MaxDelay() time.Duration { return e.maxDelay }
var _ Strategy = (*exponential)(nil)
// spread decorates a Strategy with jittering capabilities in order to spread
// out retry attempts over time.
type spread struct {
s Strategy
j *jitter.Jitter
}
func (s *spread) Next() time.Duration {
return s.j.Apply(s.s.Next())
}
func (s *spread) Done() {
s.s.Done()
}
func (s *spread) MinDelay() time.Duration {
return s.j.Floor(s.s.MinDelay(), 1)
}
func (s *spread) MaxDelay() time.Duration {
return s.s.MaxDelay() // Jitter does not affect the maximum delay.
}
var _ Strategy = (*spread)(nil)
type config struct {
minDelay time.Duration
maxDelay time.Duration
growthFactor float64
jitterAmount float64
r jitter.Rand
}
// Option customizes the behavior of a backoff Strategy.
type Option func(*config)
// WithMinDelay sets the minimum time between consecutive retries.
// It is capped at zero (meaning no delay) if a negative duration is provided.
// If equal to or greater than the maximum delay, the backoff delays remain
// constant at the maximum delay. If not customized, the DefaultMinDelay
// is used.
//
// When jitter is introduced, the minimum delay is effectively reduced
// proportional to the jitter amount. Thus, the strategy might return a delay
// shorter than the configured minimum delay, depending on the random output.
func WithMinDelay(d time.Duration) Option {
return func(c *config) {
c.minDelay = max(0, d)
}
}
// WithMaxDelay sets the maximum time between consecutive retries.
// It is capped at zero (meaning no delay) if a negative duration is provided.
// If less than or equal to the minimum delay, the backoff delays remain
// constant at the maximum delay. If not customized, the DefaultMaxDelay
// is used.
func WithMaxDelay(d time.Duration) Option {
return func(c *config) {
c.maxDelay = max(0, d)
}
}
// WithGrowthFactor determines the growth factor (multiplier) for exponential
// backoff. A factor equal to one results in linear backoff, where the minimum
// delay becomes the step size. Any factor less than one is treated as one.
// If not customized, the DefaultGrowthFactor is used.
func WithGrowthFactor(f float64) Option {
return func(c *config) {
c.growthFactor = f
}
}
// WithJitterAmount specifies the amount of random jitter to apply to the
// backoff delays. It is expressed as a fraction of the delay, where 0 means no
// jitter and 1 means full jitter. The given number is capped between 0 and 1.
// If not customized, the DefaultJitterAmount is used.
//
// Jitter scatters the retry attempts in time, which aims to mitigate the
// thundering herd problem, where many clients retry simultaneously.
func WithJitterAmount(p float64) Option {
return func(c *config) {
c.jitterAmount = min(1, p)
}
}
// WithRand sets the source of randomness for jittering. If not specified or
// nil, a default source will be seeded with the current system time.
func WithRand(r jitter.Rand) Option {
return func(c *config) {
if r != nil {
c.r = r
}
}
}
// New creates a new backoff Strategy based on the provided options.
func New(opts ...Option) Strategy {
c := config{
minDelay: DefaultMinDelay,
maxDelay: DefaultMaxDelay,
growthFactor: DefaultGrowthFactor,
jitterAmount: DefaultJitterAmount,
}
for _, opt := range opts {
opt(&c)
}
if c.minDelay >= c.maxDelay {
return &constant{
delay: c.maxDelay,
}
}
if c.growthFactor <= 1 {
return &linear{
minDelay: c.minDelay,
maxDelay: c.maxDelay,
}
}
s := &exponential{
minDelay: c.minDelay,
maxDelay: c.maxDelay,
growthFactor: c.growthFactor,
}
if c.jitterAmount <= 0 {
return s
}
return &spread{
s: s,
j: jitter.New(c.jitterAmount, c.r),
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cache provides a generic, auto-refreshing in-memory cache for a
// resource fetched from a URL.
//
// The core of the package is the Controller, a scheduler.Tick that periodically
// fetches a remote resource, parses it, and caches it in memory. It is designed
// to be resilient, with a built-in, configurable retry mechanism for handling
// transient network failures.
//
// The refresh interval is intelligently determined by the resource's caching
// headers (e.g., Cache-Control, Expires), but can be clamped within a specified
// min/max range. The controller also handles conditional requests using ETag
// and Last-Modified headers to reduce bandwidth and server load.
//
// # Usage
//
// A typical use case involves creating a scheduler, defining a Mapper function
// to parse the HTTP response, creating and configuring a Controller, and then
// dispatching it to run in the background.
//
// Example:
//
// type Resource struct {
// // fields for the parsed data
// }
//
// // 1. Create a scheduler to manage the refresh ticks.
// sched := scheduler.New(context.Background())
// defer sched.Shutdown()
//
// // 2. Define a mapper to parse the response body into your target type.
// var mapper cache.Mapper[Resource] = func(body []byte) (Resource, error) {
// var data Resource
// err := json.Unmarshal(body, &data)
// return data, err
// }
//
// // 3. Create and configure the cache controller.
// ctrl := cache.NewController(
// "https://api.example.com/resource",
// mapper,
// cache.WithMinInterval(5*time.Minute),
// cache.WithHeader("Authorization", "Bearer *****"),
// )
//
// // 4. Dispatch the controller to start fetching in the background.
// sched.Dispatch(ctrl)
//
// // 5. You can wait for the first successful fetch.
// <-ctrl.Ready()
//
// // 6. Get the cached data.
// if data, ok := ctrl.Get(); ok {
// fmt.Printf("Successfully fetched and cached data: %+v\n", data)
// }
package cache
import (
"context"
"crypto/tls"
"io"
"log/slog"
"net"
"net/http"
"sync"
"time"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/retry"
"github.com/deep-rent/nexus/scheduler"
)
// Default configuration values for the cache controller.
const (
// DefaultTimeout is the default timeout for a single HTTP request.
DefaultTimeout = 30 * time.Second
// DefaultMinInterval is the default lower bound for the refresh interval.
DefaultMinInterval = 15 * time.Minute
// DefaultMaxInterval is the default upper bound for the refresh interval.
DefaultMaxInterval = 24 * time.Hour
)
// Mapper is a function that parses a response's raw response body into the
// target type T. It is responsible for decoding the data (e.g., from JSON or
// XML) and returning the structured result. An error should be returned if
// parsing fails fatally. For warnings or debug information, invoke the logger
// contained in the Response. If the mapping takes a considerable amount of
// time, it should generally respect the context contained in the Response.
type Mapper[T any] func(r *Response) (T, error)
// Response provides contextual information to a Mapper function,
// including the response body, request context, and a logger.
type Response struct {
Body []byte // Raw response payload to be mapped.
Ctx context.Context // Context controlling the HTTP exchange.
Logger *slog.Logger // Logger instance inherited from the Controller.
}
// Controller manages the lifecycle of a cached resource. It implements
// scheduler.Tick, allowing it to be run by a scheduler to periodically refresh
// the resource from a URL.
type Controller[T any] interface {
scheduler.Tick
// Get retrieves the currently cached resource. The boolean return value is
// true if the cache has been successfully populated at least once.
Get() (T, bool)
// Ready returns a channel that is closed once the first successful fetch of
// the resource is complete. This allows consumers to block until the cache
// is warmed up.
Ready() <-chan struct{}
}
// NewController creates and configures a new cache Controller.
//
// It requires a URL for the resource to fetch and a Mapper function to parse
// the response. If no http.Client is provided via options, it creates a default
// one with a sensible timeout and a retry transport.
func NewController[T any](
url string,
mapper Mapper[T],
opts ...Option,
) Controller[T] {
cfg := config{
client: nil,
timeout: DefaultTimeout,
headers: make([]header.Header, 0, 3),
minInterval: DefaultMinInterval,
maxInterval: DefaultMaxInterval,
logger: slog.Default(),
}
for _, opt := range opts {
opt(&cfg)
}
client := cfg.client
if client == nil {
d := &net.Dialer{
Timeout: cfg.timeout / 3,
KeepAlive: 0,
}
var t http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: d.DialContext,
TLSClientConfig: cfg.tls,
TLSHandshakeTimeout: cfg.timeout / 3,
ResponseHeaderTimeout: cfg.timeout * 9 / 10,
ExpectContinueTimeout: 1 * time.Second,
DisableKeepAlives: true,
}
t = retry.NewTransport(header.NewTransport(t, cfg.headers...), cfg.retry...)
client = &http.Client{
Timeout: cfg.timeout,
Transport: t,
}
}
return &controller[T]{
url: url,
mapper: mapper,
client: client,
minInterval: cfg.minInterval,
maxInterval: cfg.maxInterval,
logger: cfg.logger,
readyChan: make(chan struct{}),
}
}
// controller is the internal implementation of the Controller interface.
type controller[T any] struct {
url string
mapper Mapper[T]
client *http.Client
minInterval time.Duration
maxInterval time.Duration
now func() time.Time
logger *slog.Logger
readyOnce sync.Once
readyChan chan struct{}
mu sync.RWMutex
resource T
ok bool
etag string
lastModified string
}
// Get retrieves the currently cached resource.
func (c *controller[T]) Get() (T, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.resource, c.ok
}
// Ready returns a channel that is closed when the cache is first populated.
func (c *controller[T]) Ready() <-chan struct{} {
return c.readyChan
}
// ready ensures the ready channel is closed exactly once.
func (c *controller[T]) ready() {
c.readyOnce.Do(func() { close(c.readyChan) })
}
// Run executes a single fetch-and-cache cycle. It implements the scheduler.Tick
// interface. It handles conditional requests, response parsing, and caching,
// and returns the duration to wait before the next run.
func (c *controller[T]) Run(ctx context.Context) time.Duration {
c.logger.Debug("Fetching resource")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
if err != nil {
// This is a non-retriable error in request creation.
c.logger.Error("Failed to create request", slog.Any("error", err))
return c.minInterval // Wait a long time before trying to create it again.
}
// Add conditional headers if we have them from a previous request.
c.mu.RLock()
if c.etag != "" {
req.Header.Set("If-None-Match", c.etag)
}
if c.lastModified != "" {
req.Header.Set("If-Modified-Since", c.lastModified)
}
c.mu.RUnlock()
res, err := c.client.Do(req)
if err != nil {
if err != context.Canceled {
c.logger.Error(
"HTTP request failed after retries",
slog.Any("error", err),
)
}
return c.minInterval
}
defer func() {
if err := res.Body.Close(); err != nil {
c.logger.Warn("Failed to close response body", slog.Any("error", err))
}
}()
switch code := res.StatusCode; code {
case http.StatusNotModified:
c.logger.Debug("Resource unchanged", slog.String("etag", c.etag))
c.ready()
return c.delay(res.Header)
case http.StatusOK:
body, err := io.ReadAll(res.Body)
if err != nil {
c.logger.Error("Failed to read response body", slog.Any("error", err))
return c.minInterval
}
resource, err := c.mapper(&Response{
Body: body,
Ctx: req.Context(),
Logger: c.logger,
})
if err != nil {
c.logger.Error("Couldn't parse response body", slog.Any("error", err))
return c.minInterval
}
c.mu.Lock()
c.resource = resource
c.etag = res.Header.Get("ETag")
c.lastModified = res.Header.Get("Last-Modified")
c.ok = true
c.mu.Unlock()
c.logger.Info("Resource updated successfully")
c.ready()
return c.delay(res.Header)
default:
c.logger.Error(
"Received a non-retriable HTTP status code",
slog.Int("status", code),
)
return c.minInterval
}
}
// delay calculates the duration until the next fetch based on caching headers,
// clamped by the configured min/max intervals.
func (c *controller[T]) delay(h http.Header) time.Duration {
d := header.Lifetime(h, c.now)
if d > c.maxInterval {
return c.maxInterval
}
if d < c.minInterval {
return c.minInterval
}
return d
}
var _ Controller[any] = (*controller[any])(nil)
// config holds the internal configuration for the cache controller.
type config struct {
client *http.Client
timeout time.Duration
headers []header.Header
tls *tls.Config
minInterval time.Duration
maxInterval time.Duration
retry []retry.Option
logger *slog.Logger
}
// Option is a function that configures the cache Controller.
type Option func(*config)
// WithClient provides a custom http.Client to be used for requests. This is
// useful for advanced configurations, such as custom transports or connection
// pooling. If not provided, a default client with retry logic is created.
func WithClient(client *http.Client) Option {
return func(c *config) {
if client != nil {
c.client = client
}
}
}
// WithTimeout sets the total timeout for a single HTTP fetch attempt, including
// connection, redirects, and reading the response body. This is ignored if a
// custom client is provided via WithClient.
func WithTimeout(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.timeout = d
}
}
}
// WithHeader adds a static header to every request sent by the controller. This
// can be called multiple times to add multiple headers.
func WithHeader(k, v string) Option {
return func(c *config) {
c.headers = append(c.headers, header.New(k, v))
}
}
// WithTLSConfig provides a custom tls.Config for the default HTTP transport.
// This is ignored if a custom client is provided via WithClient.
func WithTLSConfig(tls *tls.Config) Option {
return func(c *config) {
c.tls = tls
}
}
// WithMinInterval sets the minimum duration between refresh attempts. The
// refresh delay, typically determined by caching headers, will not be shorter
// than this.
func WithMinInterval(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.minInterval = d
}
}
}
// WithMaxInterval sets the maximum duration between refresh attempts. The
// refresh delay will not be longer than this value.
func WithMaxInterval(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.maxInterval = d
}
}
}
// WithRetryOptions configures the retry mechanism for the default HTTP client.
// These options are ignored if a custom client is provided via WithClient.
func WithRetryOptions(opts ...retry.Option) Option {
return func(c *config) {
c.retry = append(c.retry, opts...)
}
}
// WithLogger provides a custom slog.Logger for the controller. If not provided,
// slog.Default() is used.
func WithLogger(log *slog.Logger) Option {
return func(c *config) {
if log != nil {
c.logger = log
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package di provides a type-safe, concurrent dependency injection
// container for Go applications.
//
// The core concepts are:
// - Injector: The main container that holds all service bindings.
// - Slot: A unique, typed key used to register and retrieve services.
// - Provider: A factory function that creates an instance of a service.
// - Resolver: A strategy that defines the lifecycle of a service.
//
// # Usage
//
// Let's explore how to use this container by modeling a simple chemical
// reaction: forming a Salt. This example shows how to use a single interface
// (Ion) for dependencies that fulfill different roles (Cation and Anion).
//
// Step 1: Define a Reusable Abstraction (The Ion Role)
//
// Instead of separate Cation and Anion interfaces, we can define a single,
// reusable Ion interface. Our Salt struct will now depend on two instances
// of this same interface type.
//
// // Ion represents any particle with a symbol and a charge.
// type Ion interface {
// Symbol() string
// Charge() int
// }
//
// // Salt is our final product, which depends on two Ions.
// type Salt struct {
// cation Ion
// anion Ion
// }
//
// func (s Salt) Formula() string {
// return s.cation.Symbol() + s.anion.Symbol()
// }
//
// Step 2: Create Slots for Roles (The Unique Labels)
//
// The key insight here is that slots distinguish dependencies by their role,
// not just their type. Even though both slots below are for the Ion type,
// they are unique keys. This allows us to inject the right ion into the right
// place. The tags ("ion", "cation", etc.) are optional but help with debugging
// in case something goes wrong.
//
// var (
// SlotCation = di.NewSlot[Ion]("ion", "cation")
// SlotAnion = di.NewSlot[Ion]("ion", "anion")
// SlotSalt = di.NewSlot[Salt]("compound", "salt")
// )
//
// Step 3: Write Providers (The Recipes)
//
// Providers now return the generic Ion interface. The Salt provider can
// then request two different Ions by using their distinct role-based slots.
//
// // ProvideSodium provides a concrete Ion to fulfill the Cation role.
// func ProvideSodium(*di.Injector) (Ion, error) {
// type Sodium struct{}
// func (na Sodium) Symbol() string { return "Na" }
// func (na Sodium) Charge() int { return 1 }
// return Sodium{}, nil
// }
//
// // ProvideChloride provides a concrete Ion to fulfill the Anion role.
// func ProvideChloride(*di.Injector) (Ion, error) {
// type Chloride struct{}
// func (cl Chloride) Symbol() string { return "Cl" }
// func (cl Chloride) Charge() int { return -1 }
// return Chloride{}, nil
// }
//
// // ProvideSalt requests dependencies by their role-specific slots.
// func ProvideSalt(in *di.Injector) (Salt, error) {
// // Request the Ion fulfilling the "Cation" role.
// cation := di.Required[Ion](in, SlotCation)
// // Request the Ion fulfilling the "Anion" role.
// anion := di.Required[Ion](in, SlotAnion)
// return Salt{cation: cation, anion: anion}, nil
// }
//
// Step 4: Assemble the Solution (Configure the Injector)
//
// Now, create an Injector and bind the concrete providers to their
// respective role slots. We are telling the container that Sodium will act
// as our Cation and Chloride will act as our Anion.
//
// // 1. Create the injector.
// solution := di.NewInjector()
//
// // 2. Bind concrete providers to their roles. We use Transient scope to
// // obtain fresh ions each time we form a new salt molecule.
// di.Bind(solution, SlotCation, ProvideSodium, di.Transient())
// di.Bind(solution, SlotAnion, ProvideChloride, di.Transient())
//
// // 3. Bind the provider for the final product. A salt molecule is very
// // stable, so we treat it as a singleton.
// di.Bind(solution, SlotSalt, ProvideSalt, di.Singleton())
//
// Step 5: Trigger the Reaction (Resolve the Final Product)
//
// When we ask for the Salt, the injector provides the previously registered
// atoms (dependencies) to the Salt provider to form the final molecule. As
// expected, we obtain ordinary table salt (NaCl).
//
// // This call triggers the entire dependency chain.
// salt := di.Required[Salt](solution, SlotSalt)
//
// fmt.Printf("Successfully formed: %s\n", salt.Formula())
// // Output: Successfully formed: NaCl
package di
import (
"context"
"fmt"
"reflect"
"strings"
"sync"
)
// Slot is an abstract, typed symbol for an injectable service.
// It is a unique pointer that acts as a map key within the Injector,
// while the generic type T provides compile-time type safety.
type Slot[T any] *struct{}
// slots is a global, concurrent map that stores the debug tag for each slot.
var slots = &sync.Map{}
// func Reset() {
// slots = &sync.Map{}
// }
// NewSlot creates a new, unique Slot for a given type T.
//
// The optional keys are used to create a descriptive name for debugging and
// error messages. Multiple keys are joined with dots. This is useful to group
// related services, e.g., by package or feature. The assigned tag can be
// retrieved later using the Tag function.
func NewSlot[T any](keys ...string) Slot[T] {
s := new(struct{})
t := reflect.TypeOf((*T)(nil)).Elem().String()
var tag string
if len(keys) == 0 {
// Case 1: Unnamed slot, e.g., @string
tag = "@" + t
} else {
// Case 2: Named slot, e.g., a.b.c@int
tag = strings.Join(keys, ".") + "@" + t
}
slots.Store(s, tag)
return s
}
// Tag returns the pre-formatted debug string for a slot.
//
// The tag is of the form "name@type", where "name" is the optional name
// assigned during slot creation, and "type" is the Go type of the slot. If no
// name was provided, it returns just the type prefixed with "@". If the slot is
// unknown, it falls back to the pointer address.
func Tag(slot any) string {
if name, ok := slots.Load(slot); ok {
return name.(string)
}
// Fallback to pointer address for unnamed slots.
return fmt.Sprintf("%p", slot)
}
// Provider defines the function signature for a service factory.
//
// When a service is requested, its provider is called with an instance of the
// Injector, which it can then use to resolve any of its own dependencies (e.g.,
// by calling Use). How often the provider is called depends on the number of
// injection sites and the resolution strategy used when binding the provider to
// a slot. By convention, provider functions should be named "Provide<Type>".
// The associated call to di.Bind should then be done in a function named
// "Bind<Type>".
type Provider[T any] func(in *Injector) (T, error)
// binding holds a provider and its associated resolution strategy.
type binding struct {
provider any
resolver Resolver
}
// config holds configuration options for an Injector.
type config struct {
ctx context.Context
}
// Option configures an Injector.
type Option func(*config)
// WithContext sets the root context for the Injector. If ctx is nil, the
// background context is used by default.
func WithContext(ctx context.Context) Option {
return func(cfg *config) {
if ctx != nil {
cfg.ctx = ctx
}
}
}
// Injector is the main dependency injection container.
// It holds all service bindings and manages their lifecycle. An Injector is
// safe for concurrent reads (e.g., using Use, Required), but is not safe for
// concurrent writes (e.g., using Bind, Override). Bindings should be configured
// once at application startup.
type Injector struct {
ctx context.Context
bindings map[any]*binding
mu sync.RWMutex
parent *Injector
}
// NewInjector creates and returns a new, empty Injector with the given options.
// If no options are provided, it defaults to using context.Background().
func NewInjector(opts ...Option) *Injector {
cfg := config{
ctx: context.Background(),
}
for _, opt := range opts {
opt(&cfg)
}
return &Injector{
ctx: cfg.ctx,
bindings: make(map[any]*binding),
}
}
// Context returns the injector's context.
//
// This context is provided during the injector's creation via the WithContext
// option. It serves two primary purposes:
//
// 1. Propagation: It allows for the propagation of request-scoped values,
// deadlines, and cancellation signals throughout the dependency graph.
//
// 2. Scoping: It is the key mechanism for enabling scoped dependencies.
// Resolvers like Scoped() use this context to cache instances that live
// for the duration of the context's lifecycle (e.g., an HTTP request).
func (in *Injector) Context() context.Context {
return in.ctx
}
// Bind registers a provider and its resolver for a specific service slot.
// It is typically called during application initialization.
//
// Bind panics if the slot is already bound in the injector.
func Bind[T any](
in *Injector,
slot Slot[T],
provider Provider[T],
resolver Resolver,
) {
in.mu.Lock()
defer in.mu.Unlock()
if _, ok := in.bindings[slot]; ok {
panic(fmt.Sprintf("slot %s is already bound", Tag(slot)))
}
in.bindings[slot] = &binding{
provider: provider,
resolver: resolver,
}
}
// Use resolves a service from the Injector for a given slot. It is the primary
// method for retrieving dependencies when an error is an expected outcome.
//
// It returns an error if the slot is not bound, if the provider returns an
// error, or if a circular dependency is detected. If the provider returns a
// nil value with no error, Use will return the zero value of T.
//
// Use will panic if the value returned by the provider is not assignable to T,
// which indicates a programming error (e.g., a provider returning an
// incompatible type).
func Use[T any](in *Injector, slot Slot[T]) (T, error) {
v, err := in.Resolve(slot)
if err != nil {
var zero T
return zero, err
}
// If the provider returned a nil interface or pointer with no error.
if v == nil {
var zero T
return zero, nil
}
// This type assertion is critical. It ensures that the value returned
// from the non-generic resolver is of the correct type.
t, ok := v.(T)
if ok {
// This panic indicates a bug in a provider implementation, where it
// returned a concrete type that does not match the slot's type.
return t, nil
}
panic(fmt.Sprintf("provider returned %T for slot %s", v, Tag(slot)))
}
// Optional resolves a service and panics if any resolution error occurs (e.g.,
// an unbound slot or a provider error). However, unlike Required, it allows the
// provider to return a nil value without panicking. It is useful for
// dependencies that are truly optional.
func Optional[T any](in *Injector, slot Slot[T]) T {
v, err := Use(in, slot)
if err != nil {
panic(err)
}
return v
}
// Required resolves a service and panics if an error occurs OR if the resolved
// value is nil. This should be used for critical dependencies that must always
// be present.
//
// It checks for nil-ness on interfaces, pointers, maps, slices, channels, and
// functions.
func Required[T any](in *Injector, slot Slot[T]) T {
v := Optional(in, slot)
val := reflect.ValueOf(v)
switch val.Kind() {
case
reflect.Pointer,
reflect.Interface,
reflect.Slice,
reflect.Map,
reflect.Chan,
reflect.Func:
if val.IsNil() {
panic(fmt.Sprintf(
"required dependency for slot %s is nil",
Tag(slot),
))
}
}
return v
}
// Override registers a provider for a slot, replacing any existing binding.
// This is primarily useful in testing environments to replace production
// services with mocks.
func Override[T any](
in *Injector,
slot Slot[T],
provider Provider[T],
resolver Resolver,
) {
in.mu.Lock()
defer in.mu.Unlock()
in.bindings[slot] = &binding{
provider: provider,
resolver: resolver,
}
}
// visitingKey is the context key for the circular dependency detection map.
type visitingKey struct{}
// Resolve is a non-generic method to resolve a dependency from a slot.
// In most cases, the type-safe functions (Use, Optional, Required) should be
// preferred. Resolve is mostly useful for framework integrations that may need
// to work with slots of an unknown type.
func (in *Injector) Resolve(slot any) (any, error) {
// If this injector is a proxy, it means we are in a nested call to resolve().
// We must use the parent's resolution logic but with our current context,
// which carries the visiting map.
if in.parent != nil {
// The type assertion is safe because we control proxy creation.
visiting := in.ctx.Value(visitingKey{}).(map[any]bool)
return in.parent.resolve(slot, visiting)
}
// This is a top-level call, so create a fresh map.
return in.resolve(slot, make(map[any]bool))
}
// resolve is the internal, recursive implementation for dependency resolution.
// The visiting map tracks the current resolution path to detect cycles.
func (in *Injector) resolve(slot any, visiting map[any]bool) (any, error) {
if visiting[slot] {
return nil, fmt.Errorf(
"circular dependency detected while resolving slot %s",
Tag(slot),
)
}
visiting[slot] = true
in.mu.RLock()
b, ok := in.bindings[slot]
in.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("no provider bound for slot %s", Tag(slot))
}
// Delegate to the resolver (e.g., Singleton), passing the visiting map.
val, err := b.resolver.Resolve(in, b.provider, slot, visiting)
delete(visiting, slot) // Clean up the map on the way back up the call stack.
return val, err
}
// Resolver defines a strategy for managing a service's lifecycle.
type Resolver interface {
// Resolve provides an instance according to the strategy it implements.
// The visiting map tracks the current resolution path to detect cycles.
Resolve(
in *Injector,
provider any,
slot any,
visiting map[any]bool,
) (any, error)
}
// singleton is a Resolver that caches the service instance.
type singleton struct {
instance any
err error
once sync.Once
}
func (s *singleton) Resolve(
in *Injector,
provider any,
slot any,
visiting map[any]bool,
) (any, error) {
s.once.Do(func() {
s.instance, s.err = provide(in, provider, slot, visiting)
})
return s.instance, s.err
}
// Singleton returns a Resolver that creates an instance once per injector and
// reuses it for all subsequent requests.
func Singleton() Resolver {
return &singleton{}
}
// transient is a Resolver that always creates a new service instance.
type transient struct{}
func (transient) Resolve(
in *Injector,
provider any,
slot any,
visiting map[any]bool,
) (any, error) {
return provide(in, provider, slot, visiting)
}
// provide is an internal helper that safely invokes a provider function.
// It recovers from panics and creates a proxy injector to propagate the
// circular dependency map.
func provide(
in *Injector,
provider any,
slot any,
visiting map[any]bool,
) (instance any, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf(
"panic during provider call for slot %s: %v",
Tag(slot), r,
)
instance = nil
}
}()
// Create a proxy injector for the provider call. This proxy carries the
// visiting map within its context. When the provider calls Use/Required,
// the proxy's Resolve method is called, which correctly propagates the map.
proxy := &Injector{
parent: in,
ctx: context.WithValue(in.ctx, visitingKey{}, visiting),
}
// Use reflection to call the provider.
val := reflect.ValueOf(provider)
out := val.Call([]reflect.Value{reflect.ValueOf(proxy)})
// The provider signature is func(...) (T, error).
if out[1].IsNil() {
instance = out[0].Interface()
} else {
err = out[1].Interface().(error)
}
return
}
// Transient returns a Resolver that creates a new instance of the service
// every time it is requested.
func Transient() Resolver {
return transient{}
}
// scopedCacheKey is the context key for the scoped dependency cache.
type scopedCacheKey struct{}
// NewScope creates a new context that carries a cache for scoped dependencies.
// This should be called at the beginning of an operation that defines a scope,
// such as a new HTTP request. The returned context should be passed to a new
// or child injector via WithContext.
func NewScope(ctx context.Context) context.Context {
return context.WithValue(ctx, scopedCacheKey{}, &sync.Map{})
}
// scoped is a Resolver that ties the service lifecycle to a context scope.
type scoped struct{}
func (s scoped) Resolve(
in *Injector,
provider any,
slot any,
visiting map[any]bool,
) (any, error) {
val := in.Context().Value(scopedCacheKey{})
cache, ok := val.(*sync.Map)
if !ok || cache == nil {
return nil, fmt.Errorf(
"no scope cache found in context for scoped slot %s",
Tag(slot),
)
}
// Check if an instance already exists in the scope's cache.
// The nil marker (`struct{}{}``) is used to handle the case where a provider
// legitimately returns nil, to prevent re-invocation.
if instance, loaded := cache.Load(slot); loaded {
return instance, nil
}
// If not found, create a new instance.
instance, err := provide(in, provider, slot, visiting)
if err != nil {
// Do not cache the slot if the provider failed.
return nil, err
}
// Store the new instance in the cache.
actual, _ := cache.LoadOrStore(slot, instance)
return actual, nil
}
// Scoped returns a Resolver that ties the lifecycle of a service to a
// context.Context. A new instance is created once per scope, defined by a
// call to NewScope. It requires that the injector's context was created
// via NewScope.
func Scoped() Resolver {
return scoped{}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package env provides functionality for unmarshaling environment variables
// into Go structs.
//
// By default, all exported fields in a struct are mapped to environment
// variables. The variable name is derived by converting the field's name to
// uppercase SNAKE_CASE (e.g., a field named APIKey maps to API_KEY).
// This behavior can be customized or disabled on a per-field basis using
// struct tags.
//
// # Usage
//
// Define a struct to hold your configuration. Only exported fields will be
// considered. The code snippet below showcases various field types and
// struct tag options:
//
// type Config struct {
// Host string `env:",required"`
// Port int `env:",default:8080"`
// Timeout time.Duration `env:",unit:s"`
// Debug bool
// Proxy ProxyConfig `env:",prefix:'HTTP_PROXY_'"`
// Roles []string `env:",split:';'"`
// Internal int `env:"-"`
// internal int
// }
//
// var cfg Config
// if err := env.Unmarshal(&cfg); err != nil {
// log.Fatalf("failed to unmarshal config: %v", err)
// }
// // Use the configuration to bootstrap your application...
//
// # Options
//
// The behavior of the unmarshaler is controlled by the env struct field tag.
// The tag is a comma-separated string of options.
//
// The first value is the name of the environment variable. If it is omitted,
// the field's name is used as the base for the variable name.
//
// DatabaseURL string `env:"MY_DATABASE_URL"`
//
// The subsequent parts of the tag are options, which can be in a key:value
// format or be boolean flags.
//
// Option "default"
//
// Sets a default value to be used if the environment variable is not set.
//
// Port int `env:",default:8080"`
//
// Option "required"
//
// Marks the variable as required. Unmarshal will return an error if the
// variable is not set and no default is provided.
//
// APIKey string `env:",required"`
//
// Option "prefix"
//
// For nested struct fields, this overrides the default prefix. By default,
// the prefix is the field's name in SNAKE_CASE followed by an underscore.
// It can be set to an empty string to omit the prefix entirely.
//
// DBConfig `env:",prefix:DB_"`
//
// Option "inline"
//
// When applied to an anonymous struct field, it flattens the struct,
// effectively treating its fields as if they belonged to the parent struct.
//
// Nested `env:",inline"`
//
// Option "split"
//
// For slice types, this specifies the delimiter to split the environment
// variable string. The default separator is a comma.
//
// Hosts []string `env:",split:';'"`
//
// Option "format"
//
// Provides a format specifier for special types. For time.Time it can
// be a Go-compliant layout string (e.g., "2006-01-02") or one of the predefined
// constants "unix", "dateTime", "date", and "time". Defaults to the RFC
// 3339 format. For []byte, it can be "hex", "base32", or "base64" to alter
// the encoding format.
//
// StartDate time.Time `env:",format:date"`
//
// Option "unit"
//
// Specifies the unit for time.Time or time.Duration when parsing from an
// integer. For time.Duration: "ns", "us" (or "μs"), "ms", "s", "m", "h".
// For time.Time (with format:unix): "s", "ms", "us" (or "μs").
//
// CacheTTL time.Duration `env:",unit:m,default:5"`
package env
import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
"github.com/deep-rent/nexus/internal/pointer"
"github.com/deep-rent/nexus/internal/snake"
"github.com/deep-rent/nexus/internal/tag"
)
// Lookup is a function that retrieves the value of an environment variable.
// It follows the signature of os.LookupEnv, returning the value and a boolean
// indicating whether the variable was present. This type allows for custom
// lookup mechanisms, such as reading from sources other than the actual
// environment, which is especially useful for testing.
type Lookup func(key string) (string, bool)
// Unmarshaler is an interface that can be implemented by types to provide their
// own custom logic for parsing an environment variable string.
type Unmarshaler interface {
// UnmarshalEnv unmarshals the string value of an environment variable.
// The value is the raw string from the environment or a default value.
// It returns an error if the value cannot be parsed.
UnmarshalEnv(value string) error
}
// Option is a function that configures the behavior of the Unmarshal and Expand
// functions. It follows the functional options pattern.
type Option func(*config)
// WithPrefix returns an Option that adds a common prefix to all environment
// variable keys looked up during unmarshaling. For example, WithPrefix("APP_")
// would cause a field with the env tag "PORT" to look for the "APP_PORT"
// variable.
func WithPrefix(prefix string) Option {
return func(o *config) {
o.Prefix = prefix
}
}
// WithLookup returns an Option that sets a custom lookup function for
// retrieving environment variable values. If not customized, os.LookupEnv will
// be used by default. This is useful for testing or if you need to load
// environment variables from alternative sources.
func WithLookup(lookup Lookup) Option {
return func(o *config) {
if lookup != nil {
o.Lookup = lookup
}
}
}
// Unmarshal populates the fields of a struct with values from environment
// variables. The given value v must be a non-nil pointer to a struct.
//
// By default, Unmarshal processes all exported fields. A field's environment
// variable name is derived from its name, converted to uppercase SNAKE_CASE.
// To ignore a field, tag it with `env:"-"`. Unexported fields are always
// excluded. If a variable is not set, the field remains unchanged unless a
// default value is specified in the struct tag, or it is marked as required.
func Unmarshal(v any, opts ...Option) error {
if err := unmarshal(v, opts...); err != nil {
return fmt.Errorf("env: %w", err)
}
return nil
}
// Expand substitutes environment variables in a string.
//
// It replaces references to environment variables in the formats ${KEY} or $KEY
// with their corresponding values. A literal dollar sign can be escaped with $$.
// If a referenced variable is not found in the environment, the function
// returns an error. Its behavior can be adjusted through functional options.
func Expand(s string, opts ...Option) (string, error) {
cfg := config{
Lookup: os.LookupEnv,
}
for _, opt := range opts {
opt(&cfg)
}
var b strings.Builder
b.Grow(len(s))
i := 0
for i < len(s) {
// Find the next dollar sign.
start := strings.IndexByte(s[i:], '$')
if start == -1 {
b.WriteString(s[i:])
break
}
// Append the text before the dollar sign.
b.WriteString(s[i : i+start])
// Move our main index to the location of the dollar sign.
i += start
// Check what follows the dollar sign.
switch {
case i+1 < len(s) && s[i+1] == '$':
// Case 1: Escaped dollar sign ($$).
b.WriteByte('$')
i += 2 // Skip both signs.
case i+1 < len(s) && s[i+1] == '{':
// Case 2: Bracketed variable expansion (${KEY}).
end := strings.IndexByte(s[i+2:], '}')
if end == -1 {
return "", errors.New("env: variable bracket not closed")
}
// Extract the variable name.
key := cfg.Prefix + s[i+2:i+2+end]
val, ok := cfg.Lookup(key)
if !ok {
return "", fmt.Errorf("env: variable %q is not set", key)
}
b.WriteString(val)
// Move the index past the processed variable `${KEY}`.
i += 2 + end + 1
default:
// Case 3: Standard variable expansion ($KEY) or lone dollar sign. Scan
// ahead for valid identifier characters. The first character must be
// a letter or underscore; subsequent characters can include digits.
n := 0
for j := i + 1; j < len(s); j++ {
c := s[j]
// Allow if it's a letter/underscore, OR if it's a digit but NOT the
// first character.
if (c == '_' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z')) ||
(n > 0 && ('0' <= c && c <= '9')) {
n++
} else {
break
}
}
if n == 0 {
// No valid identifier characters found (e.g., "$5", "$!").
// Treat as a literal dollar sign.
b.WriteByte('$')
i++
} else {
// Extract the unbracketed variable name.
key := cfg.Prefix + s[i+1:i+1+n]
val, ok := cfg.Lookup(key)
if !ok {
return "", fmt.Errorf("env: variable %q is not set", key)
}
b.WriteString(val)
// Move the index past the processed variable `$KEY`.
i += 1 + n
}
}
}
return b.String(), nil
}
// flags encapsulates the options parsed from an `env` struct tag.
type flags struct {
Name string // Name of the environment variable.
Prefix *string // Optional prefix for nested structs.
Split string // Delimiter for slice types.
Unit string // Unit for time.Time or time.Duration.
Format string // Format specifier for special types.
Default string // Fallback value if the variable is not found.
Inline bool // Whether to inline an anonymous struct field.
Required bool // Whether the variable is required.
}
type config struct {
Prefix string // Common prefix for all environment variable keys.
Lookup Lookup // Injectable callback for variable lookup.
}
// Cache types with special unmarshaling logic.
var (
typeTime = reflect.TypeFor[time.Time]()
typeDuration = reflect.TypeFor[time.Duration]()
typeLocation = reflect.TypeFor[time.Location]()
typeURL = reflect.TypeFor[url.URL]()
typeUnmarshaler = reflect.TypeFor[Unmarshaler]()
)
func unmarshal(v any, opts ...Option) error {
ptr := reflect.ValueOf(v)
if ptr.Kind() != reflect.Pointer || ptr.IsNil() {
return errors.New(
"expected a non-nil pointer to a struct",
)
}
val := ptr.Elem()
if kind := val.Kind(); kind != reflect.Struct {
return fmt.Errorf(
"expected a pointer to a struct, but got pointer to %v", kind,
)
}
cfg := config{
Lookup: os.LookupEnv,
}
for _, opt := range opts {
opt(&cfg)
}
return process(val, cfg.Prefix, cfg.Lookup)
}
// process recursively walks through the struct fields.
func process(rv reflect.Value, prefix string, lookup Lookup) error {
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
ft := rt.Field(i)
fv := rv.Field(i)
if !ft.IsExported() || !fv.CanSet() {
continue
}
tag := ft.Tag.Get("env")
if tag == "-" {
continue
}
opts, err := parse(tag)
if err != nil {
return fmt.Errorf("failed to parse tag for field %q: %w", ft.Name, err)
}
if ft.Anonymous && opts.Inline {
// Dereference and allocate in case the inline field is a pointer.
embedded := pointer.Deref(fv)
if err := process(embedded, prefix, lookup); err != nil {
return err
}
continue
}
key := opts.Name
if key == "" {
key = snake.ToUpper(ft.Name)
}
// Check for true embedded structs.
if isEmbedded(ft, fv) {
nested := prefix
if opts.Prefix != nil {
nested += *opts.Prefix
} else {
nested += key + "_"
}
// Dereference and allocate in case the field is a pointer.
embedded := pointer.Deref(fv)
if err := process(embedded, nested, lookup); err != nil {
return err
}
continue
}
key = prefix + key
val, ok := lookup(key)
// A variable is "set" even if it is empty. We only trigger the strictly
// missing variable logic if 'ok' is false.
if !ok {
if opts.Default != "" {
val = opts.Default
} else if opts.Required {
return fmt.Errorf("required variable %q is not set", key)
} else {
continue
}
} else if val == "" && opts.Default != "" {
// If the variable is explicitly set to empty (""), but a default
// exists in the tags, fall back to the default value.
val = opts.Default
}
// If a field is required and set to "", it bypasses the errors above.
// For strings, setValue will correctly assign the empty string. For types
// like int or bool, setValue will return a natural parsing error.
if err := setValue(fv, val, opts); err != nil {
return fmt.Errorf(
"error setting field %q from variable %q: %w",
ft.Name, key, err,
)
}
}
return nil
}
func setValue(rv reflect.Value, v string, f *flags) error {
if u, ok := asUnmarshaler(rv); ok {
// Use the custom unmarshaler if available.
return u.UnmarshalEnv(v)
}
rv = pointer.Deref(rv)
switch rv.Type() {
case typeTime:
return setTime(rv, v, f)
case typeDuration:
return setDuration(rv, v, f)
case typeLocation:
return setLocation(rv, v)
case typeURL:
return setURL(rv, v)
default:
return setOther(rv, v, f)
}
}
// setOther handles all "regular" (primitive and slice) types by delegating to
// the appropriate parsing logic based on the reflective kind. If rv is a slice,
// it calls setSlice, otherwise it attempts to convert v into the type expected
// by rv and sets it.
func setOther(rv reflect.Value, v string, f *flags) error {
switch kind := rv.Kind(); kind {
case reflect.Slice:
return setSlice(rv, v, f)
case reflect.Bool:
b, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("%q is not a bool", v)
}
rv.SetBool(b)
case reflect.String:
rv.SetString(v)
case
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64:
b := rv.Type().Bits()
i, err := strconv.ParseInt(v, 10, b)
if err != nil {
return fmt.Errorf("%q is not an int%d", v, b)
}
rv.SetInt(i)
case
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64:
b := rv.Type().Bits()
u, err := strconv.ParseUint(v, 10, b)
if err != nil {
return fmt.Errorf("%q is not a uint%d", v, b)
}
rv.SetUint(u)
case reflect.Float32, reflect.Float64:
b := rv.Type().Bits()
f, err := strconv.ParseFloat(v, b)
if err != nil {
return fmt.Errorf("%q is not a float%d", v, b)
}
rv.SetFloat(f)
case reflect.Complex64, reflect.Complex128:
b := rv.Type().Bits()
c, err := strconv.ParseComplex(v, b)
if err != nil {
return fmt.Errorf("%q is not a complex%d", v, b)
}
rv.SetComplex(c)
default:
return fmt.Errorf("unsupported type: %s", kind)
}
return nil
}
// setTime parses and sets a time.Time value based on the provided format and
// unit options.
func setTime(rv reflect.Value, v string, f *flags) error {
var t time.Time
var err error
switch format := f.Format; format {
case "unix":
var i int64
i, err = strconv.ParseInt(v, 10, 64)
if err == nil {
switch unit := f.Unit; unit {
case "s", "":
t = time.Unix(i, 0)
case "ms":
t = time.UnixMilli(i)
case "us", "μs":
t = time.UnixMicro(i)
default:
err = fmt.Errorf("invalid time unit: %q", unit)
}
}
case "dateTime":
t, err = time.Parse(time.DateTime, v)
case "date":
t, err = time.Parse(time.DateOnly, v)
case "time":
t, err = time.Parse(time.TimeOnly, v)
case "":
format = time.RFC3339
fallthrough
default:
t, err = time.Parse(format, v)
}
if err != nil {
return err
}
rv.Set(reflect.ValueOf(t))
return nil
}
// setDuration parses and sets a time.Duration value based on the provided unit
// option.
func setDuration(rv reflect.Value, v string, f *flags) error {
var d time.Duration
var err error
if unit := f.Unit; unit == "" {
d, err = time.ParseDuration(v)
} else {
var i int64
i, err = strconv.ParseInt(v, 10, 64)
if err == nil {
switch unit {
case "ns":
d = time.Duration(i)
case "us", "μs":
d = time.Duration(i) * time.Microsecond
case "ms":
d = time.Duration(i) * time.Millisecond
case "s":
d = time.Duration(i) * time.Second
case "m":
d = time.Duration(i) * time.Minute
case "h":
d = time.Duration(i) * time.Hour
default:
err = fmt.Errorf("invalid duration unit: %q", unit)
}
}
}
if err != nil {
return err
}
rv.SetInt(int64(d))
return nil
}
// setLocation parses and sets a time.Location value.
func setLocation(rv reflect.Value, v string) error {
loc, err := time.LoadLocation(v)
if err != nil {
return err
}
rv.Set(reflect.ValueOf(*loc))
return nil
}
// setURL parses and sets a url.URL value.
func setURL(rv reflect.Value, v string) error {
u, err := url.Parse(v)
if err != nil {
return err
}
rv.Set(reflect.ValueOf(*u))
return nil
}
// setSlice parses and sets a slice value. It supports []byte with special
// encoding formats, as well as other slice types by splitting the input string.
func setSlice(rv reflect.Value, v string, f *flags) error {
if rv.Type().Elem().Kind() == reflect.Uint8 {
var b []byte
var err error
switch f.Format {
case "":
b = []byte(v)
case "hex":
b, err = hex.DecodeString(v)
case "base32":
b, err = base32.StdEncoding.DecodeString(v)
case "base64":
b, err = base64.StdEncoding.DecodeString(v)
default:
return fmt.Errorf("unsupported format for []byte: %q", f.Format)
}
if err != nil {
return err
}
rv.SetBytes(b)
return nil
}
parts := strings.Split(v, f.Split)
if len(parts) == 1 && parts[0] == "" {
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
return nil
}
slice := reflect.MakeSlice(rv.Type(), len(parts), len(parts))
for i, part := range parts {
if err := setValue(slice.Index(i), part, f); err != nil {
return fmt.Errorf("failed to parse slice element at index %d: %w", i, err)
}
}
rv.Set(slice)
return nil
}
// parse parses the `env` tag string. It supports quoted values for options
// to allow commas within them, e.g., `default:'a,b,c'`.
func parse(s string) (*flags, error) {
t := tag.Parse(s)
f := &flags{Name: t.Name, Split: ","}
seen := make(map[string]bool)
for k, v := range t.Opts() {
if seen[k] {
return nil, fmt.Errorf("duplicate option: %q", k)
}
switch k {
case "format":
f.Format = v
case "prefix":
f.Prefix = &v
case "split":
f.Split = v
case "unit":
f.Unit = v
case "default":
f.Default = v
case "inline":
f.Inline = true
case "required":
f.Required = true
default:
return nil, fmt.Errorf("unknown option: %q", k)
}
seen[k] = true
}
return f, nil
}
// isEmbedded checks if a struct field is a true embedded struct that should
// be processed recursively.
func isEmbedded(f reflect.StructField, rv reflect.Value) bool {
t := f.Type
// Unwrap pointer(s) to check the underlying type.
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
// 1. It must resolve to a struct.
if t.Kind() != reflect.Struct {
return false
}
// 2. It is not one of the special struct types we handle directly.
if t == typeTime || t == typeURL || t == typeLocation {
return false
}
// 3. It does NOT implement the Unmarshaler interface.
if _, ok := asUnmarshaler(rv); ok {
return false
}
// If all checks pass, it's a struct we should recurse into.
return true
}
// asUnmarshaler checks if the given reflect.Value implements the Unmarshaler
// interface, either directly or via a pointer receiver. If it does, the
// function returns the type-casted Unmarshaler and true. Otherwise, it returns
// nil and false.
func asUnmarshaler(rv reflect.Value) (Unmarshaler, bool) {
// Case 1: The field's type directly implements Unmarshaler.
// This works for pointer types (e.g., *reverse) or value types with
// value receivers.
if rv.Type().Implements(typeUnmarshaler) {
if rv.Kind() == reflect.Pointer && rv.IsNil() {
// If it's a nil pointer, we must allocate it to prevent a panic
// when calling the interface method on the nil receiver.
pointer.Alloc(rv)
}
return rv.Interface().(Unmarshaler), true
}
// Case 2: A pointer to the field's type implements Unmarshaler.
// This works for value types with pointer receivers (e.g., reverse).
if rv.CanAddr() && rv.Addr().Type().Implements(typeUnmarshaler) {
return rv.Addr().Interface().(Unmarshaler), true
}
return nil, false
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package header provides a collection of utility functions for parsing,
// interpreting, and manipulating HTTP headers.
//
// The package includes helpers for common header-related tasks, such as:
// - Parsing comma-separated directives (e.g., "max-age=3600").
// - Parsing wildcard-aware content negotiation headers with q-factors.
// - Parsing RFC 5988 Link headers to extract relations for API pagination.
// - Extracting credentials from an Authorization header.
// - Calculating cache lifetime from Cache-Control and Expires headers.
// - Determining throttle delays from Retry-After and X-Ratelimit-* headers.
//
// It also provides a convenient http.RoundTripper implementation for
// automatically attaching a static set of headers to all outgoing requests.
package header
import (
"iter"
"mime"
"net/http"
"strconv"
"strings"
"time"
)
// Directives parses a comma-separated header value into an iterator of
// key-value pairs.
//
// For example, parsing "no-cache, max-age=3600" would yield twice: first
// "no-cache", "" and then "max-age", "3600".
func Directives(s string) iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
for kv := range strings.SplitSeq(s, ",") {
k, v, ok := strings.Cut(strings.TrimSpace(kv), "=")
k = strings.ToLower(strings.TrimSpace(k))
if ok {
v = strings.TrimSpace(v)
}
if !yield(k, v) {
return
}
}
}
}
// Throttle determines the required delay before the next request based on
// rate-limiting headers in the response. It accepts a clock function to
// calculate relative times. If no throttling is indicated, it returns a
// duration of 0.
func Throttle(h http.Header, now func() time.Time) time.Duration {
if v := h.Get("Retry-After"); v != "" {
if d, err := strconv.ParseInt(v, 10, 64); err == nil && d > 0 {
return time.Duration(d) * time.Second
}
if t, err := http.ParseTime(v); err == nil {
if d := t.Sub(now()); d > 0 {
return d
}
}
}
if h.Get("X-Ratelimit-Remaining") == "0" {
if v := h.Get("X-Ratelimit-Reset"); v != "" {
if t, err := strconv.ParseInt(v, 10, 64); err == nil && t > 0 {
if d := time.Unix(t, 0).Sub(now()); d > 0 {
return d
}
}
}
}
return 0
}
// Lifetime determines the cache lifetime of a response based on caching
// headers. It accepts a clock function to calculate relative times. It returns
// a duration of 0 if the response is not cacheable or does not carry any
// caching information.
func Lifetime(h http.Header, now func() time.Time) time.Duration {
// Cache-Control takes precedence over Expires
if v := h.Get("Cache-Control"); v != "" {
for k, v := range Directives(v) {
switch k {
case "no-cache", "no-store":
return 0
case "max-age":
if d, err := strconv.ParseInt(v, 10, 64); err == nil {
return time.Duration(d) * time.Second
}
}
}
}
if v := h.Get("Expires"); v != "" {
if t, err := http.ParseTime(v); err == nil {
if d := t.Sub(now()); d > 0 {
return d
}
}
}
return 0
}
// Credentials extracts the credentials from the Authorization header of an
// HTTP request for a specific authentication scheme (e.g., "Basic", "Bearer").
//
// It returns the raw credentials as-is, or an empty string if the header is
// not present, not well-formed, or does not match the specified scheme. The
// scheme comparison is case-insensitive.
func Credentials(h http.Header, scheme string) string {
auth := h.Get("Authorization")
if auth == "" {
return ""
}
prefix, credentials, ok := strings.Cut(auth, " ")
if !ok || !strings.EqualFold(prefix, scheme) {
return ""
}
return credentials
}
// Preferences parses a header value with quality factors (e.g., Accept,
// Accept-Encoding, Accept-Language) into an iterator quality factors (q-value)
// by name (media range). The values are yielded in the order they appear in the
// header, not sorted by quality. Values without an explicit q-factor are
// assigned a default quality of 1.0. Malformed q-factors are also treated as
// 1.0, while out-of-range values are clamped into the [0.0, 1.0] interval.
func Preferences(s string) iter.Seq2[string, float64] {
return func(yield func(string, float64) bool) {
for part := range strings.SplitSeq(s, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
params := strings.Split(part, ";")
var q = 1.0
for i := 1; i < len(params); i++ {
p := strings.TrimSpace(params[i])
k, v, found := strings.Cut(p, "=")
if found && strings.TrimSpace(k) == "q" {
v = strings.TrimSpace(v)
if f, err := strconv.ParseFloat(v, 64); err == nil {
q = min(1.0, max(0.0, f))
}
break
}
}
if !yield(strings.TrimSpace(params[0]), q) {
return
}
}
}
}
// Accepts checks if the given key is accepted based on a header value with
// quality factors (e.g., Accept, Accept-Encoding, or Accept-Language).
// It properly weights exact matches over partial wildcards (e.g., "text/*")
// and global wildcards ("*/*" or "*"), returning true if the best match has
// a q-value greater than zero.
func Accepts(s, key string) bool {
var (
maxQ float64
maxP int
)
// Extract the major type (e.g., "text" from "text/html") for partial
// wildcards.
major, _, hasSlash := strings.Cut(key, "/")
for k, q := range Preferences(s) {
var p int
if k == key {
p = 3 // Exact match (highest precedence)
} else if hasSlash && k == major+"/*" {
p = 2 // Partial wildcard match (e.g., "text/*")
} else if k == "*/*" || k == "*" {
p = 1 // Global wildcard match
}
// Update if we found a more specific match than our current best.
if p > maxP {
maxP = p
maxQ = q
}
}
// It is accepted if we found a valid match and its q-value is greater than 0.
return maxP > 0 && maxQ > 0
}
// MediaType extracts and returns the media type from a Content-Type header.
// It returns the media type in lowercase, trimmed of whitespace. If the header
// is empty or malformed, it returns an empty string.
//
// This function is similar to mime.ParseMediaType but does not return any
// parameters and ignores parsing errors.
func MediaType(h http.Header) string {
v := h.Get("Content-Type")
if v == "" {
return ""
}
i := strings.IndexByte(v, ';')
if i != -1 {
v = v[:i]
}
return strings.ToLower(strings.TrimSpace(v))
}
// Links parses an RFC 5988 Link header into an iterator of relation types (rel)
// and their corresponding URLs.
//
// If a link has multiple space-separated relations (e.g., rel="next archive"),
// it yields the URL for each relation separately.
func Links(s string) iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
for part := range strings.SplitSeq(s, ",") {
s := strings.IndexByte(part, '<')
e := strings.IndexByte(part, '>')
// Ensure the URL brackets are present and valid.
if s == -1 || e == -1 || s >= e {
continue
}
url := part[s+1 : e]
// Parse the parameters following the URL.
params := strings.SplitSeq(part[e+1:], ";")
for p := range params {
p = strings.TrimSpace(p)
k, v, found := strings.Cut(p, "=")
if found && strings.ToLower(strings.TrimSpace(k)) == "rel" {
// Remove optional quotes around the relation value.
v = strings.Trim(strings.TrimSpace(v), `"`)
// A single link can have multiple relation types.
for rel := range strings.FieldsSeq(v) {
if !yield(strings.ToLower(rel), url) {
return
}
}
}
}
}
}
}
// Link extracts the URL for a specific relation (e.g., "next" or "last") from
// a Link header. It returns an empty string if the relation is not found.
func Link(s, rel string) string {
rel = strings.ToLower(rel)
for k, v := range Links(s) {
if k == rel {
return v
}
}
return ""
}
// Filename extracts the intended filename from a Content-Disposition header.
//
// It automatically handles both the standard "filename" parameter and the
// RFC 6266 "filename*" parameter, which is used for non-ASCII (UTF-8) names.
// It returns an empty string if the header is missing, malformed, or does
// not contain a filename.
func Filename(h http.Header) string {
v := h.Get("Content-Disposition")
if v == "" {
return ""
}
_, params, err := mime.ParseMediaType(v)
if err != nil {
return ""
}
// The filename* parameter is decoded automatically.
return params["filename"]
}
// Header represents a single HTTP header key-value pair.
type Header struct {
Key string // Key is the canonicalized header name.
Value string // Value is the raw value of the header.
}
// String formats the header as "Key: Value".
func (h Header) String() string {
return h.Key + ": " + h.Value
}
// New creates a new Header with the given key and value. The key is
// automatically canonicalized to the standard HTTP header format.
func New(key, value string) Header {
return Header{
Key: http.CanonicalHeaderKey(key),
Value: value,
}
}
// UserAgent constructs a User-Agent header with the specified name, version,
// and an optional comment. The resulting value follows the format "name/version
// (comment)". The first part is the product token, while the parenthesized
// section provides supplementary information about the client. For external
// calls, it is best practice to include maintainer contact details in the
// comment (such as an URL or email address).
func UserAgent(name, version, comment string) Header {
value := name + "/" + version
if comment != "" {
value += " (" + comment + ")"
}
return Header{
Key: "User-Agent",
Value: value,
}
}
type transport struct {
wrapped http.RoundTripper
headers []Header
}
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
for _, h := range t.headers {
clone.Header.Set(h.Key, h.Value)
}
return t.wrapped.RoundTrip(clone)
}
var _ http.RoundTripper = (*transport)(nil)
// NewTransport wraps a base transport and sets a static set of headers on
// each outgoing request. If the provided headers map is empty, the base
// transport is returned unmodified. The function creates a defensive copy of
// the provided map. The resulting transport clones the request before
// delegating to the base transport, so the original request is not changed.
func NewTransport(
t http.RoundTripper,
headers ...Header,
) http.RoundTripper {
if len(headers) == 0 {
return t
}
return &transport{
wrapped: t,
headers: headers,
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package buffer provides a sync.Pool-backed implementation of
// httputil.BufferPool for reusing byte slices, aiming to save memory and
// reduce GC pressure when dealing with large response bodies.
package buffer
import (
"net/http/httputil"
"sync"
)
// Pool implements httputil.Pool backed by a sync.Pool internally.
// It reduces allocations for large response bodies by reusing byte slices,
// thus lowering GC pressure.
type Pool struct {
pool sync.Pool
size int
}
// NewPool creates a new Pool that returns buffers of at least minSize
// bytes. Buffers that grow beyond maxSize will be discarded.
//
// Both numbers must be positive, or else the function panics; minSize will be
// clamped by maxSize.
func NewPool(minSize, maxSize int) *Pool {
if minSize <= 0 {
panic("buffer: minSize must be positive")
}
if maxSize <= 0 {
panic("buffer: maxSize must be positive")
}
minSize = min(minSize, maxSize)
// Store a pointer to a slice to avoid allocations when storing in the
// interface-typed pool
alloc := func() any {
buf := make([]byte, minSize)
return &buf
}
return &Pool{
pool: sync.Pool{New: alloc},
size: maxSize,
}
}
// Get returns a reusable buffer slice.
func (b *Pool) Get() []byte {
//nolint:errcheck // The type assertion is guaranteed to succeed.
return *b.pool.Get().(*[]byte)
}
// Put returns the buffer to the pool unless it grew beyond the size limit.
func (b *Pool) Put(buf []byte) {
// Avoid holding on to overly large buffers
if cap(buf) <= b.size {
b.pool.Put(&buf)
}
}
// Ensure Pool satisfies the httputil.BufferPool interface.
var _ httputil.BufferPool = (*Pool)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jitter provides functionality for adding random variation (jitter)
// to time durations.
//
// This package is designed to help distributed systems avoid "thundering herd"
// problems by desynchronizing retry attempts or periodic jobs.
//
// The jitter implementation is "subtractive". It calculates a duration
// randomly chosen between [d * (1 - p), d], where p is the jitter percentage.
// This ensures that the returned duration never exceeds the input duration,
// allowing strict adherence to maximum delay limits (e.g., in backoff
// strategies).
package jitter
import (
"math/rand/v2"
"time"
)
// Rand serves as a minimal facade over rand.Rand to ease mocking.
type Rand interface {
// Float64 generates a pseudo-random number in [0.0, 1.0).
Float64() float64
}
// Ensure compliance with parent interface.
var _ Rand = (*rand.Rand)(nil)
// seeded is a pre-seeded Rand instance for default use.
// Note: Go 1.20+ auto-seeds the global RNG, which spares us from time-based
// seeding.
var seeded Rand = rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64()))
// Jitter applies subtractive random jitter to a duration.
type Jitter struct {
p float64 // jitter percentage
r Rand // random number generator
}
// New creates a new Jitter instance with the given percentage p (0.0 to 1.0)
// and source of randomness r.
//
// If r is nil, a default thread-safe, seeded generator is used.
func New(p float64, r Rand) *Jitter {
if r == nil {
r = seeded // Fallback
}
return &Jitter{
r: r,
p: p,
}
}
// Apply returns the duration d damped by a random amount based on the
// jitter percentage.
//
// The result is guaranteed to be in the range [Floor(d), d].
func (j *Jitter) Apply(d time.Duration) time.Duration {
return j.Floor(d, j.r.Float64())
}
// Floor returns the minimum possible duration that Apply could return for
// the given input d.
//
// This is equivalent to applying maximum jitter (factor = 1.0).
func (j *Jitter) Floor(d time.Duration, f float64) time.Duration {
return time.Duration(float64(d) * (1 - f*j.p))
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package pointer provides reflection-based helpers for working with pointers.
// These functions are useful for dynamically allocating and dereferencing
// variables in contexts like configuration loading or data mapping.
package pointer
import "reflect"
// Alloc allocates a new value for a nil pointer and sets the pointer to it.
//
// This function causes a panic if rv is not a settable pointer.
func Alloc(rv reflect.Value) {
rv.Set(reflect.New(rv.Type().Elem()))
}
// Deref follows pointers until it reaches a non-pointer, allocating if nil.
//
// If a nil pointer is encountered along the way, Deref will attempt to allocate
// a new value for it. If it encounters an un-settable nil pointer (e.g., one
// within an unexported struct field), it stops and returns that pointer to
// prevent a panic. The final, non-pointer value is returned.
func Deref(rv reflect.Value) reflect.Value {
// Loop through multi-level pointers to handle cases like **int.
for rv.Kind() == reflect.Pointer {
if rv.IsNil() {
// If the pointer is nil but cannot be set, we must stop
// here to avoid a panic.
if !rv.CanSet() {
break
}
Alloc(rv)
}
rv = rv.Elem()
}
return rv
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package quote provides utility functions for working with quoted strings.
package quote
// Remove strips a single layer of surrounding single or double quotes from a
// string. If the string is not quoted or too short, it is returned unchanged.
func Remove(s string) string {
if len(s) < 2 {
return s
}
// Check for a matching pair of quotes.
switch s[0] {
case '"':
if s[len(s)-1] == '"' {
return s[1 : len(s)-1]
}
case '\'':
if s[len(s)-1] == '\'' {
return s[1 : len(s)-1]
}
}
// Return the original string if no matching quotes are found.
return s
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package rotor provides a thread-safe, generic type for rotating
// through a slice of items in a round-robin fashion.
package rotor
import "sync/atomic"
// Rotor provides thread-safe round-robin access to a slice of items.
// It must be initialized with the New function.
type Rotor[E any] interface {
// Next returns the next item in the rotation.
// This method is safe for concurrent use by multiple goroutines.
Next() E
}
// singleton is a Rotor that contains only a single item.
type singleton[E any] struct{ item E }
// Next implements the Rotor interface.
func (s *singleton[E]) Next() E {
return s.item
}
// rotor is a generic implementation of the Rotor interface.
type rotor[E any] struct {
items []E
index atomic.Uint64
}
// New creates a new Rotor.
// It makes a defensive copy of the provided items slice to ensure immutability.
// This function panics if the items slice is empty.
func New[E any](items []E) Rotor[E] {
if len(items) == 0 {
panic("rotor: items slice must not be empty")
}
if len(items) == 1 {
return &singleton[E]{item: items[0]}
}
c := make([]E, len(items))
copy(c, items)
return &rotor[E]{items: c}
}
// Next implements the Rotor interface.
func (r *rotor[E]) Next() E {
n := uint64(len(r.items))
var idx uint64
for {
idx = r.index.Load()
if r.index.CompareAndSwap(idx, (idx+1)%n) {
break
}
}
return r.items[idx]
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package snake provides functions for converting strings between camelCase
// and snake_case formats.
package snake
import (
"strings"
"unicode"
)
// ToUpper converts a camelCase string to an uppercase SNAKE_CASE string.
//
// For example, "fooBar" is converted to "FOO_BAR", and so is "FOOBar". Note
// that digits do not induce transitions, so "foo1" becomes "FOO1".
func ToUpper(s string) string { return transform(s, unicode.ToUpper) }
// ToLower converts a camelCase string to a lowercase snake_case string.
//
// For example, "fooBar" is converted to "foo_bar", and so is "FOOBar". Note
// that digits do not induce transitions, so "foo1" becomes "FOO1".
func ToLower(s string) string { return transform(s, unicode.ToLower) }
// transform is a helper function that performs the actual text conversion.
func transform(s string, toCase func(rune) rune) string {
var b strings.Builder
b.Grow(len(s) + 5)
runes := []rune(s)
for i, r := range runes {
// Insert an underscore before a capital letter or digit.
if i != 0 {
q := runes[i-1]
if (unicode.IsLower(q) &&
// Case 1: Lowercase to uppercase/digit transition ("myVar", "myVar1").
(unicode.IsUpper(r) || unicode.IsDigit(r))) ||
(unicode.IsUpper(q) &&
// Case 2: Acronym to new word transition ("MYVar").
unicode.IsUpper(r) &&
i+1 < len(runes) &&
unicode.IsLower(runes[i+1])) {
b.WriteRune('_')
}
}
b.WriteRune(toCase(r))
}
return b.String()
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tag provides utility for parsing Go struct tags that follow a
// comma-separated key-value option format, similar to the `json` tag known from
// the standard library.
package tag
import (
"iter"
"strings"
"unicode"
"github.com/deep-rent/nexus/internal/quote"
)
// Tag represents a parsed struct tag, separating the primary name from the
// additional options.
type Tag struct {
Name string
opts string
}
// Opts returns an iterator sequence over the tag's options.
//
// Each element yielded is a key-value pair. If an option does not have an
// explicit value (e.g., "omitempty"), the value string will be empty. Keys and
// values are trimmed of surrounding whitespace. Values that were quoted in the
// source string (e.g., `key:"value"`) will have the quotes removed. Commas
// inside quoted values are preserved and not treated as option separators
// (e.g., `key:"val1,val2"` is one option).
func (t *Tag) Opts() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
rest := t.opts
// Scan through the rest of the string until it's completely consumed.
for rest != "" {
// Trim leading space from the rest of the string.
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
if rest == "" {
break
}
// Find the end of the current option part by finding the next
// comma that is not inside quotes.
end := -1
inQuote := false
var q rune
for i, r := range rest {
if r == q {
inQuote = false
q = 0
} else if !inQuote && (r == '\'' || r == '"') {
inQuote = true
q = r
} else if !inQuote && r == ',' {
end = i
break
}
}
var part string
if end == -1 {
// This is the last option part.
part = rest
rest = ""
} else {
part = rest[:end]
rest = rest[end+1:]
}
// Now, parse the individual part (e.g., "default:'foo,bar'").
k, v, found := strings.Cut(part, ":")
if found {
v = quote.Remove(v)
}
if !yield(strings.TrimRightFunc(k, unicode.IsSpace), v) {
return
}
}
}
}
// Parse takes a raw tag string (e.g., `json:opt1,opt2:val`) and separates it
// into the primary name and the options string.
func Parse(s string) *Tag {
name, opts, _ := strings.Cut(s, ",")
return &Tag{
Name: name,
opts: opts,
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwa provides implementations for asymmetric JSON Web Algorithms
// (JWA) as defined in RFC 7518.
//
// It provides a unified interface for signature verification using public
// keys and signature creation using crypto.Signer. This abstraction handles
// algorithm-specific complexities such as hash function selection, padding
// schemes (e.g., PSS vs PKCS1v15), and signature format transcoding (e.g.,
// converting ECDSA ASN.1 DER to raw concatenation).
//
// Note: Symmetric algorithms (such as HMAC) are not supported.
package jwa
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"encoding/asn1"
"errors"
"fmt"
"hash"
"math/big"
"sync"
"github.com/cloudflare/circl/sign/ed448"
)
// Algorithm represents an asymmetric JSON Web Algorithm (JWA) used for
// verifying and calculating signatures. The type parameter T specifies the type
// of public key that the algorithm works with.
type Algorithm[T crypto.PublicKey] interface {
// fmt.Stringer provides the standard JWA name for the algorithm.
fmt.Stringer
// Verify checks a signature against a message using the provided public key.
// It returns true if the signature is valid, and false otherwise.
// None of the parameters may be nil.
Verify(key T, msg, sig []byte) bool
// Sign creates a signature for the message using the provided signer.
// The signer must be capable of using the algorithm's specific hash
// and padding scheme.
Sign(signer crypto.Signer, msg []byte) ([]byte, error)
}
// rs implements the RSASSA-PKCS1-v1_5 family of algorithms (RSxxx).
type rs struct {
name string
pool *hashPool
}
// newRS creates a new Algorithm for RSASSA-PKCS1-v1_5 signatures
// with the given JWA name and hash function.
func newRS(name string, hash crypto.Hash) Algorithm[*rsa.PublicKey] {
return &rs{
name: name,
pool: newHashPool(hash),
}
}
func (a *rs) Verify(key *rsa.PublicKey, msg, sig []byte) bool {
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
return rsa.VerifyPKCS1v15(key, a.pool.Hash, digest, sig) == nil
}
func (a *rs) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
return signer.Sign(rand.Reader, digest, a.pool.Hash)
}
func (a *rs) String() string {
return a.name
}
// RS256 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-256.
var RS256 = newRS("RS256", crypto.SHA256)
// RS384 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-384.
var RS384 = newRS("RS384", crypto.SHA384)
// RS512 represents the RSASSA-PKCS1-v1_5 signature algorithm using SHA-512.
var RS512 = newRS("RS512", crypto.SHA512)
// ps implements the RSASSA-PSS family of algorithms (PSxxx).
type ps struct {
name string
pool *hashPool
}
// newPS creates a new Algorithm for RSASSA-PSS signatures
// with the given JWA name and hash function.
func newPS(name string, hash crypto.Hash) Algorithm[*rsa.PublicKey] {
return &ps{
name: name,
pool: newHashPool(hash),
}
}
func (a *ps) Verify(key *rsa.PublicKey, msg, sig []byte) bool {
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
// The salt length is set to match the hash size.
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
return rsa.VerifyPSS(key, a.pool.Hash, digest, sig, opts) == nil
}
func (a *ps) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
opts := &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: a.pool.Hash,
}
return signer.Sign(rand.Reader, digest, opts)
}
func (a *ps) String() string {
return a.name
}
// PS256 represents the RSASSA-PSS signature algorithm using SHA-256.
var PS256 = newPS("PS256", crypto.SHA256)
// PS384 represents the RSASSA-PSS signature algorithm using SHA-384.
var PS384 = newPS("PS384", crypto.SHA384)
// PS512 represents the RSASSA-PSS signature algorithm using SHA-512.
var PS512 = newPS("PS512", crypto.SHA512)
// es implements the ECDSA family of algorithms (ESxxx).
type es struct {
name string
pool *hashPool
}
// newES creates a new Algorithm for ECDSA signatures
// with the given JWA name and hash function.
func newES(name string, hash crypto.Hash) Algorithm[*ecdsa.PublicKey] {
return &es{
name: name,
pool: newHashPool(hash),
}
}
func (a *es) Verify(key *ecdsa.PublicKey, msg, sig []byte) bool {
// The signature is the concatenation of two integers of the same size
// as the curve's order.
n := (key.Curve.Params().BitSize + 7) / 8
if len(sig) != 2*n {
return false
}
h := a.pool.Get()
defer func() { a.pool.Put(h) }()
h.Write(msg)
digest := h.Sum(nil)
// Split the signature into R and S.
r := new(big.Int).SetBytes(sig[:n])
s := new(big.Int).SetBytes(sig[n:])
return ecdsa.Verify(key, digest, r, s)
}
func (a *es) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
h := a.pool.Get()
defer a.pool.Put(h)
h.Write(msg)
digest := h.Sum(nil)
der, err := signer.Sign(rand.Reader, digest, nil)
if err != nil {
return nil, err
}
var concat struct{ R, S *big.Int }
if _, err := asn1.Unmarshal(der, &concat); err != nil {
return nil, fmt.Errorf("failed to parse ECDSA signature: %w", err)
}
pub, ok := signer.Public().(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("signer public key is not ECDSA")
}
n := (pub.Curve.Params().BitSize + 7) / 8
out := make([]byte, 2*n)
concat.R.FillBytes(out[:n])
concat.S.FillBytes(out[n:])
return out, nil
}
func (a *es) String() string {
return a.name
}
// ES256 represents the ECDSA signature algorithm using P-256 and SHA-256.
var ES256 = newES("ES256", crypto.SHA256)
// ES384 represents the ECDSA signature algorithm using P-384 and SHA-384.
var ES384 = newES("ES384", crypto.SHA384)
// ES512 represents the ECDSA signature algorithm using P-521 and SHA-512.
var ES512 = newES("ES512", crypto.SHA512)
// ed implements the EdDSA family of algorithms.
type ed struct{}
func (a *ed) Verify(key []byte, msg, sig []byte) bool {
switch len(key) {
case ed448.PublicKeySize:
// Per RFC 8037, the JWS "EdDSA" algorithm corresponds to the "pure" EdDSA
// variant, which uses an empty string for the context parameter.
pub := ed448.PublicKey(key)
return ed448.Verify(pub, msg, sig, "")
case ed25519.PublicKeySize:
pub := ed25519.PublicKey(key)
return ed25519.Verify(pub, msg, sig)
default:
return false
}
}
func (a *ed) Sign(signer crypto.Signer, msg []byte) ([]byte, error) {
return signer.Sign(rand.Reader, msg, crypto.Hash(0))
}
func (a *ed) String() string {
return "EdDSA"
}
// EdDSA represents the EdDSA signature algorithm. It supports both Ed25519
// and Ed448 curves. The curve is determined by the size of the public key.
var EdDSA Algorithm[[]byte] = &ed{}
// hashPool manages a pool of hash.Hash objects to reduce allocations.
type hashPool struct {
Hash crypto.Hash
pool *sync.Pool
}
// newHashPool creates a new hashPool for the given hash function.
func newHashPool(hash crypto.Hash) *hashPool {
pool := &sync.Pool{
New: func() any {
return hash.New()
},
}
return &hashPool{
Hash: hash,
pool: pool,
}
}
// Get retrieves a hash.Hash from the pool.
func (p *hashPool) Get() hash.Hash {
h := p.pool.Get()
return h.(hash.Hash)
}
// Put returns a hash.Hash to the pool after resetting it.
func (p *hashPool) Put(h hash.Hash) {
h.Reset()
p.pool.Put(h)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwk provides functionality to parse, manage, and marshal JSON Web
// Keys (JWK) and JSON Web Key Sets (JWKS), as defined in RFC 7517.
//
// # Verification
//
// The package is primarily designed to consume public keys from a remote JWKS
// endpoint for the purpose of verifying JWT signatures.
//
// # Signing
//
// While JWKS parsing focuses on public keys, this package also supports the
// creation of signing keys via the KeyBuilder. These keys wrap a crypto.Signer
// (e.g., hardware modules, KMS, or standard library keys) to support token
// issuance operations.
//
// # Encoding
//
// The package supports serializing keys back to JSON. This is useful for
// services that need to expose their own public keys via a JWKS endpoint or
// for persisting key sets. The marshaling logic is strict: it only outputs
// public key material and adheres to RFC 7518 fixed-width requirements for
// elliptic curve coordinates.
//
// # Eligible Keys
//
// Keys that are not intended for signature verification are considered
// ineligible and will be skipped during parsing of a JWKS. A key is eligible
// if it meets at least one of the following criteria:
//
// - The "use" (Public Key Use) parameter is set to "sig".
// - The "key_ops" (Key Operations) parameter includes "verify".
//
// # Key Selection
//
// This implementation deliberately deviates from the RFC for robustness and
// simplicity:
//
// 1. The "alg" (Algorithm) parameter, optional in the standard, is treated as
// mandatory for all eligible keys. Enforcing this is a best practice that
// mitigates algorithm confusion attacks.
// 2. For key selection, either "kid" (Key ID) or "x5t#S256" (SHA-256
// Thumbprint) must be defined. The "x5t" (SHA-1 Thumbprint) parameter is
// explicitly ignored as it is considered outdated. No other lookup
// mechanism is supported.
package jwk
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json/jsontext"
"encoding/json/v2"
"errors"
"fmt"
"iter"
"log/slog"
"math/big"
"slices"
"time"
"github.com/cloudflare/circl/sign/ed448"
"github.com/deep-rent/nexus/cache"
"github.com/deep-rent/nexus/jose/jwa"
"github.com/deep-rent/nexus/scheduler"
)
// Hint represents a reference to a Key, containing the minimum information
// needed to look one up in a Set. It effectively abstracts the JWS header
// fields used to select a key for signature verification.
type Hint interface {
// Algorithm returns the JWA algorithm name that the key is intended for.
// This must match the "alg" parameter in the JWS header.
Algorithm() string
// KeyID returns the unique identifier for the key, or an empty string if
// absent. This must match the "kid" parameter in the JWS header.
// One of "kid" or "x5t#S256" must be present. If both are present, "kid"
// takes precedence during lookups.
KeyID() string
// Thumbprint returns the base64url-encoded SHA-256 digest of the DER-encoded
// X.509 certificate associated with the key, or an empty string if absent.
// This must match the "x5t#S256" parameter in the JWS header. One of "kid" or
// "x5t#S256" must be present. If both are present, "kid" takes precedence
// during lookups.
Thumbprint() string
}
// Key represents a public JSON Web Key (JWK) used for signature verification.
type Key interface {
Hint
// Verify checks a signature against a message using the key's material
// and its associated algorithm. It returns true if the signature is valid.
// It returns false if the signature is invalid.
Verify(msg, sig []byte) bool
// Material returns the raw cryptographic public key for encoding purposes.
// The private key is never exposed.
Material() any
}
// newKey creates a new Key programmatically from its constituent parts. The
// type parameter T must match the public key type expected by the provided
// algorithm (e.g., *rsa.PublicKey for jwa.RS256).
func newKey[T crypto.PublicKey](
alg jwa.Algorithm[T],
kid string,
x5t string,
mat T,
) Key {
return &key[T]{alg: alg, kid: kid, x5t: x5t, mat: mat}
}
// key is a concrete implementation of the Key interface, generic over the
// public key type.
type key[T crypto.PublicKey] struct {
alg jwa.Algorithm[T]
kid string
x5t string
mat T // The actual cryptographic public key material.
}
func (k *key[T]) Algorithm() string { return k.alg.String() }
func (k *key[T]) KeyID() string { return k.kid }
func (k *key[T]) Thumbprint() string { return k.x5t }
func (k *key[T]) Material() any { return k.mat }
func (k *key[T]) Verify(msg, sig []byte) bool {
return k.alg.Verify(k.mat, msg, sig)
}
// KeyPair represents a JSON Web Key that is capable of both verification and
// signing. It embeds the public Key interface and wraps a crypto.Signer for
// the private key operations.
type KeyPair interface {
Key
// Sign generates a signature for the given message.
Sign(msg []byte) ([]byte, error)
}
// keyPair is the concrete implementation of KeyPair.
type keyPair[T crypto.PublicKey] struct {
// We embed the struct value (not the pointer) so that the inner fields (alg,
// kid, etc.) are allocated together.
key[T]
signer crypto.Signer
}
func (s *keyPair[T]) Sign(msg []byte) ([]byte, error) {
return s.alg.Sign(s.signer, msg)
}
// KeyBuilder assists in the programmatic construction of Key and KeyPair
// instances. It ensures that the resulting keys possess the required metadata
// and that the cryptographic material matches the intended algorithm.
type KeyBuilder[T crypto.PublicKey] struct {
alg jwa.Algorithm[T]
kid string
x5t string
}
// NewKeyBuilder starts the construction of a key for the specified algorithm.
// The generic type T determines the expected public key material (e.g.,
// *rsa.PublicKey).
func NewKeyBuilder[T crypto.PublicKey](alg jwa.Algorithm[T]) *KeyBuilder[T] {
return &KeyBuilder[T]{alg: alg}
}
// Algorithm returns the JWA associated with this builder.
func (b *KeyBuilder[T]) Algorithm() jwa.Algorithm[T] { return b.alg }
// KeyID returns the currently configured key identifier, or an empty string.
func (b *KeyBuilder[T]) KeyID() string { return b.kid }
// Thumbprint returns the currently configured certificate thumbprint, or an
// empty string.
func (b *KeyBuilder[T]) Thumbprint() string { return b.x5t }
// WithKeyID sets the "kid" (Key ID) parameter.
func (b *KeyBuilder[T]) WithKeyID(kid string) *KeyBuilder[T] {
b.kid = kid
return b
}
// WithThumbprint sets the "x5t#S256" (SHA-256 Certificate Thumbprint)
// parameter.
func (b *KeyBuilder[T]) WithThumbprint(x5t string) *KeyBuilder[T] {
b.x5t = x5t
return b
}
// Build creates a verification-only Key using the provided public key material.
// It panics if neither a Key ID nor a Thumbprint has been configured.
func (b *KeyBuilder[T]) Build(mat T) Key {
return b.build(mat)
}
// BuildPair creates a signing-capable KeyPair using the provided signer.
//
// It panics if:
// 1. The signer's public key cannot be cast to type T.
// 2. Neither a Key ID nor a Thumbprint has been configured.
func (b *KeyBuilder[T]) BuildPair(signer crypto.Signer) KeyPair {
mat, ok := signer.Public().(T)
if !ok {
panic("signer public key type does not match key builder type")
}
return &keyPair[T]{
key: *b.build(mat),
signer: signer,
}
}
func (b *KeyBuilder[T]) build(mat T) *key[T] {
if b.kid == "" && b.x5t == "" {
panic("either key id or thumbprint must be set")
}
return &key[T]{
alg: b.alg,
kid: b.kid,
x5t: b.x5t,
mat: mat,
}
}
// ErrIneligibleKey indicates that a key may be syntactically valid but should
// not be used for signature verification according to its "use" or "key_ops"
// parameters.
var ErrIneligibleKey = errors.New("ineligible for signature verification")
// Parse parses a single Key from the provided JSON input.
//
// It first checks if the key is eligible for signature verification. If not,
// it returns ErrIneligibleKey. Otherwise, it proceeds to validate the presence
// of required parameters ("kty" and "alg"), whether the algorithm is supported,
// and the integrity of the key material itself.
func Parse(in []byte) (Key, error) {
var raw raw
if err := json.Unmarshal(in, &raw); err != nil {
return nil, fmt.Errorf("invalid json format: %w", err)
}
// Per RFC 7517, a key's purpose is determined by the union of "use" and
// "key_ops". We perform this check first for efficiency, as we only care
// about signature verification keys.
if raw.Use != "sig" && !slices.Contains(raw.Ops, "verify") {
return nil, ErrIneligibleKey
}
if raw.Kty == "" {
return nil, errors.New("undefined key type")
}
if raw.Alg == "" {
return nil, errors.New("algorithm not specified")
}
read := readers[raw.Alg]
if read == nil {
return nil, fmt.Errorf("unknown algorithm %q", raw.Alg)
}
key, err := read(&raw)
if err != nil {
return nil, fmt.Errorf("read %s key material: %w", raw.Kty, err)
}
return key, nil
}
// Set stores an immutable collection of Keys, typically parsed from a JWKS.
// It provides efficient lookups of keys for signature verification.
type Set interface {
// Keys returns an iterator over all keys in this set.
Keys() iter.Seq[Key]
// Len returns the number of keys in this set.
Len() int
// Find looks up a key using the specified hint. A key is returned only
// if both its key id and algorithm match the hint exactly.
// Otherwise, it returns nil.
Find(hint Hint) Key
}
// newSet creates a new, empty Set with the specified initial capacity.
func newSet(n int) *set {
return &set{
keys: make([]Key, 0, n),
kid: make(map[string]int, n),
x5t: make(map[string]int, n),
}
}
// set is the concrete implementation of the Set interface.
// It uses maps for efficient O(1) average time complexity lookups.
type set struct {
keys []Key
kid map[string]int // Maps key id to index in keys array.
x5t map[string]int // Maps thumbprint to index in keys array.
}
func (s *set) Keys() iter.Seq[Key] { return slices.Values(s.keys) }
func (s *set) Len() int { return len(s.keys) }
func (s *set) Find(hint Hint) Key {
if hint == nil {
return nil
}
var k Key
if i, ok := s.kid[hint.KeyID()]; ok {
k = s.keys[i]
} else if i, ok := s.x5t[hint.Thumbprint()]; ok {
k = s.keys[i]
} else {
return nil
}
if k.Algorithm() != hint.Algorithm() {
return nil
}
return k
}
// emptySet is pretty self-explanatory.
type emptySet struct{}
func (e emptySet) Keys() iter.Seq[Key] { return func(func(Key) bool) {} }
func (e emptySet) Len() int { return 0 }
func (e emptySet) Find(Hint) Key { return nil }
// empty is a singleton instance of an empty Set.
var empty Set = emptySet{}
// singletonSet is an adapter that wraps a single Key as a Set.
type singletonSet struct{ key Key }
func (s *singletonSet) Keys() iter.Seq[Key] {
return func(f func(Key) bool) { f(s.key) }
}
func (s *singletonSet) Len() int { return 1 }
func (s *singletonSet) Find(hint Hint) Key {
if s.key.Algorithm() == hint.Algorithm() &&
(s.key.KeyID() == hint.KeyID() || s.key.Thumbprint() == hint.Thumbprint()) {
return s.key
}
return nil
}
// ParseSet parses a Set from a JWKS JSON input.
//
// If the top-level JSON structure is malformed, it returns an empty set and
// a fatal error. Otherwise, it iterates through the "keys" array, parsing
// each key individually. Keys that are invalid, unsupported, or occur multiple
// times, result in non-fatal errors. Ineligible keys (e.g., those meant for
// encryption) are silently skipped. If any non-fatal errors occurred, a joined
// error is returned alongside the set of successfully parsed keys.
func ParseSet(in []byte) (Set, error) {
var raw struct {
// Defer unmarshaling of individual keys to safely skip ineligible ones.
Keys []jsontext.Value `json:"keys"`
}
if err := json.Unmarshal(in, &raw); err != nil {
return empty, fmt.Errorf("invalid format: %w", err)
}
n := len(raw.Keys)
if n == 0 {
return empty, nil
}
s := newSet(n)
var errs []error
for i, v := range raw.Keys {
k, err := Parse(v)
if err != nil {
if errors.Is(err, ErrIneligibleKey) {
continue
}
err = fmt.Errorf("key at index %d: %w", i, err)
errs = append(errs, err)
continue
}
kid := k.KeyID()
x5t := k.Thumbprint()
if kid == "" && x5t == "" {
errs = append(errs, fmt.Errorf(
"key at index %d: missing both key id and thumbprint", i,
))
continue
}
// Check for duplicates before mutating the set.
if kid != "" {
if _, ok := s.kid[kid]; ok {
errs = append(errs, fmt.Errorf(
"key at index %d: duplicate key id %q", i, kid,
))
continue
}
}
if x5t != "" {
if _, ok := s.x5t[x5t]; ok {
errs = append(errs, fmt.Errorf(
"key at index %d: duplicate thumbprint %q", i, x5t,
))
continue
}
}
// Determines the index in the keys'slice where this new key will be stored.
// This is safe because we are appending linearly.
idx := len(s.keys)
// Append the key exactly once.
s.keys = append(s.keys, k)
// Update the lookup maps.
if kid != "" {
s.kid[kid] = idx
}
if x5t != "" {
s.x5t[x5t] = idx
}
}
return s, errors.Join(errs...)
}
// Write marshals a single Key into its JSON Web Key representation.
//
// It populates the standard JWK fields ("kty", "alg", "use", "kid", "x5t#S256")
// and the algorithm-specific public key parameters (e.g., "n" and "e" for RSA).
// The output is strictly compliant with RFC 7517 and RFC 7518, ensuring that
// elliptic curve coordinates are padded to the correct fixed width.
func Write(k Key) ([]byte, error) {
r, err := toRaw(k)
if err != nil {
return nil, err
}
return json.Marshal(r)
}
// WriteSet marshals a Set into a JSON Web Key Set (JWKS) document.
//
// The resulting JSON corresponds to the standard JWKS structure:
//
// {
// "keys": [ ... ]
// }
//
// This function efficiently iterates over the keys in the set, converting them
// to their raw JSON representation before marshaling the entire collection.
func WriteSet(s Set) ([]byte, error) {
// We marshal into a slice of raw structs directly.
// This is more efficient than calling Write() loop, which would
// result in double-marshaling.
keys := make([]raw, 0, s.Len())
for k := range s.Keys() {
r, err := toRaw(k)
if err != nil {
return nil, fmt.Errorf("encode key %q: %w", k.KeyID(), err)
}
keys = append(keys, *r)
}
return json.Marshal(struct {
Keys []raw `json:"keys"`
}{
Keys: keys,
})
}
// toRaw converts a Key object into the raw DTO.
func toRaw(k Key) (*raw, error) {
write, ok := writers[k.Algorithm()]
if !ok {
return nil, fmt.Errorf("unsupported algorithm %q", k.Algorithm())
}
// Populate standard metadata.
r := &raw{
Alg: k.Algorithm(),
Kid: k.KeyID(),
X5t: k.Thumbprint(),
Use: "sig",
}
// Populate algorithm-specific fields.
if err := write(k.Material(), r); err != nil {
return nil, err
}
return r, nil
}
// Singleton creates a Set that contains only the provided Key.
func Singleton(key Key) Set {
return &singletonSet{key: key}
}
// CacheSet extends the Set interface with scheduler.Tick, creating a component
// that can be deployed to a scheduler for automatic refreshing of a remote
// JWKS view in the background. The default implementation is backed by a
// cache.Controller.
type CacheSet interface {
Set
scheduler.Tick
}
// cacheSet is the concrete implementation of the CacheSet interface.
type cacheSet struct {
ctrl cache.Controller[Set]
}
// get safely retrieves the current Set from the cache controller. If the
// cache has not been populated yet (e.g., due to an initial network failure),
// it returns a static empty set to ensure that delegated operations like Find
// do not panic. This makes the Set resilient to transient startup issues.
func (s *cacheSet) get() Set {
if set, ok := s.ctrl.Get(); ok {
return set
}
return empty
}
func (s *cacheSet) Keys() iter.Seq[Key] { return s.get().Keys() }
func (s *cacheSet) Len() int { return s.get().Len() }
func (s *cacheSet) Find(hint Hint) Key { return s.get().Find(hint) }
func (s *cacheSet) Run(ctx context.Context) time.Duration {
return s.ctrl.Run(ctx)
}
// Ensure cacheSet implements CacheSet.
var _ CacheSet = (*cacheSet)(nil)
// mapper adapts the ParseSet function to the cache.Mapper interface.
var mapper cache.Mapper[Set] = func(r *cache.Response) (Set, error) {
set, err := ParseSet(r.Body)
if set.Len() == 0 {
return nil, errors.New("no valid keys found")
}
if err != nil && r.Logger.Enabled(r.Ctx, slog.LevelDebug) {
r.Logger.Debug("Some keys could not be parsed", slog.Any("error", err))
}
// Don't complain unless there are no keys available at all.
return set, nil
}
// NewCacheSet creates a new CacheSet that stays in sync with a remote JWKS
// endpoint. It must be deployed to a scheduler.Scheduler to begin the
// background fetching and refreshing process.
//
// The provided cache.Options can configure behaviors like refresh interval,
// request timeouts, and error handling. Parsing of retrieved key sets is
// extremely lenient: it will only fail if no valid keys are found at all.
func NewCacheSet(url string, opts ...cache.Option) CacheSet {
ctrl := cache.NewController(url, mapper, opts...)
return &cacheSet{ctrl}
}
// raw holds the JWK parameters including the key material.
type raw struct {
Kty string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use,omitempty"`
Ops []string `json:"key_ops,omitempty"`
Kid string `json:"kid,omitempty"`
X5t string `json:"x5t#S256,omitempty"`
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
}
// reader defines a function that decodes the key material from a raw JWK
// and constructs a concrete Key.
type reader func(r *raw) (Key, error)
// readers maps a JWA algorithm name to the function responsible for parsing
// its key material.
var readers map[string]reader
// addReader helps populate the readers map in a type-safe manner.
func addReader[T crypto.PublicKey](alg jwa.Algorithm[T], dec decoder[T]) {
readers[alg.String()] = func(r *raw) (Key, error) {
mat, err := dec(r)
if err != nil {
return nil, err
}
return newKey(alg, r.Kid, r.X5t, mat), nil
}
}
// decoder decodes the key material for a specific key type T.
type decoder[T crypto.PublicKey] func(*raw) (T, error)
// decodeRSA parses the material for an RSA public key.
func decodeRSA(raw *raw) (*rsa.PublicKey, error) {
if raw.Kty != "RSA" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
if len(raw.N) == 0 {
return nil, errors.New("missing modulus")
}
if len(raw.E) == 0 {
return nil, errors.New("missing public exponent")
}
nBytes, err := base64.RawURLEncoding.DecodeString(raw.N)
if err != nil {
return nil, fmt.Errorf("decode modulus: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(raw.E)
if err != nil {
return nil, fmt.Errorf("decode public exponent: %w", err)
}
// Exponents > 2^31-1 are extremely rare and not recommended.
if len(eBytes) > 4 {
return nil, errors.New("public exponent exceeds 32 bits")
}
n := new(big.Int).SetBytes(nBytes)
e := 0
// The conversion to a big-endian unsigned integer is safe because of the
// length check above.
for _, b := range eBytes {
e = (e << 8) | int(b)
}
return &rsa.PublicKey{N: n, E: e}, nil
}
// decodeECDSA creates a decoder for the specified elliptic curve.
func decodeECDSA(crv elliptic.Curve) decoder[*ecdsa.PublicKey] {
return func(raw *raw) (*ecdsa.PublicKey, error) {
if raw.Kty != "EC" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
if raw.Crv != crv.Params().Name {
return nil, fmt.Errorf("incompatible curve %q", raw.Crv)
}
if len(raw.X) == 0 {
return nil, errors.New("missing x coordinate")
}
if len(raw.Y) == 0 {
return nil, errors.New("missing y coordinate")
}
xBytes, err := base64.RawURLEncoding.DecodeString(raw.X)
if err != nil {
return nil, fmt.Errorf("decode x coordinate: %w", err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(raw.Y)
if err != nil {
return nil, fmt.Errorf("decode y coordinate: %w", err)
}
// Calculate the required byte size for the curve coordinates.
size := (crv.Params().BitSize + 7) / 8
if len(xBytes) > size || len(yBytes) > size {
return nil, errors.New("coordinate length exceeds curve size")
}
// Construct the SEC 1 uncompressed point format: 0x04 || X || Y.
uncompressed := make([]byte, 1+(2*size))
uncompressed[0] = 4
copy(uncompressed[1+size-len(xBytes):1+size], xBytes)
copy(uncompressed[1+(2*size)-len(yBytes):], yBytes)
pub, err := ecdsa.ParseUncompressedPublicKey(crv, uncompressed)
if err != nil {
return nil, fmt.Errorf("parse public key: %w", err)
}
return pub, nil
}
}
// decodeEdDSA parses the material for an EdDSA public key.
func decodeEdDSA(raw *raw) ([]byte, error) {
if raw.Kty != "OKP" {
return nil, fmt.Errorf("incompatible key type %q", raw.Kty)
}
var n int
switch raw.Crv {
case "Ed448":
n = ed448.PublicKeySize
case "Ed25519":
n = ed25519.PublicKeySize
default:
return nil, fmt.Errorf("unsupported curve %q", raw.Crv)
}
x, err := base64.RawURLEncoding.DecodeString(raw.X)
if err != nil {
return nil, fmt.Errorf("decode x coordinate: %w", err)
}
if m := len(x); m != n {
return nil, fmt.Errorf(
"illegal key size for %s curve: got %d, want %d", raw.Crv, m, n,
)
}
return x, nil
}
// writers maps a JWA algorithm name to the function responsible for encoding
// its key material.
var writers map[string]writer
// writer defines a function that encodes the key material into a marshallable
// JWT struct.
type writer func(mat any, r *raw) error
// addWriter helps populate the writers map in a type-safe manner.
func addWriter[T crypto.PublicKey](alg jwa.Algorithm[T], enc encoder[T]) {
writers[alg.String()] = func(mat any, r *raw) error {
pub, ok := mat.(T)
if !ok {
return fmt.Errorf("invalid key for algorithm %q", alg.String())
}
return enc(pub, r)
}
}
// encoder defines a function that populates the raw JWK parameters from the
// algorithm-specific key material.
type encoder[T crypto.PublicKey] func(mat T, r *raw) error
// encodeRSA populates the RSA-specific fields ("n", "e") in the raw JWK.
func encodeRSA(key *rsa.PublicKey, r *raw) error {
r.Kty = "RSA"
r.N = base64.RawURLEncoding.EncodeToString(key.N.Bytes())
e := key.E
if e == 0 {
return errors.New("RSA public exponent is zero")
}
var eBytes []byte
if e < 0xFFFFFF {
eBytes = make([]byte, 0, 3)
} else {
eBytes = make([]byte, 0, 4)
}
for e > 0 {
eBytes = append([]byte{byte(e)}, eBytes...)
e >>= 8
}
r.E = base64.RawURLEncoding.EncodeToString(eBytes)
return nil
}
// encodeECDSA populates the ECDSA-specific fields ("crv", "x", "y").
// It enforces fixed-width padding for coordinates as required by RFC 7518.
func encodeECDSA(key *ecdsa.PublicKey, r *raw) error {
r.Kty = "EC"
params := key.Params()
r.Crv = params.Name
// Obtain the SEC 1 uncompressed format: 0x04 || X || Y.
b, err := key.Bytes()
if err != nil {
return fmt.Errorf("encode ecdsa key: %w", err)
}
if len(b) < 1 || b[0] != 4 {
return errors.New("invalid public key format")
}
// Calculate coordinate size dynamically based on the returned slice.
size := (len(b) - 1) / 2
x := b[1 : 1+size]
y := b[1+size : 1+(2*size)]
r.X = base64.RawURLEncoding.EncodeToString(x)
r.Y = base64.RawURLEncoding.EncodeToString(y)
return nil
}
// encodeEdDSA populates the EdDSA-specific fields ("crv", "x").
// It determines the curve name based on the key length.
func encodeEdDSA(key []byte, r *raw) error {
r.Kty = "OKP"
switch len(key) {
case ed25519.PublicKeySize:
r.Crv = "Ed25519"
case ed448.PublicKeySize:
r.Crv = "Ed448"
default:
return fmt.Errorf("invalid EdDSA key length: %d", len(key))
}
r.X = base64.RawURLEncoding.EncodeToString(key)
return nil
}
// init initializes the readers and writers maps with supported algorithms.
func init() {
const size = 10
readers = make(map[string]reader, size)
addReader(jwa.RS256, decodeRSA)
addReader(jwa.RS384, decodeRSA)
addReader(jwa.RS512, decodeRSA)
addReader(jwa.PS256, decodeRSA)
addReader(jwa.PS384, decodeRSA)
addReader(jwa.PS512, decodeRSA)
addReader(jwa.ES256, decodeECDSA(elliptic.P256()))
addReader(jwa.ES384, decodeECDSA(elliptic.P384()))
addReader(jwa.ES512, decodeECDSA(elliptic.P521()))
addReader(jwa.EdDSA, decodeEdDSA)
writers = make(map[string]writer, size)
addWriter(jwa.RS256, encodeRSA)
addWriter(jwa.RS384, encodeRSA)
addWriter(jwa.RS512, encodeRSA)
addWriter(jwa.PS256, encodeRSA)
addWriter(jwa.PS384, encodeRSA)
addWriter(jwa.PS512, encodeRSA)
addWriter(jwa.ES256, encodeECDSA)
addWriter(jwa.ES384, encodeECDSA)
addWriter(jwa.ES512, encodeECDSA)
addWriter(jwa.EdDSA, encodeEdDSA)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jwt provides tools for parsing, verifying, and signing JSON Web
// Tokens (JWTs).
//
// This package uses generics to allow users to define their own custom claims
// structures. A common pattern is to embed the provided Reserved claims
// struct and add extra fields for any other claims present in the token.
//
// # Basic Verification
//
// Start by defining custom claims:
//
// type Claims struct {
// jwt.Reserved
// Scope string `json:"scp"`
// Extra map[string]any `json:",unknown"`
// }
//
// The top-level Verify function can be used for simple, one-off signature
// verification without claim validation:
//
// keySet, err := jwk.ParseSet(`{"keys": [...]}`)
// if err != nil { /* handle parsing error */ }
// claims, err := jwt.Verify[Claims](keySet, []byte("eyJhb..."))
//
// # Advanced Validation
//
// For advanced validation of claims like issuer, audience, and token age,
// create a reusable Verifier with the desired configuration:
//
// verifier := jwt.NewVerifier[Claims](keySet).
// WithIssuer("foo", "bar").
// WithAudience("baz").
// WithLeeway(1 * time.Minute).
// WithMaxAge(1 * time.Hour)
//
// claims, err := verifier.Verify([]byte("eyJhb..."))
// if err != nil { /* handle validation error */ }
// fmt.Println("Scope:", claims.Scope)
//
// # Basic Signing
//
// The top-level Sign function can be used to create signed tokens from any
// JSON-serializable struct or map. This is useful for simple tokens where
// you manually handle all claims:
//
// // keyPair must be a jwk.KeyPair (containing a private key)
// claims := map[string]any{"sub": "user_123", "admin": true}
// token, err := jwt.Sign(keyPair, claims)
//
// # Advanced Signing
//
// To enforce policies like expiration or consistent issuers, create a reusable
// Signer. Your claims struct must implement MutableClaims (embedding
// jwt.Reserved handles this automatically).
//
// signer := jwt.NewSigner(keyPair).
// WithIssuer("https://api.example.com").
// WithLifetime(1 * time.Hour)
//
// // The signer will automatically set "iss", "iat", and "exp" on the struct.
// claims := &MyClaims{
// Reserved: jwt.Reserved{Subject: "user_123"},
// Scope: "admin",
// }
// token, err := signer.Sign(claims)
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json/jsontext"
"encoding/json/v2"
"errors"
"fmt"
"slices"
"time"
"github.com/deep-rent/nexus/internal/rotor"
"github.com/deep-rent/nexus/jose/jwk"
)
// Header represents the decoded JOSE header of a JWT.
type Header jwk.Hint
type header struct {
Typ string `json:"typ,omitempty"`
Alg string `json:"alg"`
Kid string `json:"kid,omitempty"`
X5t string `json:"x5t#S256,omitempty"`
}
func (h *header) Type() string { return h.Typ }
func (h *header) Algorithm() string { return h.Alg }
func (h *header) KeyID() string { return h.Kid }
func (h *header) Thumbprint() string { return h.X5t }
var (
// ErrKeyNotFound is returned when no matching key is found in the JWK set.
ErrKeyNotFound = errors.New("no matching key found")
// ErrInvalidSignature is returned when the token's signature differs from
// the computed signature.
ErrInvalidSignature = errors.New("invalid signature")
)
// Token represents a parsed, but not necessarily verified, JWT.
// The generic type T is the user-defined claims structure.
type Token[T Claims] interface {
// Header returns the token's header parameters.
Header() Header
// Claims returns the token's payload claims.
Claims() T
// Verify checks the token's signature using the provided JWK set.
// It returns ErrKeyNotFound if no matching key is found or
// ErrInvalidSignature if the signature is incorrect.
Verify(set jwk.Set) error
}
// audience is a custom type to handle the JWT "aud" claim, which can be
// either a single string or an array of strings.
type token[T Claims] struct {
header Header
claims T
msg []byte
sig []byte
}
func (t *token[T]) Header() Header { return t.header }
func (t *token[T]) Claims() T { return t.claims }
func (t *token[T]) Verify(set jwk.Set) error {
key := set.Find(t.header)
if key == nil {
return ErrKeyNotFound
}
if !key.Verify(t.msg, t.sig) {
return ErrInvalidSignature
}
return nil
}
type audience []string
func (a *audience) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err == nil {
*a = audience{s}
return nil
}
var m []string
if err := json.Unmarshal(b, &m); err == nil {
*a = audience(m)
return nil
}
return errors.New("expected a string or an array of strings")
}
// Claims provides access to the standard JWT claims.
// It is used by Verifier for claim validation.
type Claims interface {
// ID returns the "jti" (JWT ID) claim, or an empty string if absent.
ID() string
// Subject returns the "sub" (Subject) claim, or an empty string if absent.
Subject() string
// Issuer returns the "iss" (Issuer) claim, or an empty string if absent.
Issuer() string
// Audience returns the "aud" (Audience) claim, or nil if absent.
Audience() []string
// IssuedAt returns the "iat" (Issued At) claim, or the zero time if absent.
IssuedAt() time.Time
// ExpiresAt returns the "exp" (Expires At) claim, or the zero time if absent.
ExpiresAt() time.Time
// NotBefore returns the "nbf" (Not Before) claim, or the zero time if absent.
NotBefore() time.Time
}
// MutableClaims extends Claims with setters for standard JWT claims.
//
// The setter methods are not safe for concurrent use and should only be called
// during token creation.
type MutableClaims interface {
Claims
// SetID sets the "jti" (JWT ID) claim.
SetID(id string)
// SetSubject sets the "sub" (Subject) claim.
SetSubject(sub string)
// SetIssuer sets the "iss" (Issuer) claim.
SetIssuer(iss string)
// SetAudience sets the "aud" (Audience) claim.
SetAudience(aud []string)
// SetIssuedAt sets the "iat" (Issued At) claim.
SetIssuedAt(t time.Time)
// SetExpiresAt sets the "exp" (Expires At) claim.
SetExpiresAt(t time.Time)
// SetNotBefore sets the "nbf" (Not Before) claim.
SetNotBefore(t time.Time)
}
// Reserved contains the standard registered claims for a JWT. It implements
// the Claims interface and should be embedded in custom claims structs to
// enable standard claim handling.
type Reserved struct {
Jti string `json:"jti,omitempty"` // JWT ID
Sub string `json:"sub,omitempty"` // Subject
Iss string `json:"iss,omitempty"` // Issuer
Aud audience `json:"aud,omitempty"` // Audience
Iat time.Time `json:"iat,omitzero,format:unix"` // Issued At
Exp time.Time `json:"exp,omitzero,format:unix"` // Expires At
Nbf time.Time `json:"nbf,omitzero,format:unix"` // Not Before
}
func (r *Reserved) ID() string { return r.Jti }
func (r *Reserved) SetID(id string) { r.Jti = id }
func (r *Reserved) Subject() string { return r.Sub }
func (r *Reserved) SetSubject(sub string) { r.Sub = sub }
func (r *Reserved) Issuer() string { return r.Iss }
func (r *Reserved) SetIssuer(iss string) { r.Iss = iss }
func (r *Reserved) Audience() []string { return r.Aud }
func (r *Reserved) SetAudience(aud []string) { r.Aud = aud }
func (r *Reserved) IssuedAt() time.Time { return r.Iat }
func (r *Reserved) SetIssuedAt(t time.Time) { r.Iat = t }
func (r *Reserved) ExpiresAt() time.Time { return r.Exp }
func (r *Reserved) SetExpiresAt(t time.Time) { r.Exp = t }
func (r *Reserved) NotBefore() time.Time { return r.Nbf }
func (r *Reserved) SetNotBefore(t time.Time) { r.Nbf = t }
// Ensure Reserved implements the MutableClaims interface.
var _ MutableClaims = (*Reserved)(nil)
// DynamicClaims represents a standard JWT payload extended with arbitrary
// custom claims. It embeds the standard Reserved claims and captures any
// unmapped JSON properties into the Other map.
//
// By applying jsontext.Value and the `json:",inline"` tag from the
// encoding/json/v2 package, custom claims are retained as raw JSON bytes.
// This defers parsing until the exact target type is known, avoiding the
// common pitfalls of default map[string]any unmarshaling (such as all
// numbers defaulting to float64).
type DynamicClaims struct {
Reserved
Other map[string]jsontext.Value `json:",inline"`
}
// Get retrieves a specific custom claim by key from the DynamicClaims
// payload and unmarshals it into the requested type T.
//
// It safely handles nil pointers, missing keys, and parsing errors. If the
// receiver 'c' is nil, the 'Other' map is uninitialized, the key is not
// found, or the raw JSON cannot be successfully unmarshaled into type T,
// Get returns the zero value of T and false. Otherwise, it returns the
// parsed value and true.
func Get[T any](c *DynamicClaims, key string) (T, bool) {
if c == nil || c.Other == nil {
var zero T
return zero, false
}
val, ok := c.Other[key]
if !ok {
var zero T
return zero, false
}
var out T
if err := json.Unmarshal(val, &out); err != nil {
var zero T
return zero, false
}
return out, true
}
// dot is the byte value for the delimiting character of JWS segments.
const dot = byte('.')
// Parse decodes a JWT from its compact serialization format into a Token
// without verifying the signature. The type parameter T specifies the target
// struct for the token's claims. If the token is malformed or the payload does
// not unmarshal into T (using encoding/json/v2), an error is returned.
func Parse[T Claims](in []byte) (Token[T], error) {
i := bytes.IndexByte(in, dot)
j := bytes.LastIndexByte(in, dot)
if i <= 0 || i == j || j == len(in)-1 {
return nil, errors.New("expected three dot-separated segments")
}
h, err := decode(in[:i])
if err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
header := new(header)
if err := json.Unmarshal(h, header); err != nil {
return nil, fmt.Errorf("failed to unmarshal header: %w", err)
}
if typ := header.Typ; typ != "" && typ != "JWT" {
return nil, fmt.Errorf("unexpected token type %q", typ)
}
c, err := decode(in[i+1 : j])
if err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var claims T
if err := json.Unmarshal(c, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
sig, err := decode(in[j+1:])
if err != nil {
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
msg := in[:j]
return &token[T]{
header: header,
claims: claims,
msg: msg,
sig: sig,
}, nil
}
// decode is a helper for Base64URL decoding without padding.
func decode(src []byte) ([]byte, error) {
n := base64.RawURLEncoding.DecodedLen(len(src))
d := make([]byte, n)
k, err := base64.RawURLEncoding.Decode(d, src)
if err != nil {
return nil, err
}
return d[:k], nil
}
// Verify first parses a JWT and then verifies its signature against a given key
// set. The type parameter T specifies the target struct for the token's claims.
//
// This function only checks the cryptographic signature, not the content of the
// claims. For claim validation (e.g., issuer, audience, expiration), create and
// configure a Verifier. It is a shorthand for Parse followed by calling Verify
// on the resulting Token.
func Verify[T Claims](set jwk.Set, in []byte) (T, error) {
tok, err := Parse[T](in)
if err != nil {
var zero T
return zero, err
}
if err := tok.Verify(set); err != nil {
var zero T
return zero, err
}
return tok.Claims(), nil
}
var (
// ErrInvalidIssuer signals that the "iss" claim did not match any of the
// expected issuers.
ErrInvalidIssuer = errors.New("invalid issuer")
// ErrInvalidAudience signals that the "aud" claim did not match any of the
// expected audiences.
ErrInvalidAudience = errors.New("invalid audience")
// ErrTokenExpired signals that the "exp" claim is in the past.
ErrTokenExpired = errors.New("token is expired")
// ErrTokenNotYetActive signals that the "nbf" claim is in the future.
ErrTokenNotYetActive = errors.New("token not yet active")
// ErrTokenTooOld signals that the "iat" claim is further in the past than
// the configured maximum age.
ErrTokenTooOld = errors.New("token is too old")
)
// Verifier is a configured, reusable JWT verifier. The type parameter T is the
// user-defined struct for the token's claims. It must implement the Claims
// interface, or else verification will always fail.
type Verifier[T Claims] struct {
set jwk.Set
issuers []string
audiences []string
leeway time.Duration
age time.Duration
now func() time.Time
}
// WithIssuers adds one or more trusted issuers to the verifier. If a token's
// "iss" claim is missing or does not match one of these, it will be rejected.
// This option can be used multiple times to append additional values. By
// default, no issuer validation is performed.
//
// This method is not thread-safe and should be called only during setup.
func (v *Verifier[T]) WithIssuers(iss ...string) *Verifier[T] {
v.issuers = append(v.issuers, iss...)
return v
}
// WithAudiences adds one or more trusted audiences to the verifier. If the
// token's "aud" claim is missing or does not contain at least one of these
// values, it will be rejected. This option can be used multiple times to append
// additional values. By default, no audience validation is performed.
//
// This method is not thread-safe and should be called only during setup.
func (v *Verifier[T]) WithAudiences(aud ...string) *Verifier[T] {
v.audiences = append(v.audiences, aud...)
return v
}
// WithLeeway sets a grace period to allow for clock skew in temporal
// validations of the "exp", "nbf", and "iat" claims. It is subtracted from or
// added to the current time as appropriate. The default is zero, meaning no
// leeway. Negative values will be ignored.
//
// This method is not thread-safe and should be called only during setup.
func (v *Verifier[T]) WithLeeway(d time.Duration) *Verifier[T] {
if d > 0 {
v.leeway = d
}
return v
}
// WithMaxAge sets the maximum age for tokens based on their "iat" claim.
// Tokens without an "iat" claim will no longer be accepted. The default is
// zero, meaning no age validation. Negative values will be ignored.
//
// This method is not thread-safe and should be called only during setup.
func (v *Verifier[T]) WithMaxAge(d time.Duration) *Verifier[T] {
if d > 0 {
v.age = d
}
return v
}
// WithClock sets the function used to retrieve the current time during
// validation. This is useful for deterministic testing or synchronizing with
// an external time source. The default is time.Now.
//
// This method is not thread-safe and should be called only during setup.
func (v *Verifier[T]) WithClock(now func() time.Time) *Verifier[T] {
if now != nil {
v.now = now
}
return v
}
// NewVerifier creates a new verifier bound to a specific JWK set.
// The type parameter T is the user-defined struct for the token's claims.
// Further configuration can be applied using the With... setters.
func NewVerifier[T Claims](set jwk.Set) *Verifier[T] {
return &Verifier[T]{
set: set,
now: time.Now,
}
}
// Verify parses a token from its compact serialization, verifies its
// signature against the verifier's key set, and validates its claims
// according to the verifier's configuration.
func (v *Verifier[T]) Verify(in []byte) (T, error) {
c, err := Verify[T](v.set, in)
if err != nil {
var zero T
return zero, err
}
now := v.now()
if len(v.issuers) > 0 && !slices.Contains(v.issuers, c.Issuer()) {
var zero T
return zero, ErrInvalidIssuer
}
if len(v.audiences) > 0 {
found := false
for _, aud := range v.audiences {
if slices.Contains(c.Audience(), aud) {
found = true
break
}
}
if !found {
var zero T
return zero, ErrInvalidAudience
}
}
if nbf := c.NotBefore(); !nbf.IsZero() {
if now.Add(v.leeway).Before(nbf) {
var zero T
return zero, ErrTokenNotYetActive
}
}
if exp := c.ExpiresAt(); !exp.IsZero() {
if now.Add(-v.leeway).After(exp) {
var zero T
return zero, ErrTokenExpired
}
}
if iat := c.IssuedAt(); v.age > 0 && !iat.IsZero() {
if iat.Add(v.age).Before(now.Add(-v.leeway)) {
var zero T
return zero, ErrTokenTooOld
}
}
return c, nil
}
// Sign creates a new signed JWT using the provided KeyPair and claims.
//
// It marshals the claims using encoding/json/v2, creates a header based on
// the key's properties, and signs the payload. The claims argument can be
// any type that serializes to a JSON object.
func Sign(k jwk.KeyPair, claims any) ([]byte, error) {
// Prepare and marshal the header.
header := &header{
Typ: "JWT",
Alg: k.Algorithm(),
Kid: k.KeyID(),
X5t: k.Thumbprint(),
}
h, err := json.Marshal(header)
if err != nil {
return nil, fmt.Errorf("failed to marshal header: %w", err)
}
h = encode(h)
// Marshal the claims.
c, err := json.Marshal(claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
c = encode(c)
// Construct the signing input (message).
msg := make([]byte, 0, len(h)+1+len(c))
msg = append(msg, h...)
msg = append(msg, dot)
msg = append(msg, c...)
// Sign the message.
sig, err := k.Sign(msg)
if err != nil {
return nil, fmt.Errorf("failed to sign token: %w", err)
}
sig = encode(sig)
// Assemble the final token.
token := make([]byte, 0, len(msg)+1+len(sig))
token = append(token, msg...)
token = append(token, dot)
token = append(token, sig...)
return token, nil
}
// encode is a helper for Base64URL encoding without padding.
func encode(src []byte) []byte {
dst := make([]byte, base64.RawURLEncoding.EncodedLen(len(src)))
base64.RawURLEncoding.Encode(dst, src)
return dst
}
// Signer is a configured, reusable JWT creator. It allows setting default
// claims (like Issuer and Audience) and enforcing token lifetime (Expiration).
type Signer struct {
rot rotor.Rotor[jwk.KeyPair]
iat bool
iss string
aud []string
ttl time.Duration
now func() time.Time
}
// NewSigner creates a new Signer that uses the provided key pool for signing.
// At least one key pair must be provided; otherwise, it panics. If multiple
// keys are given, they will be rotated through in a round-robin fashion to
// ensure even usage across the key pool. Further configuration can be applied
// using the With... setters.
func NewSigner(keys ...jwk.KeyPair) *Signer {
return &Signer{
rot: rotor.New(keys),
iat: true,
now: time.Now,
}
}
// WithIssuedAt enables or disables automatic setting of the "iat" (Issued At)
// claim for all tokens created by this signer. It is enabled by default and
// will be stamped with the current time.
//
// This method is not thread-safe and should be called only during setup.
func (s *Signer) WithIssuedAt(use bool) *Signer {
s.iat = use
return s
}
// WithIssuer sets the "iss" (Issuer) claim for all tokens created by this
// signer. If the user-provided claims already contain an issuer, this
// configuration will overwrite it.
//
// This method is not thread-safe and should be called only during setup.
func (s *Signer) WithIssuer(iss string) *Signer {
s.iss = iss
return s
}
// WithAudience sets the "aud" (Audience) claim. If the user-provided claims
// already contain an audience, this configuration will overwrite it.
//
// This method is not thread-safe and should be called only during setup.
func (s *Signer) WithAudience(aud ...string) *Signer {
s.aud = aud
return s
}
// WithLifetime sets the duration for which tokens are valid. It calculates the
// "exp" (Expires At) claim by adding this duration to the current time.
// If zero (default), no "exp" claim is added unless provided in the input
// claims.
//
// This method is not thread-safe and should be called only during setup.
func (s *Signer) WithLifetime(d time.Duration) *Signer {
if d > 0 {
s.ttl = d
}
return s
}
// WithClock sets the function used to retrieve the current time when
// timestamping tokens ("iat", "nbf", "exp"). This is useful for deterministic
// testing. The default is time.Now.
//
// This method is not thread-safe and should be called only during setup.
func (s *Signer) WithClock(now func() time.Time) *Signer {
if now != nil {
s.now = now
}
return s
}
// Sign applies the signer's configuration (issuer, audience, and temporal
// validity) directly to the mutable claims object, then signs it.
func (s *Signer) Sign(claims MutableClaims) ([]byte, error) {
now := s.now()
// Always stamp the current time as time of issuance.
if s.iat {
claims.SetIssuedAt(now)
}
// Apply configured issuer name.
if s.iss != "" {
claims.SetIssuer(s.iss)
}
// Apply configured audience.
if len(s.aud) > 0 {
claims.SetAudience(s.aud)
}
// Calculate and apply expiration if a lifetime is configured.
if s.ttl > 0 {
claims.SetExpiresAt(now.Add(s.ttl))
}
key := s.rot.Next()
// Delegate to the low-level Sign function.
return Sign(key, claims)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package log provides a configurable constructor for the standard slog.Logger,
// allowing for easy setup using the functional options pattern.
//
// It simplifies the creation of a structured logger by abstracting away the
// handler setup and providing flexible options for setting the level, format,
// and output from common types like strings.
//
// # Usage:
//
// Create a logger that outputs JSON at a debug level to standard error:
//
// logger := log.New(
// log.WithLevel("debug"),
// log.WithFormat("json"),
// log.WithWriter(os.Stderr),
// log.WithAddSource(true), // Include file and line number.
// )
//
// slog.SetDefault(logger)
// slog.Debug("This is a debug message")
//
// Create a multi-target logger using Combine and NewHandler:
//
// h1 := log.NewHandler(log.WithLevel("debug"), log.WithFormat("text"), log.WithWriter(os.Stdout))
// h2 := log.NewHandler(log.WithLevel("error"), log.WithFormat("json"), log.WithWriter(os.Stderr))
// multiLogger := log.Combine(h1, h2)
//
// slog.SetDefault(multiLogger)
// slog.Debug("This is a debug message")
package log
import (
"fmt"
"io"
"log/slog"
"os"
"strings"
)
// Default configuration values for a new logger.
const (
DefaultLevel = slog.LevelInfo
DefaultAddSource = false
DefaultFormat = FormatText
)
// Format defines the log output format, such as JSON or plain text.
type Format uint8
const (
FormatText Format = iota // Human-readable text format.
FormatJSON // JSON format suitable for machine parsing.
)
// String returns the lower-case string representation of the log format.
func (f Format) String() string {
switch f {
case FormatJSON:
return "json"
default:
return "text"
}
}
// New creates and configures a new slog.Logger. By default, it logs at
// slog.LevelInfo in plain text to os.Stdout, without source information.
// These defaults can be overridden by passing in one or more Option functions.
func New(opts ...Option) *slog.Logger {
return slog.New(NewHandler(opts...))
}
// Combine creates a new slog.Logger that broadcasts log records to multiple
// provided slog.Handlers simultaneously using slog.NewMultiHandler.
func Combine(handlers ...slog.Handler) *slog.Logger {
return slog.New(slog.NewMultiHandler(handlers...))
}
// NewHandler creates and configures a new slog.Handler. By default, it
// sets up a text handler logging at slog.LevelInfo to os.Stdout.
// These defaults can be overridden by passing in one or more Option functions.
func NewHandler(opts ...Option) slog.Handler {
c := config{
Level: DefaultLevel,
AddSource: DefaultAddSource,
Format: DefaultFormat,
Writer: os.Stdout,
}
for _, opt := range opts {
opt(&c)
}
w := c.Writer
o := &slog.HandlerOptions{
Level: c.Level,
AddSource: c.AddSource,
}
var handler slog.Handler
switch c.Format {
case FormatJSON:
handler = slog.NewJSONHandler(w, o)
default:
handler = slog.NewTextHandler(w, o)
}
return handler
}
// config holds the configuration settings for the logger.
type config struct {
Level slog.Level
AddSource bool
Format Format
Writer io.Writer
}
// Option defines a function that modifies the logger configuration.
type Option func(*config)
// WithLevel sets the minimum log level. It accepts either a slog.Level constant
// (e.g., slog.LevelDebug) or a case-insensitive string (e.g., "debug") as
// handled by ParseLevel. If an invalid string or type is provided, the option
// is a no-op.
func WithLevel(v any) Option {
return func(c *config) {
switch t := v.(type) {
case slog.Level:
c.Level = t
case string:
level, err := ParseLevel(t)
if err == nil {
c.Level = level
}
}
}
}
// WithFormat sets the log output format. It accepts either a Format constant
// (FormatText or FormatJSON) or a case-insensitive string ("text" or "json")
// as handled by ParseFormat. If an invalid string or type is provided, the
// option is a no-op.
func WithFormat(v any) Option {
return func(c *config) {
switch t := v.(type) {
case Format:
c.Format = t
case string:
format, err := ParseFormat(t)
if err == nil {
c.Format = format
}
}
}
}
// WithAddSource configures the logger to include the source code position (file
// and line number) in each log entry.
//
// Note that this has a performance cost and should be used judiciously, often
// enabled only during development or at debug levels.
func WithAddSource(add bool) Option {
return func(c *config) {
c.AddSource = add
}
}
// WithWriter returns an Option that sets the output destination for the logs.
// If the provided io.Writer is nil, it is ignored.
func WithWriter(w io.Writer) Option {
return func(c *config) {
if w != nil {
c.Writer = w
}
}
}
// ParseLevel converts a case-insensitive string into a slog.Level.
// It accepts standard level names like "debug", "info", "warn", and "error".
// It returns an error if the string is not a valid level.
func ParseLevel(s string) (level slog.Level, err error) {
if e := level.UnmarshalText([]byte(s)); e != nil {
err = fmt.Errorf("invalid log level %q", s)
}
return
}
// ParseFormat converts a case-insensitive string into a Format.
// Valid inputs are "text" and "json". It returns an error for any other value.
func ParseFormat(s string) (format Format, err error) {
switch strings.ToLower(s) {
case "json":
format = FormatJSON
return
case "text":
format = FormatText
return
default:
err = fmt.Errorf("invalid log format %q", s)
return
}
}
// Silent creates a logger that discards all output.
func Silent() *slog.Logger {
const LevelSilent = slog.Level(100)
return New(
WithWriter(io.Discard),
WithLevel(LevelSilent),
)
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cors provides a configurable CORS (Cross-Origin Resource Sharing)
// middleware for http.Handlers.
//
// # Usage
//
// The New function creates the middleware pipe, which can be configured with
// functional options (e.g., WithAllowedOrigins, WithAllowedMethods). The
// middleware automatically handles preflight (OPTIONS) requests and injects the
// appropriate CORS headers into responses for actual requests.
//
// Example:
//
// // Configure CORS to allow requests from a specific origin with
// // restricted methods and additional headers.
// pipe := cors.New(
// cors.WithAllowedOrigins("https://example.com"),
// cors.WithAllowedMethods(http.MethodGet, http.MethodOptions),
// cors.WithAllowedHeaders("Authorization", "Content-Type"),
// cors.WithMaxAge(12*time.Hour),
// )
//
// handler := http.HandlerFunc( ... )
// // Apply the CORS middleware as one of the first layers.
// chainedHandler := middleware.Chain(handler, pipe)
//
// http.ListenAndServe(":8080", chainedHandler)
package cors
import (
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/deep-rent/nexus/middleware"
)
// wildcard is a special value that can be passed in configuration to allow
// requests from any origin.
const wildcard = "*"
// config stores the pre-computed configuration for internal use.
type config struct {
allowedOrigins map[string]struct{}
allowedMethods string
allowedHeaders string
exposedHeaders string
allowCredentials bool
maxAge string
}
// Option is a function that configures the CORS middleware.
type Option func(*config)
// WithAllowedOrigins sets the allowed origins for CORS requests.
//
// By default, all origins are allowed. The same behavior can be achieved by
// leaving the list empty or by manually including the special wildcard "*".
// In other cases, this option restricts requests to a specific whitelist. If
// credentials are enabled via WithAllowCredentials, browsers forbid a wildcard
// origin, and this middleware will dynamically reflect the request's Origin
// header if it is in the allowed list.
func WithAllowedOrigins(origins ...string) Option {
return func(c *config) {
if len(origins) != 0 && !slices.Contains(origins, wildcard) {
c.allowedOrigins = make(map[string]struct{}, len(origins))
for _, origin := range origins {
c.allowedOrigins[origin] = struct{}{}
}
}
}
}
// WithAllowedMethods sets the allowed HTTP methods for CORS requests.
//
// If no methods are provided, this header is omitted by default, and only
// simple methods (GET, POST, HEAD) are implicitly allowed by browsers for
// non-preflighted requests. It is recommended to list all methods your API
// supports, including OPTIONS.
func WithAllowedMethods(methods ...string) Option {
return func(c *config) {
if len(methods) != 0 {
c.allowedMethods = strings.Join(methods, ", ")
}
}
}
// WithAllowedHeaders sets the allowed HTTP headers for CORS requests.
//
// This is necessary for any non-standard headers the client needs to send,
// such as "Authorization" or custom "X-" headers. If not set, browsers will
// only permit requests with CORS-safelisted request headers.
func WithAllowedHeaders(headers ...string) Option {
return func(c *config) {
if len(headers) != 0 {
c.allowedHeaders = strings.Join(headers, ", ")
}
}
}
// WithExposedHeaders sets the HTTP headers that are safe to expose to the
// API of a CORS API specification.
//
// By default, client-side scripts can only access a limited set of simple
// response headers. This option lists additional headers (like a custom
// "X-Pagination-Total" header) that should be made accessible to the script.
func WithExposedHeaders(headers ...string) Option {
return func(c *config) {
if len(headers) != 0 {
c.exposedHeaders = strings.Join(headers, ", ")
}
}
}
// WithAllowCredentials indicates whether the response to the request can be
// exposed when the credentials flag is true.
//
// When used as part of a response to a preflight request, it indicates that the
// actual request can include cookies and other user credentials. This option
// defaults to false. Note that browsers require a specific origin (not a
// wildcard) in the Access-Control-Allow-Origin header when this is enabled.
func WithAllowCredentials(allow bool) Option {
return func(c *config) {
c.allowCredentials = allow
}
}
// WithMaxAge indicates how long the results of a preflight request can be
// cached by the browser, in seconds.
//
// If set to 0 (the default), the header is omitted. Be aware that browsers
// have a default internal limit (usually 5 seconds) when this header is
// missing. This results in a preflight request for almost every API call, which
// can double the traffic to your server. It is recommended to set this to a
// higher value (e.g., 10 minutes) for stable APIs to reduce latency.
func WithMaxAge(d time.Duration) Option {
return func(c *config) {
if d > 0 {
c.maxAge = strconv.FormatInt(int64(d.Seconds()), 10)
}
}
}
// New creates a middleware Pipe that handles CORS based on the provided
// options.
//
// The middleware distinguishes between preflight and actual requests. Preflight
// (OPTIONS) requests are intercepted and terminated with a 204 No Content
// response. For actual requests, it adds the necessary CORS headers to the
// response before passing control to the next handler. Non-CORS requests are
// passed through without modification.
func New(opts ...Option) middleware.Pipe {
cfg := config{}
for _, opt := range opts {
opt(&cfg)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if proceed := handle(&cfg, w, r); proceed {
next.ServeHTTP(w, r)
}
})
}
}
// handle processes CORS headers and returns true if the request should
// be passed to the next handler. It returns false if the request has been
// fully handled, such as in a preflight request.
func handle(cfg *config, w http.ResponseWriter, r *http.Request) bool {
origin := r.Header.Get("Origin")
// Pass through non-CORS requests.
if origin == "" {
return true
}
// Apply this header immediately to ensure caches respect the difference
// between allowed and disallowed origin responses.
h := w.Header()
h.Add("Vary", "Origin")
preflight := r.Method == http.MethodOptions
// Pass through invalid preflight requests.
if preflight && r.Header.Get("Access-Control-Request-Method") == "" {
return true
}
// Validate origin if not in wildcard mode.
if cfg.allowedOrigins != nil {
if _, ok := cfg.allowedOrigins[origin]; !ok {
return true // Let non-matching origins pass through without CORS headers.
}
}
if !cfg.allowCredentials && cfg.allowedOrigins == nil {
origin = wildcard
}
h.Set("Access-Control-Allow-Origin", origin)
if cfg.allowCredentials {
h.Set("Access-Control-Allow-Credentials", "true")
}
// Handle preflight requests.
if preflight {
if cfg.allowedMethods != "" {
h.Set("Access-Control-Allow-Methods", cfg.allowedMethods)
}
if cfg.allowedHeaders != "" {
h.Set("Access-Control-Allow-Headers", cfg.allowedHeaders)
}
if cfg.maxAge != "" {
h.Set("Access-Control-Max-Age", cfg.maxAge)
}
w.WriteHeader(http.StatusNoContent)
return false // Terminate request chain.
}
// Handle actual requests.
if cfg.exposedHeaders != "" {
h.Set("Access-Control-Expose-Headers", cfg.exposedHeaders)
}
return true
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gzip provides an HTTP middleware for compressing response bodies
// using the gzip algorithm. It automatically adds the "Content-Encoding: gzip"
// header and compresses the payload for clients that support it (indicated by
// the "Accept-Encoding" request header).
//
// # Usage
//
// The middleware is designed to be efficient. It pools gzip writers to reduce
// memory allocations and gracefully skips compression for responses tha
// already have a "Content-Encoding" header set.
//
// Example:
//
// // Create the final handler.
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Header().Set("Content-Type", "text/plain")
// w.Write([]byte("This is a long string that will be compressed."))
// })
//
// // Create a gzip middleware pipe with the highest level if compression.
// pipe := gzip.New(
// gzip.WithCompressionLevel(gzip.BestCompression),
// gzip.WithExcludeMimeTypes("text/*", "application/font-woff"),
// )
//
// // Apply the CORS middleware as one of the first layers.
// chainedHandler := middleware.Chain(handler, pipe)
//
// http.ListenAndServe(":8080", chainedHandler)
package gzip
import (
"bufio"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/middleware"
)
// Mirror constants from the compress/gzip package for easy access without
// requiring an extra import.
const (
BestCompression = gzip.BestCompression
BestSpeed = gzip.BestSpeed
DefaultCompression = gzip.DefaultCompression
NoCompression = gzip.NoCompression
)
// defaultExcludeList lists common media types that are already compressed.
var defaultExcludeList = []string{
// Media
"image/*",
"video/*",
"audio/*",
// Fonts
"font/*",
// Archives & Documents
"application/zip",
"application/gzip",
"application/pdf",
"application/wasm",
}
// interceptor wraps an http.ResponseWriter to transparently compress the
// response body with gzip. It also implements http.Hijacker and http.Flusher to
// support protocol upgrades and streaming.
type interceptor struct {
http.ResponseWriter
gz *gzip.Writer
level int
exclude []string
pool *sync.Pool
wrote bool // Tracks if WriteHeader has been called.
hijacked bool // Tracks if the connection has been hijacked.
skip bool // Decide whether to skip compression.
}
// WriteHeader sets the Content-Encoding header and deletes Content-Length
// before writing the status code. Deleting Content-Length is crucial, as the
// size of the compressed content is unknown until it's fully written.
func (w *interceptor) WriteHeader(statusCode int) {
if w.wrote {
return
}
w.wrote = true
if w.ResponseWriter.Header().Get("Content-Encoding") != "" {
w.skip = true
}
mime := header.MediaType(w.Header())
if mime != "" {
for _, t := range w.exclude {
if strings.HasSuffix(t, "*") {
if strings.HasPrefix(mime, t[:len(t)-1]) {
w.skip = true
break
}
} else {
if mime == t {
w.skip = true
break
}
}
}
}
if !w.skip {
w.Header().Set("Content-Encoding", "gzip")
w.Header().Del("Content-Length")
w.gz = w.pool.Get().(*gzip.Writer)
w.gz.Reset(w.ResponseWriter)
}
w.ResponseWriter.WriteHeader(statusCode)
}
// Write compresses the data and writes it to the underlying ResponseWriter.
// It also handles setting the Content-Encoding header on the first write.
func (w *interceptor) Write(b []byte) (int, error) {
if !w.wrote {
w.WriteHeader(http.StatusOK)
}
if w.skip {
return w.ResponseWriter.Write(b)
}
return w.gz.Write(b)
}
// Close flushes any buffered data, closes the gzip writer, and returns it to
// the pool.
func (w *interceptor) Close() {
// If the connection was hijacked, don't write the gzip footer.
// Just return the writer to the pool.
if w.gz != nil {
if !w.hijacked {
_ = w.gz.Close()
}
w.gz.Reset(io.Discard)
w.pool.Put(w.gz)
w.gz = nil
}
}
// Hijack implements the http.Hijacker interface, allowing the underlying
// connection to be taken over for protocol upgrades like WebSockets.
func (w *interceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijacking not supported")
}
w.hijacked = true
return hijacker.Hijack()
}
// Flush implements the http.Flusher interface, enabling incremental flushing of
// the response body, which is useful for streaming data.
func (w *interceptor) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
if w.gz != nil {
_ = w.gz.Flush()
}
flusher.Flush()
}
}
// Ensure interceptor implements the necessary contracts.
var _ http.ResponseWriter = (*interceptor)(nil)
var _ http.Hijacker = (*interceptor)(nil)
var _ http.Flusher = (*interceptor)(nil)
// New creates a middleware Pipe that compresses HTTP responses using gzip
// with the specified options.
//
// The middleware is a no-op if the client does not send an Accept-Encoding
// header including "gzip" or if the response already has a non-empty
// Content-Encoding header. It adds the "Vary: Accept-Encoding" header to
// responses to prevent cache poisoning.
func New(opts ...Option) middleware.Pipe {
cfg := config{
level: DefaultCompression,
exclude: defaultExcludeList,
}
for _, opt := range opts {
opt(&cfg)
}
pool := &sync.Pool{
New: func() any {
// Errors are ignored as they only occur with an invalid level,
// which we guard against in the option.
gw, _ := gzip.NewWriterLevel(io.Discard, cfg.level)
return gw
},
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip if the client doesn't accept gzip compression.
if !header.Accepts(r.Header.Get("Accept-Encoding"), "gzip") ||
w.Header().Get("Content-Encoding") != "" {
next.ServeHTTP(w, r)
return
}
// Create the gzip response writer.
gzw := &interceptor{
ResponseWriter: w,
level: cfg.level,
exclude: cfg.exclude,
pool: pool,
}
defer gzw.Close()
// Indicate that the response is subject to content negotiation.
gzw.Header().Add("Vary", "Accept-Encoding")
next.ServeHTTP(gzw, r)
})
}
}
// config holds the middleware configuration.
type config struct {
level int
exclude []string
}
// Option is a function that configures the middleware.
type Option func(*config)
// WithCompressionLevel sets the compression level. It accepts values ranging
// from BestSpeed (1) to BestCompression (9). For other values, it will fall
// back to DefaultCompression, a good balance between speed and
// compression ratio.
func WithCompressionLevel(level int) Option {
return func(c *config) {
if level >= NoCompression && level <= BestCompression {
c.level = level
}
}
}
// WithExcludeMimeTypes adds MIME types to the list of content types that
// should not be compressed. This option is additive and can be called
// multiple times; it appends to the default exclusion list rather than
// replacing it.
//
// The matching logic supports two formats:
//
// - Exact: Provide the full MIME type (e.g., "application/pdf").
// - Prefix: End the MIME type with a wildcard "*" (e.g., "image/*")
// to exclude all subtypes for that primary type.
func WithExcludeMimeTypes(types ...string) Option {
return func(c *config) {
for _, t := range types {
c.exclude = append(c.exclude, strings.ToLower(t))
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package middleware provides a standard approach for chaining and composing
// HTTP middleware.
//
// # Usage
//
// The core type is Pipe, an adapter that wraps an http.Handler to add
// functionality. The Chain function composes these pipes into a single handler.
// The package also includes common middleware like Recover for panic handling,
// RequestID for tracing, and Log for request logging.
//
// Example:
//
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("OK"))
// })
//
// // Chain middleware around the final handler.
// // Order matters: Recover must be first (outermost).
// logger := slog.Default()
// chainedHandler := middleware.Chain(handler,
// middleware.Recover(logger),
// middleware.RequestID(),
// middleware.Log(logger),
// )
//
// http.ListenAndServe(":8080", chainedHandler)
package middleware
import (
"context"
"crypto/rand"
"encoding/hex"
"log/slog"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
)
// Pipe is a middleware function. It's an adapter that takes an http.Handler
// and returns a new http.Handler, allowing functionality to be composed in
// layers.
type Pipe func(http.Handler) http.Handler
// Chain combines a handler with multiple middleware Pipes. The pipes are
// applied in reverse order, meaning the first pipe in the list is the outermost
// and executes first.
//
// For example, Chain(h, A, B, C) results in a handler equivalent to A(B(C(h))).
// Any nil pipes in the list are safely ignored.
func Chain(h http.Handler, pipes ...Pipe) http.Handler {
for i := len(pipes) - 1; i >= 0; i-- {
if pipe := pipes[i]; pipe != nil {
h = pipe(h)
}
}
return h
}
// Recover produces a middleware Pipe that catches panics in downstream
// handlers. It uses the provided logger to report the exception with a stack
// trace and returns an empty response with status code 500 to the client. The
// log entry also pinpoints the request method and URL that caused the panic.
// For maximum effectiveness, this should be the first (outermost) middleware
// in the chain.
func Recover(logger *slog.Logger) Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
method, url := r.Method, r.URL.String()
logger.Error(
"Panic caught by middleware",
"method", method,
"url", url,
"error", err,
"stack", string(debug.Stack()),
)
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}
type contextKey string // Prevents collisions with other packages.
// requestIDKey is the key under which the request ID is stored in the request
// context.
const requestIDKey = contextKey("RequestID")
// RequestID returns a middleware Pipe that injects a unique ID into each
// request. It adds the ID to the response via the "X-Request-ID" header and to
// the request's context for downstream use.
//
// Downstream handlers and other middleware can retrieve the ID using
// GetRequestID. If a unique ID cannot be generated from the random source, this
// middleware does nothing and passes the request to the next handler.
func RequestID() Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
next.ServeHTTP(w, r)
return
}
id := hex.EncodeToString(b)
w.Header().Set("X-Request-ID", id)
next.ServeHTTP(w, r.WithContext(SetRequestID(r.Context(), id)))
})
}
}
// GetRequestID retrieves the request ID from a given context. It returns an
// empty string if the ID is not found.
func GetRequestID(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey).(string)
return id
}
// SetRequestID sets the request ID in the provided context, returning a new
// context that carries the ID.
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, requestIDKey, id)
}
// interceptor is used to wrap the original http.ResponseWriter to
// capture the status code.
type interceptor struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code before calling the original WriteHeader.
func (i *interceptor) WriteHeader(code int) {
i.statusCode = code
i.ResponseWriter.WriteHeader(code)
}
// Log returns a middleware Pipe that logs a summary of each HTTP request. It
// captures the final HTTP status code by wrapping the http.ResponseWriter.
//
// The log entry is generated at the debug level after the request has been
// handled. It includes the method, URL, status code, duration, and other
// common attributes. To include a request ID in the log, this middleware should
// be placed after the RequestID middleware in the chain.
func Log(logger *slog.Logger) Pipe {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
incpt := &interceptor{w, http.StatusOK}
next.ServeHTTP(incpt, r)
logger.Debug(
"HTTP request handled",
slog.String("id", GetRequestID(r.Context())),
slog.String("method", r.Method),
slog.String("url", r.URL.String()),
slog.String("remote", r.RemoteAddr),
slog.String("agent", r.UserAgent()),
slog.Int("status", incpt.statusCode),
slog.Duration("duration", time.Since(start)),
)
})
}
}
// Volatile returns a middleware Pipe that prevents caching of the response.
// It sets standard HTTP headers (Cache-Control, Pragma, Expires) to ensure
// clients and proxies always fetch a fresh copy of the resource.
func Volatile() Pipe {
control := strings.Join([]string{
"no-store", "no-cache", "must-revalidate", "proxy-revalidate",
}, ", ")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", control)
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
next.ServeHTTP(w, r)
})
}
}
// SecurityConfig defines the headers applied by the Secure middleware.
type SecurityConfig struct {
// STSMaxAge is the maximum age for HSTS in seconds.
// If 0, the header is not set.
STSMaxAge int64
// STSIncludeSubdomains adds the "includeSubDomains" directive to HSTS.
STSIncludeSubdomains bool
// FrameOptions sets the X-Frame-Options header (e.g., "DENY", "SAMEORIGIN").
// If empty, the header is not set.
FrameOptions string
// NoSniff sets X-Content-Type-Options to "nosniff" if true.
// This helps prevent MIME type sniffing by browsers.
NoSniff bool
// CSP sets the Content-Security-Policy header.
// If empty, it is not set.
CSP string
// ReferrerPolicy sets the Referrer-Policy header.
// If empty, it is not set.
ReferrerPolicy string
// PermissionsPolicy sets the Permissions-Policy header.
// Example: "geolocation=(), microphone=()"
PermissionsPolicy string
// CrossOriginOpenerPolicy sets the Cross-Origin-Opener-Policy header.
// Recommended: "same-origin"
CrossOriginOpenerPolicy string
}
// DefaultSecurityConfig provides a baseline configuration that enables HSTS
// for 1 year, disables MIME sniffing, and protects against clickjacking by
// denying framing.
var DefaultSecurityConfig = SecurityConfig{
STSMaxAge: 31536000,
STSIncludeSubdomains: true,
FrameOptions: "DENY",
NoSniff: true,
PermissionsPolicy: "geolocation=(), microphone=(), camera=(), payment=()",
CrossOriginOpenerPolicy: "same-origin",
}
// Secure returns a middleware Pipe that sets various security-related HTTP
// headers based on the provided configuration.
func Secure(cfg SecurityConfig) Pipe {
// Pre-calculate HSTS header to avoid string allocation on every request.
hsts := ""
if cfg.STSMaxAge > 0 {
hsts = "max-age=" + strconv.FormatInt(cfg.STSMaxAge, 10)
if cfg.STSIncludeSubdomains {
hsts += "; includeSubDomains"
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 1. Strict-Transport-Security
if hsts != "" {
w.Header().Set("Strict-Transport-Security", hsts)
}
// 2. X-Content-Type-Options
if cfg.NoSniff {
w.Header().Set("X-Content-Type-Options", "nosniff")
}
// 3. X-Frame-Options
if cfg.FrameOptions != "" {
w.Header().Set("X-Frame-Options", cfg.FrameOptions)
}
// 4. Content-Security-Policy
if cfg.CSP != "" {
w.Header().Set("Content-Security-Policy", cfg.CSP)
}
// 5. Referrer-Policy
if cfg.ReferrerPolicy != "" {
w.Header().Set("Referrer-Policy", cfg.ReferrerPolicy)
}
// 6. Permissions-Policy (New)
if cfg.PermissionsPolicy != "" {
w.Header().Set("Permissions-Policy", cfg.PermissionsPolicy)
}
// 7. Cross-Origin-Opener-Policy (New)
if cfg.CrossOriginOpenerPolicy != "" {
w.Header().Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
}
// 8. X-Permitted-Cross-Domain-Policies (Hardening for PDF/Flash)
w.Header().Set("X-Permitted-Cross-Domain-Policies", "none")
next.ServeHTTP(w, r)
})
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package proxy provides a configurable reverse proxy handler. It constructs an
// httputil.ReverseProxy, starting with sensible defaults, integrating a
// reusable buffer pool, structured logging, and robust error handling
// via a functional options API.
package proxy
import (
"context"
"errors"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/deep-rent/nexus/internal/buffer"
)
const (
// DefaultMinBufferSize is the default minimum size of pooled buffers.
DefaultMinBufferSize = 32 << 10 // 32 KiB
// DefaultMaxBufferSize is the default maximum size of pooled buffers.
DefaultMaxBufferSize = 256 << 10 // 256 KiB
)
// Handler is an alias of http.Handler representing a reverse proxy.
type Handler = http.Handler
// NewHandler creates a new reverse proxy handler that routes to the target URL.
// The behavior of the proxy can be customized through the given options.
func NewHandler(target *url.URL, opts ...HandlerOption) Handler {
cfg := handlerConfig{
transport: http.DefaultTransport.(*http.Transport).Clone(),
flushInterval: 0,
minBufferSize: DefaultMinBufferSize,
maxBufferSize: DefaultMaxBufferSize,
newRewrite: NewRewrite,
newErrorHandler: NewErrorHandler,
logger: slog.Default(),
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.minBufferSize > cfg.maxBufferSize {
cfg.minBufferSize = cfg.maxBufferSize
}
// Construct ReverseProxy directly to avoid the deprecated Director hook
// set by NewSingleHostReverseProxy.
h := &httputil.ReverseProxy{
ErrorHandler: cfg.newErrorHandler(cfg.logger),
Transport: cfg.transport,
BufferPool: buffer.NewPool(cfg.minBufferSize, cfg.maxBufferSize),
FlushInterval: cfg.flushInterval,
}
defaultRewrite := func(pr *httputil.ProxyRequest) {
pr.SetXForwarded()
pr.SetURL(target)
}
h.Rewrite = cfg.newRewrite(defaultRewrite)
return h
}
// RewriteFunc defines a function to modify the proxy request before it is sent
// to the upstream target.
//
// The signature matches httputil.ReverseProxy.Rewrite.
type RewriteFunc func(*httputil.ProxyRequest)
// RewriteFactory creates a RewriteFunc using the provided original RewriteFunc.
// The returned RewriteFunc may call original to retain its behavior.
type RewriteFactory = func(original RewriteFunc) RewriteFunc
// NewRewrite is the default RewriteFactory for the proxy.
// It returns the original RewriteFunc unmodified.
func NewRewrite(original RewriteFunc) RewriteFunc {
// The default rewrite need not be overridden; it already sets the
// X-Forwarded-Host, X-Forwarded-Proto, and X-Forwarded-For headers, which is
// exactly what most proxies expect. It also correctly rewrites the Host header
// to match the target (required for sidecar setups to function).
return original
}
// ErrorHandler defines a function for handling errors that occur during the
// reverse proxy's operation.
//
// The signature matches httputil.ReverseProxy.ErrorHandler.
type ErrorHandler = func(http.ResponseWriter, *http.Request, error)
// ErrorHandlerFactory creates an ErrorHandler using the provided logger.
// It receives the configured logger to be used for error reporting.
type ErrorHandlerFactory = func(*slog.Logger) ErrorHandler
// NewErrorHandler is the default ErrorHandlerFactory for the proxy.
// It creates an error handler that logs upstream errors using
// the provided logger and maps them to appropriate HTTP status codes.
func NewErrorHandler(logger *slog.Logger) ErrorHandler {
return func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
// Silence client-initiated disconnects; there's nothing useful to send
return
}
status := http.StatusBadGateway
method, uri := r.Method, r.RequestURI
if errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, http.ErrHandlerTimeout) {
status = http.StatusGatewayTimeout
logger.ErrorContext(
r.Context(),
"Upstream request timed out",
slog.String("method", method),
slog.String("uri", uri),
)
} else {
logger.ErrorContext(
r.Context(),
"Upstream request failed",
slog.String("method", method),
slog.String("uri", uri),
slog.Any("error", err),
)
}
w.WriteHeader(status)
}
}
// handlerConfig holds the configurable settings for the proxy handler.
type handlerConfig struct {
transport *http.Transport
flushInterval time.Duration
minBufferSize int
maxBufferSize int
newRewrite RewriteFactory
newErrorHandler ErrorHandlerFactory
logger *slog.Logger
}
// HandlerOption defines a function for setting reverse proxy options.
type HandlerOption func(*handlerConfig)
// WithTransport sets the http.Transport for upstream requests.
//
// Use this option to tune connection pooling, timeouts (e.g., Dial,
// TLSHandshake), and keep-alives. If nil is given, this option is ignored.
func WithTransport(t *http.Transport) HandlerOption {
return func(cfg *handlerConfig) {
if t != nil {
cfg.transport = t
}
}
}
// WithFlushInterval specifies the periodic flush interval for copying the
// response body to the client.
//
// A zero value (default) disables periodic flushing. A negative value tells
// the proxy to flush immediately after each write to the client. The proxy is
// smart enough to recognize streaming responses, ignoring the flush interval
// in such cases.
//
// Adjust this setting if you observe high latencies for responses that are
// fully buffered by the proxy before being sent to the client. A lower value
// reduces latency at the cost of increased CPU usage.
func WithFlushInterval(d time.Duration) HandlerOption {
return func(cfg *handlerConfig) {
cfg.flushInterval = d
}
}
// WithMinBufferSize specifies the minimum size of buffers allocated by the
// buffer pool. This helps to reduce allocations for large response bodies.
//
// Non-positive values are ignored, and DefaultMinBufferSize is used. The
// value will be capped at MaxBufferSize.
//
// The pool will automatically adjust itself for larger, common responses
// and the MaxBufferSize will protect from memory bloat. You only need to
// adapt this setting if you know from profiling that 99% of your responses
// are, for example, larger than 100 KB.
func WithMinBufferSize(n int) HandlerOption {
return func(cfg *handlerConfig) {
if n > 0 {
cfg.minBufferSize = n
}
}
}
// WithMaxBufferSize specifies the maximum size of buffers to be returned to
// the buffer pool. Buffers that grow larger than this size will be discarded
// after use to prevent memory bloat.
//
// Non-positive values are ignored, and DefaultMaxBufferSize is used.
//
// This is a critical tuning parameter. If your typical (e.g., P95)
// response size is larger than this value, the pool will be
// ineffective, as most buffers will be discarded instead of being reused.
func WithMaxBufferSize(n int) HandlerOption {
return func(cfg *handlerConfig) {
if n > 0 {
cfg.maxBufferSize = n
}
}
}
// WithRewrite provides a custom RewriteFactory for the proxy.
//
// If nil is given, this option is ignored. By default, NewRewrite is used.
func WithRewrite(f RewriteFactory) HandlerOption {
return func(cfg *handlerConfig) {
if f != nil {
cfg.newRewrite = f
}
}
}
// WithErrorHandler provides a custom ErrorHandlerFactory for the proxy.
//
// If nil is given, this option is ignored. By default, NewErrorHandler is used.
func WithErrorHandler(f ErrorHandlerFactory) HandlerOption {
return func(cfg *handlerConfig) {
if f != nil {
cfg.newErrorHandler = f
}
}
}
// WithLogger sets the logger to be used by the proxy's ErrorHandler.
//
// If nil is given, this option is ignored. By default, slog.Default() is used.
// The default error handler (NewErrorHandler) uses this logger for capturing
// upstream errors.
func WithLogger(log *slog.Logger) HandlerOption {
return func(cfg *handlerConfig) {
if log != nil {
cfg.logger = log
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package retry is an http.RoundTripper middleware that provides automatic,
// policy-driven retries for HTTP requests.
//
// It wraps an existing http.RoundTripper (such as http.DefaultTransport)
// and intercepts requests to apply retry logic. The decision to retry is
// controlled by a Policy, and the delay between attempts is determined by a
// backoff.Strategy.
//
// # Usage
//
// A new transport is created with NewTransport, configured with functional
// options like WithAttemptLimit and WithBackoff.
//
// Example:
//
// // Retry up to 3 times with exponential backoff starting at 1 second.
// transport := retry.NewTransport(
// http.DefaultTransport,
// retry.WithAttemptLimit(3),
// retry.WithBackoff(backoff.New(
// backoff.WithMinDelay(1*time.Second),
// )),
// )
//
// client := &http.Client{Transport: transport}
//
// // This request will be retried automatically on temporary failures.
// res, err := client.Get("http://example.com/flaky")
// if err != nil {
// slog.Error("Request failed after all retries", "error", err)
// return
// }
// defer res.Body.Close()
package retry
import (
"context"
"errors"
"io"
"log/slog"
"net"
"net/http"
"time"
"github.com/deep-rent/nexus/backoff"
"github.com/deep-rent/nexus/header"
)
// Attempt encapsulates the state of a single HTTP request attempt. It is passed
// to a Policy to determine if a retry is warranted.
type Attempt struct {
Request *http.Request
Response *http.Response
Error error
Count int
}
// Idempotent reports whether the request can be safely retried without
// unintended side effects. It considers standard HTTP methods that are
// idempotent according to RFC 7231.
func (a Attempt) Idempotent() bool {
switch a.Request.Method {
case
http.MethodGet,
http.MethodHead,
http.MethodOptions,
http.MethodTrace,
http.MethodPut,
http.MethodDelete:
return true
default:
return false
}
}
// Temporary reports whether the response indicates a server-side temporary
// failure. This is determined by specific HTTP status codes that suggest the
// request might succeed if retried.
func (a Attempt) Temporary() bool {
if a.Response != nil {
switch a.Response.StatusCode {
case
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true
}
}
return false
}
// Transient reports whether the error suggests a temporary network-level
// issue that might be resolved on a subsequent attempt. It returns true for
// network timeouts and unexpected EOF errors.
//
// It returns false for context cancellations (context.Canceled,
// context.DeadlineExceeded), as these are intentional and should not be
// retried.
func (a Attempt) Transient() bool {
if a.Error == nil ||
errors.Is(a.Error, context.Canceled) ||
errors.Is(a.Error, context.DeadlineExceeded) {
return false
}
if errors.Is(a.Error, io.ErrUnexpectedEOF) || errors.Is(a.Error, io.EOF) {
return true
}
var err net.Error
return errors.As(a.Error, &err) && err.Timeout()
}
// Policy is the central decision-making function that determines whether a
// request should be retried. It is invoked after each attempt with the
// corresponding Attempt details. It returns true to schedule a retry or false
// to stop and return the last response/error.
type Policy func(a Attempt) bool
// LimitAttempts decorates a Policy to enforce a maximum attempt limit.
//
// It short-circuits the decision, returning false if the attempt count has
// reached the limit n. Otherwise, it delegates the decision to the wrapped
// policy. A limit of n means a request will be attempted at most n times
// (e.g., an initial attempt and n-1 retries). A limit of 1 disables retries.
func (p Policy) LimitAttempts(n int) Policy {
if n <= 0 {
return p
}
return func(a Attempt) bool {
return a.Count < n && p(a)
}
}
// DefaultPolicy provides a safe and sensible default retry strategy. It enters
// the retry loop only for idempotent requests that have resulted in a
// temporary server error or a transient network error such as a timeout.
func DefaultPolicy() Policy {
return func(a Attempt) bool {
return a.Idempotent() && (a.Temporary() || a.Transient())
}
}
type transport struct {
next http.RoundTripper
policy Policy
backoff backoff.Strategy
logger *slog.Logger
now func() time.Time
}
// RoundTrip executes a single HTTP transaction, applying retry logic as
// configured. It is the implementation of the http.RoundTripper interface.
//
// For a request to be retryable, its body must be rewindable. This is
// achieved by setting the http.Request.GetBody field. If GetBody is nil,
// the request is attempted only once, as its body stream cannot be read
// a second time.
//
// RoundTrip is responsible for handling the response body. On a successful
// attempt (or the final failed attempt), the response body is returned to the
// caller, who is responsible for closing it. On intermediary failed attempts,
// the response body is fully read and closed to ensure the underlying
// connection can be reused.
//
// The retry loop is sensitive to the request's context. If the context is
// cancelled, the retry loop terminates immediately.
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
var (
res *http.Response
err error
count int
)
defer t.backoff.Done()
rewindable := req.GetBody != nil
for {
count++
// If this is a retry and the body is rewindable, obtain a new reader.
if count > 1 && rewindable {
var e error
req.Body, e = req.GetBody()
if e != nil {
// Cannot rewind the body, so we must stop here.
return nil, e
}
}
res, err = t.next.RoundTrip(req)
// Ask the policy if we should retry.
if !t.policy(Attempt{
Request: req,
Response: res,
Error: err,
Count: count,
}) {
break // Success or policy decided to exit
}
// Check if the request body is rewindable. If not, we must stop here.
// This is checked after the policy to ensure the policy still gets notified
// of the attempt.
if req.Body != nil && !rewindable {
break
}
// If retrying, drain and close the previous response body.
if res != nil && res.Body != nil {
if _, err := io.Copy(io.Discard, res.Body); err != nil {
t.logger.Warn("Failed to drain response body", slog.Any("error", err))
}
if err := res.Body.Close(); err != nil {
t.logger.Warn("Failed to close response body", slog.Any("error", err))
}
}
delay := t.backoff.Next()
if res != nil {
if d := header.Throttle(res.Header, t.now); d != 0 {
// Use the longer of the two delays to respect both the server
// and our own backoff policy.
delay = max(delay, d)
}
}
if ctx := req.Context(); t.logger.Enabled(ctx, slog.LevelDebug) {
attrs := []any{
slog.Int("attempt", count),
slog.Duration("delay", delay),
slog.String("method", req.Method),
slog.String("url", req.URL.String()),
}
if err != nil {
attrs = append(attrs, slog.Any("error", err))
}
if res != nil {
attrs = append(attrs, slog.Int("status", res.StatusCode))
}
t.logger.DebugContext(ctx, "Request attempt failed, retrying", attrs...)
}
if delay <= 0 {
continue // Retry without delay
}
// Wait for the delay, respecting context cancellation.
select {
case <-time.After(delay):
continue // Proceed to next attempt
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
return res, err
}
var _ http.RoundTripper = (*transport)(nil)
// NewTransport creates and returns a new retrying http.RoundTripper. It wraps
// an existing transport and retries requests based on the configured policy
// and backoff strategy.
func NewTransport(
next http.RoundTripper,
opts ...Option,
) http.RoundTripper {
cfg := config{
policy: DefaultPolicy(),
limit: 0,
backoff: backoff.Constant(0),
logger: slog.Default(),
now: time.Now,
}
for _, opt := range opts {
opt(&cfg)
}
return &transport{
next: next,
policy: cfg.policy.LimitAttempts(cfg.limit),
backoff: cfg.backoff,
logger: cfg.logger,
now: cfg.now,
}
}
type config struct {
policy Policy
limit int
backoff backoff.Strategy
logger *slog.Logger
now func() time.Time
}
// Option is a function that configures the retry transport.
type Option func(*config)
// WithPolicy sets the retry policy used by the transport. If not provided,
// DefaultPolicy is used. A nil value is ignored.
func WithPolicy(policy Policy) Option {
return func(c *config) {
if policy != nil {
c.policy = policy
}
}
}
// WithAttemptLimit sets the maximum number of attempts for a request, including
// the initial one. A value of 3 means one initial attempt and up to two
// retries. A value of 1 effectively disables retries. If the value is 0 or
// less, no limit is enforced and retries are governed solely by the policy.
func WithAttemptLimit(n int) Option {
return func(c *config) {
c.limit = n
}
}
// WithBackoff sets the backoff strategy for calculating the delay between
// retries. If not provided, there is no delay between attempts. A nil value is
// ignored.
func WithBackoff(strategy backoff.Strategy) Option {
return func(c *config) {
if strategy != nil {
c.backoff = strategy
}
}
}
// WithLogger sets the logger for debug messages. If not provided,
// slog.Default() is used. A nil value is ignored.
func WithLogger(log *slog.Logger) Option {
return func(c *config) {
if log != nil {
c.logger = log
}
}
}
// WithClock provides a custom time source, primarily for testing. If not
// provided, time.Now is used. A nil value is ignored.
func WithClock(now func() time.Time) Option {
return func(c *config) {
if now != nil {
c.now = now
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package router provides a lightweight, JSON-centric wrapper around Go's
// native http.ServeMux.
//
// It simplifies building JSON APIs by offering a consolidated "Exchange" object
// for handling requests and responses, standardized error formatting, and a
// middleware chaining mechanism.
//
// Basic Usage:
//
// // 1. Setup the router with options
// logger := log.New()
// r := router.New(
// router.WithLogger(logger),
// router.WithMiddleware(middleware.Log(logger)),
// )
//
// // 2. Define a handler
// // You can use a closure, or a struct that satisfies the Handler interface.
// r.HandleFunc("POST /users", func(e *router.Exchange) error {
// var req CreateUserRequest
//
// // BindJSON enforces Content-Type and parses the body.
// // It returns a specific *router.Error type if validation fails.
// if err := e.BindJSON(&req); err != nil {
// return err
// }
//
// // ... Logic to save user ...
//
// // Return JSON response
// return e.JSON(http.StatusCreated, UserResponse{ID: "123"})
// })
//
// // 3. Start the server
// http.ListenAndServe(":8080", r)
package router
import (
"context"
"encoding/json/v2"
"log/slog"
"net/http"
"net/url"
"github.com/deep-rent/nexus/header"
"github.com/deep-rent/nexus/middleware"
)
// Standard error reasons used for machine-readable error codes.
const (
// ReasonWrongType indicates that the request had an unsupported content type.
ReasonWrongType = "wrong_type"
// ReasonEmptyBody indicates that the request body was empty.
ReasonEmptyBody = "empty_body"
// ReasonParseJSON indicates that there was an error parsing the JSON body.
ReasonParseJSON = "parse_json"
// ReasonParseForm indicates that there was an error parsing form data.
ReasonParseForm = "parse_form"
// ReasonServerError indicates that an unexpected internal error occurred.
ReasonServerError = "server_error"
)
// Standard media types used in the Content-Type header.
const (
// MediaTypeJSON is the media type for JSON content.
MediaTypeJSON = "application/json"
// MediaTypeForm is the media type for URL-encoded form data.
MediaTypeForm = "application/x-www-form-urlencoded"
)
// ResponseWriter extends the standard http.ResponseWriter with introspection
// capabilities.
//
// It allows handlers and middleware to check if the response headers have
// already been written, which is crucial for robust error handling.
type ResponseWriter interface {
http.ResponseWriter
// Status returns the HTTP status code written, or 0 if not written yet.
Status() int
// Closed reports whether the headers have already been written.
// This indicates that the response is committed.
Closed() bool
// Unwrap returns the underlying http.ResponseWriter.
// This allows http.ResponseController to access features like Flush(),
// Hijack(), and SetReadDeadline().
Unwrap() http.ResponseWriter
}
// NewResponseWriter wraps an http.ResponseWriter into a ResponseWriter.
func NewResponseWriter(w http.ResponseWriter) ResponseWriter {
return &responseWriter{
ResponseWriter: w,
status: 0,
}
}
// responseWriter is the concrete implementation of ResponseWriter.
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
if rw.status != 0 {
return
}
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
if rw.status == 0 {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
func (rw *responseWriter) Status() int {
return rw.status
}
func (rw *responseWriter) Closed() bool {
return rw.status != 0
}
func (rw *responseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}
// Error describes the standardized shape of API errors returned to clients.
//
// Handlers can return this struct directly to control the HTTP status code
// and error details. If a handler returns a standard Go error, the Router
// will wrap it in a generic internal server error.
type Error struct {
// Status is the HTTP status code (e.g., 400, 404, 500).
Status int `json:"status"`
// Reason is a short string identifying the error type (e.g.,
// "invalid_input").
Reason string `json:"reason"`
// Description is a human-readable explanation of the error cause.
Description string `json:"description"`
// ID is a unique identifier of the specific occurrence for tracing purposes
// (optional).
ID string `json:"id,omitempty"`
// Context contains arbitrary additional data about the error, such as
// validation fields.
Context map[string]any `json:"context,omitempty"`
// Cause is the underlying error that triggered this error (if any).
// It is excluded from JSON serialization to prevent leaking internal details.
Cause error `json:"-"`
}
// Error implements the generic error interface.
func (e *Error) Error() string {
return e.Reason + ": " + e.Description
}
// Exchange acts as a context object for a single HTTP request/response cycle.
//
// It wraps the underlying *http.Request and http.ResponseWriter to provide
// convenient helper methods for common API tasks, such as parsing JSON,
// reading parameters, and writing structured responses.
type Exchange struct {
// R is the incoming HTTP request.
R *http.Request
// W is a writer for the outgoing HTTP response.
W ResponseWriter
// jsonOpts is inherited from the parent Router.
jsonOpts []json.Options
}
// Context returns the request's context.
// This is commonly used for cancellation signals and request scoping.
func (e *Exchange) Context() context.Context { return e.R.Context() }
// Method returns the HTTP method (GET, POST, etc.) of the request.
func (e *Exchange) Method() string { return e.R.Method }
// URL returns the full URL of the request.
func (e *Exchange) URL() *url.URL { return e.R.URL }
// Path returns the URL path of the request.
func (e *Exchange) Path() string { return e.R.URL.Path }
// Param retrieves a path parameter by name.
//
// This relies on the routing pattern (e.g., "GET /users/{id}"). If the
// parameter does not exist, it returns an empty string.
func (e *Exchange) Param(name string) string { return e.R.PathValue(name) }
// Query parses the URL query parameters of the request. Malformed pairs will
// be silently discarded.
func (e *Exchange) Query() url.Values { return e.R.URL.Query() }
// Header returns the HTTP headers of the request.
func (e *Exchange) Header() http.Header { return e.R.Header }
// GetHeader retrieves a specific header value from the request.
func (e *Exchange) GetHeader(key string) string { return e.R.Header.Get(key) }
// SetHeader sets a specific header value in the response.
func (e *Exchange) SetHeader(key, value string) { e.W.Header().Set(key, value) }
// BindJSON decodes the request body into v.
//
// This method enforces strict API hygiene:
// 1. It verifies that the media type is "application/json".
// 2. It checks that the payload is not empty.
// 3. It unmarshals the JSON.
//
// If any of these checks fail, it returns a structured error that handlers
// can return directly.
func (e *Exchange) BindJSON(v any) *Error {
if t := header.MediaType(e.R.Header); t != MediaTypeJSON {
return &Error{
Status: http.StatusUnsupportedMediaType,
Reason: ReasonWrongType,
Description: "content-type must be " + MediaTypeJSON,
}
}
if e.R.Body == nil || e.R.Body == http.NoBody {
return &Error{
Status: http.StatusBadRequest,
Reason: ReasonEmptyBody,
Description: "empty request body",
}
}
if err := json.UnmarshalRead(e.R.Body, v, e.jsonOpts...); err != nil {
return &Error{
Status: http.StatusBadRequest,
Reason: ReasonParseJSON,
Description: "could not parse JSON body",
}
}
return nil
}
// ReadForm parses the request body as URL-encoded form data and returns the
// values.
//
// Unlike the standard http.Request.FormValue(), this strictly accesses
// the PostForm (body) only, ignoring URL query parameters. This is crucial
// for security protocols like OAuth to prevent query parameter injection.
func (e *Exchange) ReadForm() (url.Values, *Error) {
if t := header.MediaType(e.R.Header); t != MediaTypeForm {
return nil, &Error{
Status: http.StatusUnsupportedMediaType,
Reason: ReasonWrongType,
Description: "content-type must be " + MediaTypeForm,
}
}
if err := e.R.ParseForm(); err != nil {
return nil, &Error{
Status: http.StatusBadRequest,
Reason: ReasonParseForm,
Description: "malformed form data",
}
}
return e.R.PostForm, nil
}
// JSON encodes v as JSON and writes it to the response with the given HTTP
// status code.
//
// It automatically sets the Content-Type header to MediaTypeJSON if it has not
// already been set. When encoding fails, an error is returned.
func (e *Exchange) JSON(code int, v any) error {
buf, err := json.Marshal(v, e.jsonOpts...)
if err != nil {
// The error handler will catch this and map it to a 500 status.
return err
}
if e.W.Header().Get("Content-Type") == "" {
e.SetHeader("Content-Type", MediaTypeJSON)
}
e.Status(code)
_, err = e.W.Write(buf)
return err
}
// Form writes the values as URL-encoded form data with the given status code.
//
// It automatically sets the Content-Type header to MediaTypeForm if it has not
// already been set. When encoding fails, an error is returned.
func (e *Exchange) Form(code int, v url.Values) error {
if e.W.Header().Get("Content-Type") == "" {
e.SetHeader("Content-Type", MediaTypeForm)
}
e.Status(code)
_, err := e.W.Write([]byte(v.Encode()))
return err
}
// Status sends a response with the given status code and no body.
// This is commonly used for HTTP 204 (No Content).
func (e *Exchange) Status(code int) {
e.W.WriteHeader(code)
}
// NoContent sends a HTTP 204 No Content response.
func (e *Exchange) NoContent() {
e.Status(http.StatusNoContent)
}
// Redirect replies to the request with a redirect to url, which may be a path
// relative to the request path.
//
// Any non-ASCII characters in url will be percent-encoded, but existing percent
// encodings will not be changed. The provided code should be in the 3xx range.
func (e *Exchange) Redirect(url string, code int) error {
http.Redirect(e.W, e.R, url, code)
return nil
}
// RedirectTo constructs a URL by merging the base URL with the provided
// query parameters and redirects the client.
//
// This is particularly useful for callbacks.
func (e *Exchange) RedirectTo(base string, params url.Values, code int) error {
u, err := url.Parse(base)
if err != nil {
return &Error{
Status: http.StatusInternalServerError,
Reason: ReasonServerError,
Description: "invalid redirect target",
}
}
// Merge existing query params in 'base' with new 'params'
q := u.Query()
for k, vs := range params {
for _, v := range vs {
q.Add(k, v)
}
}
u.RawQuery = q.Encode()
http.Redirect(e.W, e.R, u.String(), code)
return nil
}
// Handler defines the interface for HTTP request handlers used by the Router.
//
// This interface allows using struct-based handlers (useful for dependency
// injection) in addition to simple functions.
type Handler interface {
// ServeHTTP processes an HTTP request encapsulated in the Exchange object.
ServeHTTP(e *Exchange) error
}
// HandlerFunc defines the function signature for HTTP request handlers.
type HandlerFunc func(e *Exchange) error
// ServeHTTP satisfies the Handler interface, allowing HandlerFunc to be used
// wherever a Handler is expected.
func (f HandlerFunc) ServeHTTP(e *Exchange) error { return f(e) }
// Ensure HandlerFunc implements Handler.
var _ Handler = HandlerFunc(nil)
// ErrorHandler defines a function that handles errors returned by routes.
type ErrorHandler func(e *Exchange, err error)
// Option defines a functional configuration option for the Router.
type Option func(*Router)
// WithMiddleware adds global middleware pipes to the Router.
// These pipes are applied to every route registered with the Router.
func WithMiddleware(pipes ...middleware.Pipe) Option {
return func(r *Router) {
r.mws = append(r.mws, pipes...)
}
}
// WithMaxBodySize sets the maximum allowed size for request bodies.
// Defaults to 0 (unlimited), but typically should be set (e.g., 1MB).
func WithMaxBodySize(bytes int64) Option {
return func(r *Router) {
r.maxBytes = bytes
}
}
// WithJSONOptions sets custom JSON options for the Router.
// They configure both, marshaling and unmarshaling operations.
func WithJSONOptions(opts ...json.Options) Option {
return func(r *Router) {
r.jsonOpts = opts
}
}
// WithErrorHandler sets a custom error handler.
// This allows you to override the default JSON error formatting.
func WithErrorHandler(h ErrorHandler) Option {
return func(r *Router) {
if h != nil {
r.errorHandler = h
}
}
}
// WithLogger updates the default error handler to use the given logger. If not
// set, the Router defaults to using slog.Default(). A nil value will be
// ignored.
func WithLogger(log *slog.Logger) Option {
return func(r *Router) {
if log != nil {
r.errorHandler = defaultErrorHandler(log)
}
}
}
// Router represents an HTTP request router with middleware support.
type Router struct {
// Mux is the underlying http.ServeMux. It is exposed to allow direct
// usage with http.ListenAndServe.
Mux *http.ServeMux
mws []middleware.Pipe
maxBytes int64
jsonOpts []json.Options
errorHandler ErrorHandler
}
// New creates a new Router instance with the provided options.
func New(opts ...Option) *Router {
r := &Router{
Mux: http.NewServeMux(),
mws: nil,
errorHandler: defaultErrorHandler(slog.Default()),
}
for _, opt := range opts {
opt(r)
}
return r
}
// ServeHTTP satisfies the http.Handler interface, allowing the Router to be
// used directly with HTTP servers. It delegates request handling to the
// underlying http.ServeMux.
func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) {
r.Mux.ServeHTTP(res, req)
}
// Handle registers a new route with the given pattern, handler, and optional
// middleware pipes.
//
// The pattern string must follow Go 1.22+ syntax (e.g., "GET /users/{id}").
//
// The handler is wrapped with the Router's global middleware and any local
// middleware provided for this specific route.
func (r *Router) Handle(
pattern string,
handler Handler,
mws ...middleware.Pipe,
) {
h := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
// Enforce body size limit if configured.
if r.maxBytes > 0 {
req.Body = http.MaxBytesReader(res, req.Body, r.maxBytes)
}
e := &Exchange{
R: req,
W: NewResponseWriter(res),
jsonOpts: r.jsonOpts,
}
err := handler.ServeHTTP(e)
if err != nil {
r.errorHandler(e, err)
}
})
// Combine global and local middleware.
local := append(r.mws, mws...)
r.Mux.Handle(pattern, middleware.Chain(h, local...))
}
// HandleFunc is a convenience wrapper for Handle that accepts a function
// instead of a Handler interface.
func (r *Router) HandleFunc(
pattern string,
fn func(*Exchange) error,
mws ...middleware.Pipe,
) {
r.Handle(pattern, HandlerFunc(fn), mws...)
}
// Mount registers a standard http.Handler (like http.FileServer) under a
// pattern.
//
// The handler will still be wrapped by the Router's global middleware,
// ensuring logging/auth logic applies to these routes as well.
func (r *Router) Mount(pattern string, handler http.Handler) {
r.Mux.Handle(pattern, middleware.Chain(handler, r.mws...))
}
// handle centralizes error processing.
func defaultErrorHandler(logger *slog.Logger) ErrorHandler {
return func(e *Exchange, err error) {
// NOTE: This function could be replaced by a customizable error handler
// in the future.
if e.W.Closed() {
// Response is already committed; we cannot write a JSON error.
// Log the error and exit to prevent "superfluous response.WriteHeader".
logger.Error(
"Handler returned error after writing response",
slog.Any("err", err),
)
return
}
ae, ok := err.(*Error)
if !ok {
// Log the internal error details for debugging.
logger.Error("An internal server error occurred", slog.Any("err", err))
ae = &Error{
Status: http.StatusInternalServerError,
Reason: ReasonServerError,
Description: "internal server error",
}
}
// Attempt to write the error response.
// Note: If the handler has already flushed data to the response writer,
// this may fail or append garbage, but standard HTTP flow stops here.
if we := e.JSON(ae.Status, ae); we != nil {
// If writing the error JSON fails (e.g. broken pipe), log it.
logger.Warn("Failed to write error response", slog.Any("err", we))
}
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package scheduler provides a flexible framework for running recurring tasks
// concurrently.
//
// The core of the package is the Scheduler interface, which manages the
// lifecycle of scheduled jobs. The basic unit of work is a Task, which can be
// adapted into a schedulable Tick. A Tick is a self-repeating job that
// determines its own next run time by returning a duration after each
// execution.
//
// # Usage
//
// Helpers like Every and After are provided to easily convert a simple Task
// into a Tick with common scheduling patterns:
//
// - Every(d, task): Creates a drift-free Tick that runs at a fixed
// cadence of duration d, accounting for the task's own execution time.
// - After(d, task): Creates a drifting Tick that waits for a fixed
// duration d after the previous run completes.
//
// Example:
//
// s := scheduler.New(context.Background())
// defer s.Shutdown()
//
// task := scheduler.TaskFn(func(context.Context) {
// slog.Info("Tick!")
// })
//
// tick := scheduler.Every(2*time.Second, task)
// s.Dispatch(tick)
//
// // Let the scheduler run for a while.
// time.Sleep(5 * time.Second)
package scheduler
import (
"context"
"sync"
"time"
)
// Tick represents a unit of work that can be scheduled to run repeatedly.
type Tick interface {
// Run executes the job and returns the duration to wait before the next
// execution. It accepts a context that is cancelled when the scheduler
// is shut down.
//
// If the returned duration is zero or negative, the next run is scheduled
// immediately.
Run(ctx context.Context) time.Duration
}
// TickFn is an adapter to allow the use of ordinary functions as Ticks.
type TickFn func(ctx context.Context) time.Duration
func (f TickFn) Run(ctx context.Context) time.Duration { return f(ctx) }
// Task represents a unit of work to be executed in a scheduler's execution
// loop. Helpers like After and Every adapt a Task into a Tick.
type Task interface {
// Run executes the job. It accepts a context for cancellation and
// timeout control.
Run(ctx context.Context)
}
// TaskFn is an adapter to allow the use of ordinary functions as Tasks.
type TaskFn func(ctx context.Context)
func (f TaskFn) Run(ctx context.Context) { f(ctx) }
// After creates a drifting Tick that runs after a fixed delay.
// The scheduler waits for the full delay after the task has completed, so
// the effective cadence will vary based on the task's execution time.
func After(d time.Duration, task Task) Tick {
return TickFn(func(ctx context.Context) time.Duration {
task.Run(ctx)
return d
})
}
// Every creates a drift-free Tick that runs at a fixed interval.
// The wrapper measures the Task's execution time and subtracts it from the
// specified interval, ensuring the task starts at a consistent cadence.
//
// If a task's execution time exceeds the interval, the next run will
// start immediately.
func Every(d time.Duration, task Task) Tick {
return TickFn(func(ctx context.Context) time.Duration {
start := time.Now()
task.Run(ctx)
elapsed := time.Since(start)
return max(0, d-elapsed)
})
}
// Scheduler manages the non-blocking execution of Ticks at their intervals.
type Scheduler interface {
// Context returns the scheduler's context. This context is cancelled when
// Shutdown is called. Users can select on this context's Done channel to
// coordinate with the scheduler's termination.
Context() context.Context
// Dispatch executes the given tick in a separate goroutine. The tick will
// run immediately and then repeat according to the duration it returns
// until the scheduler is shut down. Multiple ticks can be dispatched
// concurrently without blocking each other.
Dispatch(tick Tick)
// Shutdown gracefully stops the scheduler. It cancels the scheduler's context
// and waits for all its pending tasks to complete. Shutdown blocks until all
// dispatched goroutines have finished.
Shutdown()
}
// New creates a new Scheduler whose lifecycle is tied to the provided
// parent context. Cancelling this context will also cause the scheduler to
// shut down.
func New(ctx context.Context) Scheduler {
ctx, cancel := context.WithCancel(ctx)
return &scheduler{
ctx: ctx,
cancel: cancel,
}
}
type scheduler struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func (s *scheduler) Context() context.Context {
return s.ctx
}
func (s *scheduler) Dispatch(tick Tick) {
s.wg.Go(func() {
timer := time.NewTimer(0)
for {
select {
case <-s.ctx.Done():
timer.Stop()
return
case <-timer.C:
timer.Reset(tick.Run(s.ctx))
}
}
})
}
func (s *scheduler) Shutdown() {
s.cancel()
s.wg.Wait()
}
var _ Scheduler = (*scheduler)(nil)
// Once creates a synchronous Scheduler that runs each dispatched Tick exactly
// once. Its Dispatch method is blocking and runs the Tick in the calling
// goroutine.
//
// This implementation is useful for testing or for executing a task with the
// same interface but without true background scheduling.
func Once(ctx context.Context) Scheduler {
return &once{ctx: ctx}
}
type once struct {
ctx context.Context
}
func (o *once) Context() context.Context { return o.ctx }
func (o *once) Dispatch(tick Tick) { tick.Run(o.ctx) }
func (o *once) Shutdown() {}
var _ Scheduler = (*once)(nil)
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package build provides test helpers for compiling Go binaries.
// It ensures that build artifacts are isolated and automatically cleaned up
// after the tests finish.
package build
import (
"os/exec"
"path/filepath"
"runtime"
"testing"
)
// Binary compiles the Go source code located at the src directory and writes
// the resulting executable to dst within a temporary directory. It returns the
// absolute path to the compiled binary. The test framework automatically
// removes the executable and its directory when the test completes.
func Binary(t testing.TB, src string, dst string) string {
t.Helper()
exe := filepath.Join(t.TempDir(), dst)
if runtime.GOOS == "windows" {
exe += ".exe"
}
// Compile the current directory but execute the command inside the src
// directory. This ensures Go respects the module context of the target
// program.
cmd := exec.Command("go", "build", "-o", exe, ".")
cmd.Dir = src
if out, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("failed to build %s: %v\n%s", dst, err, out)
}
return exe
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ports provides integration test helpers to find available network
// ports and block until they accept connections.
package ports
import (
"context"
"net"
"strconv"
"testing"
"time"
)
// Free asks the kernel for a free, open port that is ready to use.
func Free() (int, error) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
defer func() {
_ = l.Close()
}()
return l.Addr().(*net.TCPAddr).Port, nil
}
// FreeT is a test helper that wraps Free. It fails the test immediately if
// a free port cannot be found.
func FreeT(t testing.TB) int {
t.Helper()
port, err := Free()
if err != nil {
t.Fatalf("failed to get free port: %v", err)
}
return port
}
// Wait blocks until the specified host and port are accepting TCP connections,
// or until the context is canceled/times out.
func Wait(ctx context.Context, host string, port int) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
var d net.Dialer
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
// DialContext respects the context for the dial operation itself.
conn, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = conn.Close()
return nil
}
select {
case <-ctx.Done():
// The context was canceled or its deadline exceeded.
return ctx.Err()
case <-ticker.C:
// Wait for the next tick before trying again.
}
}
}
// WaitT is a test helper that wraps Wait. It fails the test immediately if
// the port does not become available before the test context expires.
func WaitT(t testing.TB, host string, port int) {
t.Helper()
if err := Wait(t.Context(), host, port); err != nil {
t.Fatalf("failed waiting for port %d on %s: %v", port, host, err)
}
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package updater provides functionality to check for newer releases of an
// application hosted on GitHub.
//
// It queries the GitHub Releases API to retrieve the latest release and
// compares its tag against the current application version using semantic
// versioning.
//
// # Usage
//
// cfg := &updater.Config{
// Owner: "deep-rent",
// Repository: "vouch",
// Current: "v1.0.0",
// UserAgent: "Vouch/1.0.0",
// }
//
// // Check for updates.
// rel, err := updater.Check(context.Background(), cfg)
// if err != nil {
// log.Printf("Failed to check for updates: %v", err)
// } else if rel != nil {
// log.Printf("New version available: %s (see %s)", rel.Version, rel.URL)
// }
package updater
import (
"context"
"encoding/json/v2"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/mod/semver"
)
// Default configuration values for the updater.
const (
// DefaultBaseURL is the default GitHub API base URL.
DefaultBaseURL = "https://api.github.com"
// DefaultTimeout is the default timeout for HTTP requests.
DefaultTimeout = 5 * time.Second
)
// Release represents a published release on GitHub.
type Release struct {
// Version is the tag name of the release (e.g., "v1.0.0").
Version string `json:"tag_name"`
// URL to view the release on GitHub.
URL string `json:"html_url"`
// Published is the timestamp the release was published on GitHub.
Published time.Time `json:"published_at"`
// Notes contains the release notes or description.
Notes string `json:"body"`
}
// Config holds the configuration for the Updater.
type Config struct {
// BaseURL is the base URL for the GitHub API. It defaults to DefaultBaseURL
// if not set. This is primarily used for testing purposes.
BaseURL string
// Owner is the GitHub repository owner (required).
Owner string
// Repository is the name of the GitHub repository (required).
Repository string
// Current is the current version string of the application (required).
Current string
// UserAgent is the value for the User-Agent header sent with requests.
// If empty, no User-Agent header is sent.
UserAgent string
// Timeout is the time limit for requests made by the updater.
// It defaults to 5 seconds if not set.
Timeout time.Duration
}
// Updater checks for updates on GitHub for a specific repository.
type Updater struct {
baseURL string
owner string
repository string
current string
userAgent string
client *http.Client
}
// New creates a new Updater with the given configuration.
//
// It initializes the HTTP client with the specified timeout. It panics if the
// configuration is invalid (missing required fields) or if the current version
// string is not a valid semantic version (after normalizing it with a "v"
// prefix if missing).
func New(cfg *Config) *Updater {
if cfg.Owner == "" {
panic("updater: owner is required")
}
if cfg.Repository == "" {
panic("updater: repository is required")
}
if cfg.Current == "" {
panic("updater: current version is required")
}
current := normalize(cfg.Current)
if !semver.IsValid(current) {
panic(fmt.Sprintf(
"updater: current version %q is not a valid semver",
cfg.Current,
))
}
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = DefaultBaseURL
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
return &Updater{
baseURL: baseURL,
owner: cfg.Owner,
repository: cfg.Repository,
current: current,
userAgent: cfg.UserAgent,
client: &http.Client{
Timeout: timeout,
},
}
}
// Check queries the GitHub Releases API to determine if a newer version is
// available.
//
// It compares the latest release tag against the current version using
// semantic versioning. Both versions are normalized with a "v" prefix if
// missing. It returns a Release if a newer version is found. It returns nil if
// the current version is up-to-date or if the latest release is older or
// equal. It returns an error if the GitHub API request fails or if the latest
// release tag is not a valid semantic version.
func (u *Updater) Check(ctx context.Context) (*Release, error) {
url := fmt.Sprintf(
"%s/repos/%s/%s/releases/latest",
u.baseURL, u.owner, u.repository,
)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
if u.userAgent != "" {
req.Header.Set("User-Agent", u.userAgent)
}
res, err := u.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
}
defer func() {
_ = res.Body.Close()
}()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status from github api: %s", res.Status)
}
var r Release
if err := json.UnmarshalRead(res.Body, &r); err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
latest := normalize(r.Version)
if !semver.IsValid(latest) {
return nil, fmt.Errorf("latest version %q is not a valid semver", r.Version)
}
if semver.Compare(latest, u.current) > 0 {
return &r, nil
}
return nil, nil
}
// Check is a convenience function to check for updates in a single call.
// It creates a temporary Updater with the provided config and calls its Check
// method.
func Check(ctx context.Context, cfg *Config) (*Release, error) {
return New(cfg).Check(ctx)
}
// normalize ensures the version string has a "v" prefix, which is required by
// the semver package.
func normalize(v string) string {
if strings.HasPrefix(v, "v") {
return v
}
return "v" + v
}
// Copyright (c) 2025-present deep.rent GmbH (https://deep.rent)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package uuid provides an implementation of Version 7 (Time-ordered)
// Universally Unique Identifiers (UUID) as defined in RFC 4122 and RFC 9562.
//
// MIGRATION NOTE (v4 -> v7):
// We migrated from UUIDv4 (fully random) to UUIDv7 (time-ordered) to improve
// database performance. UUIDv4 causes significant index fragmentation and
// random I/O in B-Tree structures (standard database primary keys) due to
// its lack of locality.
//
// UUIDv7 solves this by being strictly monotonic (like a sequence ID) while
// retaining global uniqueness. This results in "append-only" index behavior,
// significantly higher write throughput, and better cache locality.
// It also aligns with native support arriving in PostgreSQL 18+.
package uuid
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"sync"
"time"
)
// UUIDv7 is a 128-bit time-ordered identifier (16 bytes).
//
// Layout:
// - 48 bits: Unix Timestamp (milliseconds)
// - 4 bits: Version (0111)
// - 12 bits: Random Data A
// - 2 bits: Variant (10)
// - 62 bits: Random Data B
type UUIDv7 [16]byte
// String returns the canonical string representation of the UUIDv7.
// Format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
func (u UUIDv7) String() string {
buf := make([]byte, 36)
hex.Encode(buf[0:8], u[0:4])
buf[8] = '-'
hex.Encode(buf[9:13], u[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], u[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], u[8:10])
buf[23] = '-'
hex.Encode(buf[24:], u[10:])
return string(buf)
}
// New generates a strictly monotonic UUIDv7 with sub-millisecond precision.
//
// It fills the timestamp and sequence fields using a global monotonic counter
// derived from the system clock, ensuring that IDs generated within the same
// millisecond are ordered. The remaining bits are filled with cryptographically
// secure random data.
func New() UUIDv7 {
ms, seq := tick()
var u UUIDv7
// Fill the first 6 bytes with the timestamp (Big Endian).
u[0] = byte(ms >> 40)
u[1] = byte(ms >> 32)
u[2] = byte(ms >> 24)
u[3] = byte(ms >> 16)
u[4] = byte(ms >> 8)
u[5] = byte(ms)
// Fill bytes 6 and 7 with the version and sequence.
u[6] = 0x70 | byte(seq>>8)
u[7] = byte(seq)
// Fill bytes 8 to 15 with random data.
if _, err := io.ReadFull(rand.Reader, u[8:]); err != nil {
panic(fmt.Errorf("uuid: failed to read random bytes: %w", err))
}
// Set byte 8 to the variant.
u[8] = (u[8] & 0x3f) | 0x80
return u
}
// Parse parses a standard 36-character hyphenated string representation of a
// UUID into a UUIDv7 type.
//
// It strictly validates that the UUID is Version 7 and Variant 1 (RFC 4122).
func Parse(s string) (UUIDv7, error) {
var u UUIDv7
if len(s) != 36 {
return u, fmt.Errorf("uuid: invalid length (%d)", len(s))
}
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return u, fmt.Errorf("uuid: invalid format")
}
h := s[0:8] + s[9:13] + s[14:18] + s[19:23] + s[24:]
if _, err := hex.Decode(u[:], []byte(h)); err != nil {
return u, fmt.Errorf("uuid: invalid characters: %w", err)
}
if (u[6] & 0xf0) != 0x70 {
return UUIDv7{}, fmt.Errorf("uuid: invalid version: expected v7")
}
if (u[8] & 0xc0) != 0x80 {
return UUIDv7{}, fmt.Errorf("uuid: invalid variant: expected RFC 4122")
}
return u, nil
}
// ParseBytes parses a 16-byte raw slice into a UUIDv7 type.
//
// It strictly validates that the byte slice is exactly 16 bytes and conforms
// to Version 7 and Variant 1.
//
// Note: This function does not modify the input slice; it creates a
// complete copy of the data.
func ParseBytes(b []byte) (UUIDv7, error) {
var u UUIDv7
if len(b) != 16 {
return u, fmt.Errorf("uuid: invalid length (%d)", len(b))
}
copy(u[:], b)
if (u[6] & 0xf0) != 0x70 {
return UUIDv7{}, fmt.Errorf("uuid: invalid version: expected v7")
}
if (u[8] & 0xc0) != 0x80 {
return UUIDv7{}, fmt.Errorf("uuid: invalid variant: expected RFC 4122")
}
return u, nil
}
// Global state for the monotonic generator.
var (
mu sync.Mutex
last int64
)
// tick implements Method 3 from the UUIDv7 specification (RFC 9562, Section
// 6.2). It returns a timestamp (ms) and a strictly increasing sequence (seq).
//
// The sequence holds fractional nanoseconds scaled to fit into 12 bits.
func tick() (ms, seq int64) {
mu.Lock()
defer mu.Unlock()
// 1. Get current time components.
ns := time.Now().UnixNano()
ms = ns / 1_000_000
// 2. Calculate the sequence number.
// We have 1,000,000 nanoseconds in a millisecond.
// We have 12 bits for the sequence (max 4096).
// Dividing by 256 (>> 8) maps 1,000,000 to ~3906, which fits in 12 bits.
seq = (ns - ms*1_000_000) >> 8
// 3. Pack into a comparable scalar (48 bits MS + 12 bits SEQ).
// This allows us to handle time rollbacks or high-frequency generation
// using simple integer arithmetic.
ts := ms<<12 + seq
// 4. Enforce monotonicity.
if ts <= last {
ts = last + 1
// Unpack the scalar back into components.
// If seq overflowed 12 bits, it automatically increments ms.
ms = ts >> 12
seq = ts & 0xfff
}
last = ts
return ms, seq
}