Gradient Fallback

Motivation

TS-DDR training relies on solving an optimization subproblem at every stage and differentiating through it (via Lagrange duals or DiffOpt). In practice, some solves may fail — the solver hits numerical trouble, DiffOpt encounters degenerate duals, or the subproblem is infeasible for a particular sample. A single uncaught error kills the entire training run.

The gradient fallback system provides a principled, extensible way to handle these errors at three levels:

LevelWhere it firesWhat it controls
rrule pullbackInside the ChainRules rrule for get_next_stateWhether a bad-solver-status pullback returns zeros or throws
Training loopAround Flux.gradient(...) in train_multistage / train_multiple_shootingWhether a DiffOpt error skips the iteration or crashes
Rollout evaluationInside RolloutEvaluation per scenarioWhether a failed scenario is excluded from the metric or crashes

Built-in fallback types

See the API Reference for full docstrings.

Usage

Default behavior (zero gradients)

By default, all training functions use ZeroGradientFallback. Failed iterations log a warning and skip the parameter update:

train_multistage(
    model, x0, subproblems, spi, spo, uncertainty;
    num_batches=500,
    # gradient_fallback=ZeroGradientFallback()  # this is the default
)

At training start you will see:

[ Info: Training with ZeroGradientFallback: solver/differentiation errors
will be caught and the iteration skipped (zero gradient). Pass
`gradient_fallback=ErrorGradientFallback()` to throw instead, or implement
a custom `AbstractGradientFallback` subtype.

Strict mode (for tests)

Use ErrorGradientFallback when you want errors to surface immediately — typically in unit tests where every solve should succeed:

train_multistage(
    model, x0, subproblems, spi, spo, uncertainty;
    num_batches=10,
    gradient_fallback=ErrorGradientFallback(),
)

The same keyword works for train_multiple_shooting and RolloutEvaluation:

rollout = RolloutEvaluation(
    subproblems, spi, spo, x0, scenarios;
    gradient_fallback=ErrorGradientFallback(),
)

Custom fallbacks (extending the type system)

Subtype AbstractGradientFallback and implement three methods:

struct LoggingFallback <: DecisionRules.AbstractGradientFallback
    log::Vector{Any}
end

# Called when the rrule pullback (DiffOpt / dual extraction) fails.
# Return a tuple of cotangents matching the rrule signature, or rethrow.
function DecisionRules.handle_gradient_error(fb::LoggingFallback, e, n_in, n_out)
    push!(fb.log, (:gradient, e))
    return DecisionRules._zero_cotangents(n_in, n_out)
end

# Called when Flux.gradient(...) throws in the training loop.
# Return `true` to skip this iteration, or rethrow.
function DecisionRules.handle_training_error(fb::LoggingFallback, e, iter)
    push!(fb.log, (:training, iter, e))
    return true
end

# Called when a rollout scenario fails during evaluation.
# Return `true` to exclude this scenario from the metric, or rethrow.
function DecisionRules.handle_rollout_error(fb::LoggingFallback, e, iter)
    push!(fb.log, (:rollout, iter, e))
    return true
end

Then pass it to any training function:

fb = LoggingFallback(Any[])
train_multistage(model, x0, subs, spi, spo, unc;
    gradient_fallback=fb,
)
println("Caught $(length(fb.log)) errors during training")

This is useful for:

  • Monitoring: count how often solves fail and on which iterations
  • Adaptive recovery: adjust solver tolerances, restart from a checkpoint, etc.
  • Selective rethrowing: catch known benign errors but let unexpected ones through

Relationship to STRICT_GRADIENTS

The global STRICT_GRADIENTS flag controls a separate, lower-level mechanism: inside the rrule pullback, when the forward solver terminates with a non-optimal status (e.g., ITERATION_LIMIT), the pullback returns zero gradients (if STRICT_GRADIENTS[] == false, the default) or throws (if true).

The gradient_fallback keyword operates at a higher level — it catches errors from DiffOpt's reverse_differentiate! (assertion errors, degenerate duals, etc.) and from the training loop itself. Both mechanisms are independent and complementary:

Forward solve
  └─ bad termination status → STRICT_GRADIENTS controls behavior
  └─ good status → DiffOpt reverse_differentiate!
       └─ error (assertion, numerical) → gradient_fallback catches it
            └─ in rrule pullback: handle_gradient_error
            └─ in training loop: handle_training_error
            └─ in rollout eval: handle_rollout_error