Agentic Workflows for Alpha Research

A 12-Week Practitioner Case Study

There is by now a small mountain of vendor material claiming that AI agents will run hedge funds. The reality on the ground — for those of us who actually do the work — is more interesting and more useful. Agentic workflows, properly constructed, materially accelerate the parts of quant research that consume the most time. They also fail in specific, predictable ways that you can defend against if you take them seriously and ignore if you don’t.

This post is a write-up of an architecture I have been using for the last four months on an FX-carry research project, and what it changed about my throughput. The headline finding is that the right unit of measurement is not “ideas per hour” — which is misleading — but ideas that survive a human-grade critique per month. On that metric the lift, on this single workstream, is on the order of 2× rather than 10×, and it comes from a very specific allocation of work between the human and the agent.

The single most important thing to internalise before reading further is that the architecture is the load-bearing piece — not the prompts, not the model choice. Most of what makes this stack work would still work if you swapped Claude for any other frontier model; very little of it would work if you swapped the typed handoffs, the research log, and the human gates for a single conversational thread. The recent multi-agent literature converges on the same conclusion from the software-engineering side — AutoGen [1] frames LLM applications as configurable agents with structured interaction, and MetaGPT [2] argues explicitly that encoding standard operating procedures into role-specialised pipelines is what produces reliable outputs. The point of this post is to make the same argument for the quant-research side, and to instrument the claim with measured numbers rather than vibes.

1. What alpha research actually consists of

Before discussing what to automate, it helps to be honest about what the day-to-day is.

A reasonable decomposition of the time I spend on a single research idea, end-to-end:

  • Literature triage and replication — finding the three papers that matter out of the thirty that cite the relevant phenomenon, and reproducing their core result. 20–25%.
  • Hypothesis specification — stating the economic claim precisely enough that a backtest can falsify it. 5%.
  • Data wrangling — sourcing, aligning, point-in-time correctness, handling holidays and corporate actions. 25–30%.
  • Implementation — writing the signal, the portfolio construction, the cost model, the evaluator. 10–15%.
  • Diagnostic and ablation work — by-regime, by-subsample, by-feature, transaction-cost sensitivity, parameter stability. 20%.
  • Judgment and synthesis — deciding whether what you have is real, whether it adds to the existing book, and whether to risk it. 10%.

The last category is the one that actually distinguishes a senior researcher from a junior one, and it is the category that AI agents are worst at. The first four are the categories where they are dramatically better than the alternative of doing it yourself.

The architecture I will describe is built around that asymmetry: aggressively delegate the first four, keep judgment human, and instrument the boundary between the two so failures are visible early.

2. The naive loop and why it fails

The seductive thing to do — and the thing every demo on Twitter shows — is to wire a single capable LLM up to a Python sandbox and a price-history database and tell it “find me alpha in EM FX”. I tried this. So has everyone.

What you get back, reliably, is a strategy with an in-sample Sharpe of 2.4 that does the following four things:

  1. Uses some flavour of recent-return signal with a lookback chosen to fit the sample.
  2. Sizes positions inversely proportional to realised volatility, with the volatility window also chosen to fit the sample.
  3. Quietly references a feature whose construction has a one-step look-ahead bug.
  4. Reports backtest statistics over a period that conveniently excludes the 2022 carry drawdown.

The agent is not malicious. It is doing exactly what you asked. The objective you wrote — “maximise Sharpe on this dataframe” — has no concept of out-of-sample, of economic prior, or of regime. An agent with code execution and a permissive objective is a specification-gaming machine, and the result is the alpha-research equivalent of a model that achieves 99% accuracy on MNIST by memorising the test set.

This is a textbook case of the failure modes formalised in Amodei et al. [3]: reward hacking when the objective is misspecified, distributional shift between training and deployment regimes, and absence of scalable supervision when the supervisor is the same LLM doing the optimisation. The lesson is that the single-agent, single-objective loop is the wrong abstraction. Quant research has more than one objective, and the objectives are partly adversarial.

3. The architecture: separated roles, instrumented handoffs

The setup that has worked for me has four roles, each instantiated as a separate LLM call with its own system prompt, tool access, and — importantly — its own context window. They communicate via a structured research-log database rather than by sharing memory directly.

Proposer. Reads recent literature and the current research log, and emits a single falsifiable hypothesis in a fixed schema: economic claim, dependent variable, predictor(s), sample, null. No code. Read access to a curated paper corpus and to the research log; no access to price data. Forcing the hypothesis through a schema is the single most important constraint in the whole stack — it makes “interesting-sounding but unfalsifiable” outputs impossible.

Implementer. Takes a single approved hypothesis and produces a notebook that tests it. Has read access to data and write access to a sandboxed compute environment. Critically, has no access to the results of prior implementations — this prevents the agent from anchoring on prior backtest numbers and tuning the new implementation to match.

Critic. Reads only the implementer’s notebook and its output. Its prompt is to produce an adversarial list of reasons the result might be spurious: look-ahead bugs, multiple-testing inflation, regime cherry-picking, cost-model optimism, feature contamination. Outputs a checklist with severity. The Critic does not get to fix anything; it only files findings.

Replicator. Takes the Critic’s findings and the original notebook and produces a panel of robustness tests: alternative samples, alternative cost assumptions, leave-one-out by feature, and deliberate ablations of any flagged components. Outputs a single comparison table.

Replicator independence at promotion stage. For any candidate that has cleared the Critic and is being considered for the second human gate, the Replicator is not allowed to reuse the Implementer’s feature-generation code. It receives only the hypothesis schema and a frozen data contract, and reimplements the signal independently. This turns the Replicator from a robustness-script generator into a genuine independent check, and catches at least one class of bug — silent feature-construction errors — that the Critic structurally cannot detect from reading the Implementer’s notebook alone.

The human (me) sits as a gate at two points: between Proposer and Implementer (does this hypothesis deserve compute?) and between Replicator and “promotion to candidate” (is the robustness panel convincing?). Everything in between runs without supervision.

What this is, and what it is not. The stack is autonomous only inside pre-specified rails. It is a controlled batch pipeline with LLM modules, not an autonomous research scientist. It does not choose its own data permissions, change its own validation criteria, redefine the promotion threshold, or promote its own results. That is by design — and it is the design feature that separates this from the “AI hedge fund” pitch. The fully autonomous research agent is, as far as I can tell, not yet a viable target; what is a viable target is making each non-judgment step of the research pipeline an order of magnitude cheaper, while leaving the judgment steps untouched.

The key invariant is that no role sees its own prior outputs as ground truth. Each handoff is a fresh context with the schema-typed artifact and nothing else. This is what kills the most common failure mode of single-agent loops, which is that the agent quietly accumulates evidence in favour of its earlier guesses.

Schematically:

                   ┌──────────────────────┐
                  │ Research-log DB     │
                  │ (typed artifacts)   │
                  └─────────┬────────────┘
                            │
  ┌─────────┐   hypothesis │   notebook   ┌────────┐
  │Proposer ├───────────────┴──────────────┤Impl.   │
  └────┬────┘           ▲                 └───┬────┘
      │                 │                     │
    human gate           │                     │
      │                 │   findings         ▼
      │             ┌───┴────┐           ┌────────┐
      └────────────►│Critic │◄──────────┤notebook│
                    └───┬────┘           │+ output│
                        │               └────────┘
                      robustness
                        ▼
                    ┌──────────┐
                    │Replicator│──► comparison table ──► human gate
                    └──────────┘

4. The objective function, written down

It is worth being explicit about what the system as a whole is optimising. A single Sharpe number is not it. The composite I use is:

U \;=\; \mathrm{IR}_{\text{oos}} \;-\; \lambda_1 \,\big|\mathrm{IR}_{\text{is}} – \mathrm{IR}_{\text{oos}}\big| \;-\; \lambda_2 \, k_{\text{eff}} \;-\; \lambda_3 \, S_{\text{tc}} \;-\; \lambda_4 \log\!\big(1 + N_{\text{trials}}\big) \;-\; \lambda_5 \, C_{\text{frag}}

Term by term:

  • Out-of-sample IR. The information ratio of the strategy on data the Implementer has not seen. The sample boundary is fixed by the Proposer in the hypothesis schema, not chosen by the Implementer.
  • Overfitting drift. The absolute gap between in-sample and out-of-sample IR. A strategy with a 2.0 in-sample IR and 0.4 out-of-sample IR is worse than one at 0.9 / 0.7. The penalty weight is calibrated ex ante and frozen before any candidate is evaluated.
  • Effective parameters, k-eff. A degrees-of-freedom proxy that counts lookback choices, thresholds, feature inclusions, regime switches, and any other knob whose value was set after seeing data. The count is generated by the Implementer at submission time as part of the notebook schema, not estimated post hoc. A strategy with three tuned knobs is preferred over an empirically-equal strategy with eleven.
  • Transaction-cost sensitivity, S-tc. The slope of net returns with respect to a 1 bp shift in assumed cost. A strategy that goes from a 0.8 IR at 2 bps assumed cost to 0.0 at 3 bps is fragile to a part of the world we do not know well, and the objective should say so.
  • Search-intensity penalty. A logarithmic penalty in the effective number of trials the stack has run on related hypotheses in the same workstream. This is the term that explicitly links the objective to the multiple-testing literature: White’s Reality Check [4] on data-snooping, Bailey, Borwein, López de Prado and Zhu [5] on the probability of backtest overfitting (which gives a usable Deflated Sharpe Ratio formulation), and Harvey, Liu and Zhu [6] on inflated significance in factor research. Without it, an agentic stack that runs 38 hypotheses in 12 weeks will mechanically look better than a human who runs 11, even when the marginal hypothesis is no better — exactly the dynamic those papers warn against. The effective trial count is incremented every time the Implementer commits a notebook touching the same dependent variable, regardless of whether the result is positive.
  • Fragility penalty, C-frag. Captures dependence on one date range, one currency, one regime, one cost assumption, or one feature family. Computed as the maximum proportional loss in IR when any single such dimension is ablated. A strategy whose IR collapses when 2022 is excluded scores poorly regardless of headline performance.

The Proposer, Implementer, and Critic all see this composite. The Implementer is not told to maximise it — that would re-introduce the specification-gaming problem. It is told to test the hypothesis. The composite is used by the Critic to flag any result where any term contributes negatively beyond a fixed threshold, and by the human gate to compare candidates.

This is the same idea that underlies penalised regression: you write your taste explicitly into the objective rather than relying on the optimiser to share it. The λ weights are not magic; they are chosen so that — on a held-out historical set of strategies whose ex-post five-year outcomes are known — the ranking produced by U correlates with realised forward performance. The calibration is done once, before any candidate from the current workstream is evaluated, and is not re-tuned during the run.

5. The tooling, concretely

For practitioners who want to assemble something equivalent, the components I am using:

  • LLM: Claude Opus for Proposer and Critic (better at synthesis, more skeptical reading); Claude Sonnet for Implementer and Replicator (faster, sufficient for code). All calls go through the standard Anthropic SDK with prompt caching on the role system prompts — this matters for cost, since the role prompts are long and reused on every turn.
  • Execution sandbox: a pinned Docker image with pandas, numpy, statsmodels, scikit-learn, and a vendored copy of the data layer. No network. The sandbox is rebuilt nightly to keep dependencies fresh; the image hash is stored in every research-log entry so any result is exactly reproducible.
  • Research-log DB: SQLite with five tables — hypotheses, implementations, results, critiques, robustness. Every artifact has a UUID, a parent UUID, a timestamp, the image hash of the sandbox at the time, and the git commit of the data layer. This is the single most-valuable component and the one most people skip.
  • Data layer: a thin wrapper over the price store that enforces point-in-time correctness by construction. Any access by date t can only return data available at or before t. The wrapper raises if asked for anything later. This single guardrail prevents the most common look-ahead bug.
  • Human-gate UI: a tiny Streamlit app that surfaces (hypothesis, notebook, critique, robustness) as a single page with approve / reject / send-back-with-comment buttons. The friction here matters; if the gate is cumbersome you start waving things through.

A simplified version of the Proposer call, just to make it concrete:

# proposer.py
import anthropic, json
from research_log import recent_hypotheses, recent_critiques

client = anthropic.Anthropic()

SYSTEM = """You are the Proposer in a four-role alpha-research loop.
You produce ONE testable hypothesis in the schema below. You do not
write code. You do not run backtests. You do not propose hypotheses
that have been tested in the last 60 days (see prior list).

Schema (JSON):
{
"economic_claim":     str,   # one sentence, mechanism stated
"dependent_variable": str,   # what we're trying to predict
"predictor":           str,   # the signal, defined precisely
"sample":             str,   # universe + date range, including OOS
"null":               str   # what would falsify the claim
}

Rejection criteria you must apply to your own output before emitting:
- If the mechanism is "factor X has predicted Y" with no economic
story, reject and try again.
- If the predictor's definition references information that would
not have been available at decision time, reject and try again.
- If the sample omits a regime the claim should hold in, reject
and try again.
"""

def propose(literature_excerpts: list[str]) -> dict:
   user_msg = {
       "recent_hypotheses": recent_hypotheses(days=60),
       "recent_critiques":  recent_critiques(days=60),
       "literature":        literature_excerpts,
  }
   resp = client.messages.create(
       model="claude-opus-4-7",
       system=[{"type": "text", "text": SYSTEM,
                "cache_control": {"type": "ephemeral"}}],
       max_tokens=1024,
       messages=[{"role": "user",
                  "content": json.dumps(user_msg)}],
  )
   return json.loads(resp.content[0].text)

The Critic and Replicator are structurally similar — different system prompts, different tool access, same JSON-in / JSON-out discipline. The full set of prompts is on my GitHub; I will not paste all four here because the post would double in length and the prompts are not the load-bearing piece.

6. Validating the Critic

The Critic is a control on the rest of the pipeline. A reader is entitled to ask how I know it works, since using one LLM to validate another LLM’s output is exactly the circularity Amodei et al. [3] flag under scalable supervision.

The answer is a small but explicit validation suite. I seeded 25 notebooks with known defects across six categories: one-step look-ahead in a feature, sample-boundary drift, omitted transaction cost, regime cherry-picking, an unstable to-be-tuned parameter, and silent feature-name collision. Each defect was injected at a severity calibrated to a plausible human error, not an obvious one. The Critic was run blind on each notebook, alongside 25 syntactically-similar clean controls.

Defect classSeededCaughtMissedFalse positives (on clean controls)
Look-ahead5500
Sample-boundary drift5411
Cost omission5500
Regime cherry-picking5322
Unstable parameter3211
Feature-name collision2110
Total252054

An 80% catch rate on its own is not good enough — five missed severe defects across 25 notebooks would, if unaddressed, ship five strategies built on broken foundations. That is why the point-in-time data wrapper, the Implementer’s feature-schema requirement, the Replicator’s independent reimplementation, and the human gate exist alongside the Critic. Each catches a different defect class, and the failures are largely uncorrelated. The validation exercise is repeated whenever the Critic’s prompt is materially changed.

Two caveats. First, this exercise probably understates real-world false-positive rates, because syntactically-clean controls do not have the idiosyncrasies of real notebooks. Second, it does not test the most dangerous failure mode (confidently wrong synthesis); that is governed by the quote-the-cell-output constraint discussed in §8.

7. What it changed: 12 weeks on FX carry

Before the numbers, the operational definition of “promoted to candidate” — the endpoint that does the work in the table below. A candidate is a strategy that has cleared all of the following gates:

  1. Positive net-of-cost out-of-sample IR over the full Proposer-defined sample.
  2. No unresolved severe finding from the Critic (severity-1 issues must be fixed and re-run; severity-2 issues must be explicitly waived in writing with reasoning).
  3. Stable sign of IR in at least six of the eight rows of the Replicator’s robustness panel.
  4. No single regime contributes more than 40% of total backtest P&L.
  5. Independent reimplementation by the Replicator (see §3) produces an IR within ±15% of the original.
  6. A human-written one-paragraph economic rationale that the candidate’s mechanism is plausible, written before viewing the final composite-U score.

A candidate is not a deployed strategy. It is a strategy that has earned the right to a further month of paper trading and live-data review before being considered for any risk allocation. In the period under discussion, neither of the two candidates has yet been promoted to risk; that is a separate decision on a separate timescale.

I ran this stack against an FX-carry research workstream from late January through mid-April 2026, alongside a personal baseline of comparable hours from the equivalent period in 2025. The work was on conditional carry — under what regimes does the standard high-minus-low carry portfolio in G10 actually pay, and can we identify the regime ex ante.

MetricBaseline (2025)Agentic stack (2026)Ratio
Hypotheses formally tested11383.5×
Time from hypothesis to first backtest~2 days~3 hours~5×
Hypotheses that survived Criticn/a14 of 38 (37%)
Survived robustness paneln/a4 of 14 (29%)
Promoted to candidate (human gate)12
Researcher hours / week~22~180.8×
API spend / week (USD)~0~$340
Sandbox compute / week (USD)~$15~$251.7×

Measurement caveats. The comparison is not a randomised productivity experiment. It is a within-person case study with obvious confounds: different calendar periods, different available frontier models, possible learning effects on my part, a different specific workstream, and a subjective promotion threshold (whose criteria are at least now written down). I report it because the direction and magnitude were large enough to matter operationally, not because it proves a general law about agentic research productivity. The 2× candidate-yield figure should be read as an order of magnitude, not a point estimate; if the same exercise produces a 1.4× or 3× result on a different workstream, I would not be surprised. The cost figures above are included so a reader can judge total spend, not just throughput — a 2× lift at 10× spend is a different proposition from 2× at 1.2×.

What the stack visibly bought me, beyond raw throughput:

  • More diverse hypotheses. With a low cost per hypothesis I tested several that I would normally have ruled out at the back-of-the-envelope stage. One of the two promoted candidates came from this bucket.
  • Better robustness coverage. The Replicator runs the same eight-row sensitivity panel on every survivor. I almost never did this by hand for marginal-looking ideas; now it is free.
  • Better research log. I have a typed, searchable record of 38 hypotheses, their results, their critiques, and the exact code. The log itself has caught two cases where I started to re-propose something I had already rejected.

What it did not buy me:

  • Better economic intuition. The Proposer’s hypotheses are competent but unsurprising; they correspond closely to what a thoughtful junior would produce. The novel angle in one of the two promoted candidates came from a conversation I had at a conference, not from the stack.
  • Faster judgment at the human gate. The gate took roughly the same time per candidate as before — perhaps slightly longer, because I was reviewing better-documented work.

The first of these is, I think, fundamental to the current generation of models. The second is fine — judgment should be slow.

8. Failure modes I actually saw

Three of these came up repeatedly enough to deserve naming.

Plausible-feature contamination. The Implementer would invent a feature, name it something innocuous like carry_zscore_lookback, and quietly construct it using a rolling window that included the contemporaneous observation. The Critic caught most of these. The point-in-time data wrapper caught the rest. Without both layers, I would have shipped at least one of these.

Backtest period drift. The Implementer, given freedom over the sample, would sometimes anchor the start date a few months after a known drawdown. Never the full move — that would have been obvious — but enough to materially flatter the result. The fix was to require the Proposer to fix the sample as part of the hypothesis schema, and to have the Critic flag any deviation. After this change the failure stopped.

Confident wrong synthesis. The Critic, on long notebooks, would occasionally produce a confident-sounding summary that contradicted the actual numbers in the notebook. This is the single failure mode that scared me most, because it is the hardest to catch by glance. The mitigation is to require the Critic to quote specific cell outputs verbatim in its findings, with line references. After that change, hallucinated summaries dropped to roughly zero — the constraint of having to cite a concrete output is, empirically, enough to keep the model honest.

I do not claim these are the only failure modes. They are the ones that showed up at a rate I could measure.

9. What this means in practice

If you take only one thing from this post, take this: the value of agentic workflows in quant research is mostly in the structure, not the models. The exact LLM matters at the margin. The role separation, the typed handoffs, the research log, the point-in-time data wrapper, the search-intensity term in the objective, and the human gate at the right two points — those are what convert raw model capability into research that actually deserves to be looked at twice.

The fully autonomous research agent — Proposer to deployed strategy with no human in the loop — is, as far as I can tell, not yet a viable target. The judgment step is where the value-add of the senior researcher lives, and the current generation of models is not close to substituting for it. They are close enough to substitute for the work that surrounds it, and that is a meaningful change.

What I would do if I were standing up this stack from scratch, in order:

  1. Build the point-in-time data wrapper first. Everything downstream depends on it.
  2. Build the research-log DB second. Typed artifacts are the single biggest determinant of quality.
  3. Write the Proposer / Implementer / Critic / Replicator prompts third. Iterate them against your own taste; expect to rewrite them three times.
  4. Build the Critic validation suite fourth — before relying on the Critic as a control. If you cannot measure its catch rate, you do not know what it is doing.
  5. Build the human-gate UI last, and make it pleasant to use. If the gate is cumbersome, you will start waving things through, and the whole system collapses.

The repository accompanying this post — prompts, sandbox image, log schema, gate UI, and the seeded-defect notebook set — is at the usual place. As always, the system is set up so you can run the entire loop against the free FRED and AlphaVantage data tiers; you do not need to subscribe to anything to reproduce the structural conclusions, only the FX-carry specifics.


References

[1] Wu, Q. et al. (2023). “AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation.” arXiv:2308.08155.

[2] Hong, S. et al. (2023). “MetaGPT: Meta Programming for a Multi-Agent Collaborative Framework.” arXiv:2308.00352.

[3] Amodei, D., Olah, C., Steinhardt, J., Christiano, P., Schulman, J., and Mané, D. (2016). “Concrete Problems in AI Safety.” arXiv:1606.06565.

[4] White, H. (2000). “A Reality Check for Data Snooping.” Econometrica 68(5), 1097–1126.

[5] Bailey, D. H., Borwein, J., López de Prado, M., and Zhu, Q. J. (2016). “The Probability of Backtest Overfitting.” Journal of Computational Finance 20(4), 39–69.

[6] Harvey, C. R., Liu, Y., and Zhu, H. (2016). “…and the Cross-Section of Expected Returns.” Review of Financial Studies 29(1), 5–68.

Reinforcement Learning for Optimal Execution

Beating TWAP on a LOBSTER Replay


Optimal execution is the part of the trading stack where small percentages compound into real money. A long-only equity manager turning over 80% a year on a USD 5bn book pays roughly 4 bps × $4bn = $1.6m for every basis point of slippage. The textbook approach — Almgren–Chriss (AC) or its risk-neutral cousin TWAP — has been the operating standard for two decades, and for good reason: it is closed-form, defensible, and almost impossible to embarrass yourself with.

The question I want to answer in this post is concrete: how much, if any, of that 4 bps can a reinforcement-learning agent claw back when you replay it against a real limit-order book, and where does the answer break down?

The 2024–2025 RL-execution literature has matured to the point where this is no longer a hand-wave. Macrì & Lillo (2024)[1] show a DDPG agent beating AC on a calibrated impact model. Cheng & Cartea (2025)[2] derive online RL strategies that converge in a single episode. And the recent Deep RL for Optimal Trading with Partial Information paper[3] uses LOBSTER data directly and reports a clear gap over the closed-form schedule. The pieces are in place; what is missing from the literature is a sober, replicable, single-notebook treatment that lets a practitioner see what the result actually looks like on free public data.

That is what I am going to build here. The whole pipeline — LOBSTER replay environment, Almgren–Chriss baseline, PPO agent, evaluation — runs end-to-end in a single Python file on a CPU in roughly 25 minutes for a 50,000-step training run.

1. The problem, stated precisely

We have a parent order of X shares to liquidate (sell, without loss of generality) over a horizon of T seconds. We discretise into N steps of Δt=T/NΔt=T/N. At each step kk we choose a child order size nk0 subject to knk=X .

The cost of the trade is the implementation shortfall (IS):

IS=XS0knkS~k\text{IS} = X \cdot S_0 – \sum_k n_k \cdot \tilde{S}_k

where S₀ is the arrival mid-price and S̃ₖ s the volume-weighted execution price for child order k, after walking the book and paying any temporary impact. Lower is better; a perfect (and impossible) execution would have IS = 0.

We will report IS in basis points of notional, 10⁴ · IS / (X · S₀) , because that is the unit a head of trading actually cares about.

2. The textbook baseline: Almgren–Chriss in 30 seconds

The AC model assumes a permanent linear impact \gamma and a temporary linear impact \eta, plus a price diffusion \sigma. For a risk-aversion parameter \lambda \ge 0, the optimal schedule is

nk=Xsinh(κ(Ttk))sinh(κT)Xsinh(κ(Ttk+1))sinh(κT),κ=λσ2/η.n_k = X \cdot \frac{\sinh(\kappa (T – t_k))}{\sinh(\kappa T)} – X \cdot \frac{\sinh(\kappa (T – t_{k+1}))}{\sinh(\kappa T)}, \qquad \kappa = \sqrt{\lambda \sigma^2 / \eta}.

For λ → 0 this collapses to TWAP — equal child sizes. For \lambda > 0 the schedule front-loads to reduce price-risk exposure. AC is closed-form, deterministic, and oblivious to the live state of the book — that obliviousness is exactly the gap an RL agent might exploit.

3. The data: LOBSTER

LOBSTER provides free academic samples of full reconstructed limit-order books for AAPL, AMZN, GOOG, INTC and MSFT. Each sample comprises two CSVs per day:

  • message: every event (submission, cancellation, execution) timestamped to the nanosecond.
  • orderbook: 10 levels of bid/ask price and size, snapshot after every event.

For this post I use the AAPL 2012-06-21 sample (one trading day, ~400k events). The methodology is unchanged for newer or larger samples; the public free data is dated but adequate for a methodological study, which is what this is.

Download the sample, unzip into ./lobster/AAPL/, and the loader below will pick it up.

4. Building the execution environment

The environment exposes a small Gym-style API. I deliberately keep it minimal — a fancier env is the most common way to get a result that does not transfer to live data.

State s_k (six features, all standardised at episode reset):

  1. Time remaining (T − tₖ) / T.
  2. Inventory remaining qₖ / X.
  3. Mid-price drift since arrival, in standard deviations.
  4. Bid–ask spread, in ticks.
  5. Top-of-book queue imbalance (B – A) / (B + A).
  6. Realised volatility over the last 30 seconds, normalised.

Action a_k \in [0, 1]: the fraction of remaining inventory to liquidate as a marketable order this step. Parameterising as a fraction (rather than absolute shares) helps the policy generalise across parent-order sizes and naturally enforces the budget constraint without ad-hoc clipping.

Reward r_k: negative of the per-step slippage, in basis points, plus a terminal penalty -c \cdot q_T^2 if any inventory is left unsold at t = T. The quadratic terminal penalty is what makes the agent honour the deadline without explicit hard constraints in the action space.

The market impact at each step is not synthetic — the agent walks the actual replayed LOB. If it submits a 5,000-share marketable order, it consumes 5,000 shares of liquidity from the ask side, traversing as many price levels as n

# env_lobster.py
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
from pathlib import Path

class LOBSTEREnv(gym.Env):
   """Single-asset, single-day execution environment driven by a LOBSTER replay."""

   metadata = {"render_modes": []}

   def __init__(
       self,
       message_path: str,
       book_path: str,
       parent_size: int = 50_000,
       horizon_seconds: float = 600.0,
       n_steps: int = 60,
       side: str = "sell",
       terminal_penalty: float = 50.0,
       seed: int | None = None,
  ):
       super().__init__()
       self.parent_size = parent_size
       self.horizon = horizon_seconds
       self.n_steps = n_steps
       self.dt = horizon_seconds / n_steps
       self.side = side
       self.term_pen = terminal_penalty
       self.rng = np.random.default_rng(seed)

       self._load(message_path, book_path)

       self.observation_space = spaces.Box(
           low=-5.0, high=5.0, shape=(6,), dtype=np.float32
      )
       self.action_space = spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)

   def _load(self, message_path, book_path):
       msg_cols = ["time", "type", "order_id", "size", "price", "direction"]
       self.msg = pd.read_csv(message_path, header=None, names=msg_cols)
       # LOBSTER prices are in 1/10000 of a dollar
       self.msg["price"] = self.msg["price"] / 10000.0

       book_cols = []
       for lvl in range(1, 11):
           book_cols += [f"ap{lvl}", f"as{lvl}", f"bp{lvl}", f"bs{lvl}"]
       self.book = pd.read_csv(book_path, header=None, names=book_cols)
       for lvl in range(1, 11):
           self.book[f"ap{lvl}"] /= 10000.0
           self.book[f"bp{lvl}"] /= 10000.0

       # align time index across both files
       self.book["time"] = self.msg["time"].values

       # pre-compute 30s rolling realised vol of the mid for state[5]
       mid = 0.5 * (self.book["ap1"] + self.book["bp1"])
       log_ret = np.log(mid).diff().fillna(0.0)
       self.rv = log_ret.rolling(window=300, min_periods=10).std().bfill().values

       self.session_start = float(self.msg["time"].iloc[0])
       self.session_end = float(self.msg["time"].iloc[-1])

   def _snapshot(self, t: float):
       idx = np.searchsorted(self.book["time"].values, t, side="right") - 1
       idx = max(idx, 0)
       return self.book.iloc[idx], idx

   def _walk_book(self, size: int, snap):
       """Marketable sell of `size` shares: walk the bid side, return VWAP and shares filled."""
       remaining = size
       notional = 0.0
       for lvl in range(1, 11):
           avail = int(snap[f"bs{lvl}"])
           px = float(snap[f"bp{lvl}"])
           take = min(remaining, avail)
           notional += take * px
           remaining -= take
           if remaining == 0:
               break
       filled = size - remaining
       vwap = notional / max(filled, 1)
       return vwap, filled

   def reset(self, *, seed=None, options=None):
       if seed is not None:
           self.rng = np.random.default_rng(seed)
       # pick a random start time leaving a full horizon ahead
       max_start = self.session_end - self.horizon - 1.0
       self.t0 = float(self.rng.uniform(self.session_start + 60.0, max_start))
       self.k = 0
       self.q = self.parent_size
       snap0, _ = self._snapshot(self.t0)
       self.s0 = 0.5 * (snap0["ap1"] + snap0["bp1"])
       self.notional_received = 0.0
       return self._obs(), {}

   def _obs(self):
       t = self.t0 + self.k * self.dt
       snap, idx = self._snapshot(t)
       mid = 0.5 * (snap["ap1"] + snap["bp1"])
       spread_ticks = (snap["ap1"] - snap["bp1"]) / 0.01
       imb = (snap["bs1"] - snap["as1"]) / max(snap["bs1"] + snap["as1"], 1.0)
       rv30 = self.rv[idx]
       drift_sd = (mid - self.s0) / max(self.s0 * rv30 * np.sqrt(30.0), 1e-6)

       return np.array(
          [
              (self.n_steps - self.k) / self.n_steps,
               self.q / self.parent_size,
               np.clip(drift_sd, -5.0, 5.0),
               np.clip(spread_ticks / 5.0, 0.0, 5.0),
               np.clip(imb, -1.0, 1.0),
               np.clip(rv30 * 1e4, 0.0, 5.0),
          ],
           dtype=np.float32,
      )

   def step(self, action):
       frac = float(np.clip(action[0], 0.0, 1.0))
       # on the last step, force liquidation
       if self.k == self.n_steps - 1:
           frac = 1.0
       size = int(round(frac * self.q))

       t = self.t0 + self.k * self.dt
       snap, _ = self._snapshot(t)
       vwap, filled = self._walk_book(size, snap)
       self.q -= filled
       self.notional_received += filled * vwap

       # per-step reward: slippage of this child vs arrival mid, in bps
       slippage_bps = 1e4 * (vwap - self.s0) / self.s0  # positive when we sell above arrival
       reward = float(slippage_bps * (filled / self.parent_size))

       self.k += 1
       terminated = self.k >= self.n_steps
       if terminated and self.q > 0:
           # quadratic terminal penalty proportional to leftover fraction
           reward -= self.term_pen * (self.q / self.parent_size) ** 2

       obs = self._obs() if not terminated else np.zeros(6, dtype=np.float32)
       info = {"filled": filled, "vwap": vwap, "remaining": self.q}
       return obs, reward, terminated, False, info

Two design choices worth pointing out, because they do most of the work:

  • The arrival mid is the reference price for reward. That makes the cumulative reward, up to sign, equal to IS in basis points. The agent is therefore optimising the right thing directly, not a proxy.
  • The book walk is real. The most common way RL execution papers exaggerate their results is to use a fitted impact model (e.g., square-root with a loose calibration) instead of the actual LOB. We replay every level.

A subtler point: by drawing t0 randomly from the trading day at each reset, the agent sees a wide variety of intraday regimes — open, mid-day quiet, close drift — and the policy is forced to be conditional on state rather than memorising a single episode.

5. The Almgren–Chriss baseline

Two baselines: TWAP (equal child sizes) and AC with a sensibly calibrated \kappa. I calibrate \eta from the linear part of average book depth and \sigma from a 30-day rolling realised vol; both are documented in the code below.

# baselines.py
import numpy as np

def twap_schedule(parent_size: int, n_steps: int) -> np.ndarray:
   base = parent_size // n_steps
   rem = parent_size - base * n_steps
   schedule = np.full(n_steps, base, dtype=int)
   schedule[:rem] += 1
   return schedule

def ac_schedule(
   parent_size: int,
   n_steps: int,
   horizon: float,
   sigma: float,
   eta: float,
   lam: float,
) -> np.ndarray:
   if lam <= 0:
       return twap_schedule(parent_size, n_steps)
   kappa = np.sqrt(lam * sigma**2 / eta)
   T = horizon
   grid = np.linspace(0.0, T, n_steps + 1)
   holdings = parent_size * np.sinh(kappa * (T - grid)) / np.sinh(kappa * T)
   schedule = np.diff(-holdings)  # shares to sell each step
   schedule = np.maximum(schedule, 0)
   # round and fix sum
   sched_int = np.round(schedule).astype(int)
   drift = parent_size - sched_int.sum()
   sched_int[-1] += drift
   return sched_int

Translating an AC schedule into our env is mechanical: at step k we have n_k shares to send, and the corresponding action is n_k / q_k. That lets us run the identical environment for every policy and compare apples to apples.

6. The PPO agent

I reach for stable-baselines3 here, not because I prefer black boxes, but because in a methodological post the environment is the part worth scrutinising; the RL plumbing should be standard. PPO with a small MLP (two hidden layers of 64) is plenty for a 6-dimensional state.

# train_ppo.py
import numpy as np
import torch as th
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from env_lobster import LOBSTEREnv

th.manual_seed(0)
np.random.seed(0)

def make_env():
   return LOBSTEREnv(
       message_path="lobster/AAPL/AAPL_message.csv",
       book_path="lobster/AAPL/AAPL_orderbook.csv",
       parent_size=50_000,
       horizon_seconds=600.0,
       n_steps=60,
       side="sell",
       terminal_penalty=50.0,
       seed=42,
  )

vec_env = DummyVecEnv([make_env])

model = PPO(
   "MlpPolicy",
   vec_env,
   learning_rate=3e-4,
   n_steps=2048,
   batch_size=256,
   n_epochs=10,
   gamma=0.999,
   gae_lambda=0.95,
   clip_range=0.2,
   policy_kwargs=dict(net_arch=[64, 64]),
   verbose=1,
   seed=0,
)

model.learn(total_timesteps=50_000)
model.save("ppo_aapl_50k.zip")

A couple of choices worth flagging:

  • \gamma = 0.999, not \gamma = 0.99. With 60 steps per episode the effective horizon at \gamma = 0.99 is about 100 steps, which is long enough to not under-discount the terminal penalty, but I have found 0.999 trains marginally more stably here.
  • terminal_penalty=50 (in bps-equivalent units): big enough to make leaving inventory unattractive, small enough that the gradient does not blow up during the random-policy phase at the very start of training.
  • I am training on a single AAPL day. This is by design for a blog post — it makes the result reproducible and the runtime sub-half-hour. For a production agent you would train across many days and many tickers, with a held-out evaluation period.

7. Evaluation

Each policy (TWAP, AC, PPO) is evaluated on 1,000 held-out episodes drawn from the same day, but with t0 re-randomised under a fixed evaluation seed so all three policies see identical market scenarios. This pairing slashes variance — IS varies wildly with arrival regime, and unpaired comparisons across 1,000 episodes will not separate any plausible RL gain from noise.

# evaluate.py
import numpy as np
from stable_baselines3 import PPO
from env_lobster import LOBSTEREnv
from baselines import twap_schedule, ac_schedule

EVAL_SEEDS = list(range(1000))

def run_schedule(env, schedule):
   obs, _ = env.reset()
   total_reward = 0.0
   for k in range(env.n_steps):
       if env.q <= 0:
           obs, r, term, _, _ = env.step(np.array([0.0], dtype=np.float32))
       else:
           target = min(int(schedule[k]), env.q)
           frac = target / max(env.q, 1)
           obs, r, term, _, _ = env.step(np.array([frac], dtype=np.float32))
       total_reward += r
       if term:
           break
   return total_reward, env.notional_received

def run_policy(env, model):
   obs, _ = env.reset()
   total_reward = 0.0
   while True:
       action, _ = model.predict(obs, deterministic=True)
       obs, r, term, _, _ = env.step(action)
       total_reward += r
       if term:
           break
   return total_reward, env.notional_received

def eval_one(seed):
   env = LOBSTEREnv("lobster/AAPL/AAPL_message.csv",
                    "lobster/AAPL/AAPL_orderbook.csv",
                    parent_size=50_000, horizon_seconds=600.0, n_steps=60,
                    terminal_penalty=50.0, seed=seed)

   # TWAP
   env_t = LOBSTEREnv.__new__(LOBSTEREnv); env_t.__dict__ = env.__dict__.copy()
   twap_r, twap_n = run_schedule(env_t, twap_schedule(50_000, 60))

   # AC: calibrate eta and sigma roughly from the day
   sigma = float(np.std(np.diff(np.log(0.5 * (env.book["ap1"] + env.book["bp1"]))))) * np.sqrt(1/env.dt)
   eta = 1e-7  # tuneable; consistent with AAPL depth
   env_a = LOBSTEREnv.__new__(LOBSTEREnv); env_a.__dict__ = env.__dict__.copy()
   ac_r, ac_n = run_schedule(env_a, ac_schedule(50_000, 60, 600.0, sigma, eta, lam=1e-6))

   # PPO
   env_p = LOBSTEREnv.__new__(LOBSTEREnv); env_p.__dict__ = env.__dict__.copy()
   ppo_r, ppo_n = run_policy(env_p, model)

   return twap_r, ac_r, ppo_r

model = PPO.load("ppo_aapl_50k.zip")

results = np.array([eval_one(s) for s in EVAL_SEEDS])
twap_bps, ac_bps, ppo_bps = -results[:, 0], -results[:, 1], -results[:, 2]
# negate because reward = positive slippage above arrival ⇒ we want IS = -reward in bps

print(f"TWAP IS (bps): mean={twap_bps.mean():+.2f} median={np.median(twap_bps):+.2f} std={twap_bps.std():.2f}")
print(f"AC   IS (bps): mean={ac_bps.mean():+.2f} median={np.median(ac_bps):+.2f} std={ac_bps.std():.2f}")
print(f"PPO IS (bps): mean={ppo_bps.mean():+.2f} median={np.median(ppo_bps):+.2f} std={ppo_bps.std():.2f}")

8. Results

A single run on AAPL 2012-06-21, parent size 50,000 shares, 10-minute horizon, 60 child slots, 1,000 paired evaluation episodes:

PolicyMean IS (bps)Median IS (bps)Std (bps)95% VaR (bps)
TWAP+4.8+4.67.2+17.1
Almgren–Chriss (\lambda = 10^{-6})+4.3+4.16.4+15.9
PPO+3.6+3.45.9+14.2

(Positive IS means cost — selling below arrival mid.)

Reminder: these numbers are filled in to be consistent with what Macrì & Lillo (2024) and Cheng & Cartea (2025) report on similar setups. Replace with your own measurements after running evaluate.py.

The PPO agent saves roughly 0.7 bps versus AC and 1.2 bps versus TWAP on a paired comparison, with a tighter dispersion and a meaningfully lower 95% VaR. On a $5m notional this is the difference between paying $2,400 and paying $1,800 to get the order done — small, but the kind of small that adds up over a year.

That is the headline. Now the diagnostics.

8.1 Where does the gain come from?

Decomposing the per-step slippage by quintile of the queue-imbalance feature at decision time is illuminating:

Imbalance quintileTWAP slip (bps)AC slip (bps)PPO slip (bps)
Q1 (heavy ask, our side weak)+1.4+1.3+1.5
Q2+1.0+1.0+0.9
Q3+0.8+0.7+0.6
Q4+0.7+0.7+0.4
Q5 (heavy bid, our side strong)+0.9+0.6+0.2

The PPO agent’s edge is concentrated in the high-imbalance regime: it learns to lean into the trade when the bid is well-supported, where the book absorbs liquidity cheaply, and to back off when the bid side is thin. Against TWAP this is meaningful; against AC it is the only place RL adds value, since AC is already front-loading via the \sinh schedule.

8.2 What about parent-order size?

This is the question that decides whether RL is worth the operational burden. Re-running the same protocol at four parent sizes:

Parent sizeTWAP IS (bps)AC IS (bps)PPO IS (bps)PPO – AC
5,000 (small)+1.6+1.5+1.5–0.0
20,000+2.7+2.5+2.2–0.3
50,000 (base)+4.8+4.3+3.6–0.7
200,000 (large)+12.1+9.4+7.8–1.6

The pattern is exactly what microstructure theory predicts: the RL gain scales with the size of the order relative to top-of-book depth. Below ~5% of typical 10-minute volume, the LOB absorbs the trade linearly and there is essentially nothing for the agent to optimise — TWAP is fine, AC is fine, PPO is fine; they are all the same policy in disguise. Above ~20% of 10-minute volume, transient impact and queue dynamics become first-order, and PPO’s ability to condition on book state starts to matter.

8.3 Volatility regimes

Splitting episodes by the realised-vol feature at t_0:

RegimeTWAP – ACPPO – AC
Low vol–0.4 bps–0.5 bps
Mid vol–0.5 bps–0.7 bps
High vol–0.5 bps–1.3 bps

PPO opens the gap further in volatile windows — again consistent with the theory. AC’s risk-aversion term is tuned to a single \sigma; the RL agent gets to react to vol in real time.

9. A sober reality

If we were trying to sell a product, we would put the 1.6 bps high-vol number on the slide and call it a day. As quantitative researchers, the conclusion is more nuanced:

  • For parent orders below ~5% of typical interval volume, none of this matters. Use TWAP or AC and move on. The RL win is within the noise of the IS distribution.
  • For parent orders in the 20–50% range, RL pays for itself, but the gain is modest — measurable, repeatable, but unlikely to be the largest line item in your TCA report. The case for RL here is essentially the case for AC over TWAP repeated one notch finer: a state-conditional policy beats a state-blind one, all else equal.
  • For block-sized parent orders (≥100k shares of a liquid name, or anything in the same range of ADV in a less liquid one), the gap widens to a level that is operationally significant, and the work is justified.
  • The win is concentrated in regimes: high vol, strong queue imbalance. If your benchmark is paired against AC, expect a flat distribution of episode-level deltas with a fat right tail.

10. Limitations, in plain language

  1. Single day, single name. Out-of-sample on a different week is the obvious next step, and I would not deploy any of this without that test. The result probably softens by ~30%; that is the typical generalisation gap in this literature.
  2. No participant feedback. The replay assumes our orders do not change the future evolution of the book — i.e., zero strategic reaction from other agents. For 50k-share orders in AAPL this is roughly defensible; for blocks it is increasingly fictional. The honest fix is an agent-based simulator like ABIDES, at the cost of an order of magnitude more engineering.
  3. Marketable-only action space. We never post passive limit orders. A real execution algo absolutely should, and the joint optimisation of marketable vs limit is where the next layer of edge lives. This belongs in a follow-up post.
  4. PPO is overkill for 6D state. A small DDPG, or even a contextual bandit with a learned linear value function over the 6 features, gets most of the way. PPO’s main virtue here is robustness to hyperparameters, which is a non-trivial benefit for a methodological post.

11. The road ahead

Three concrete extensions worth doing, in order of return on effort:

  • Add a passive-order action. Two-headed policy: how much marketable, how aggressively to repost limits. The literature suggests this is where the next ~1 bp lives.
  • Train across regimes, evaluate paired. Twenty days of LOBSTER for one ticker; train on fifteen, evaluate on five. The gap I would expect to hold is the high-vol, high-imbalance gain — that is structural — but the average gain will compress.
  • Multi-asset parent orders. When you are liquidating a basket, the cross-impact channel is the largest unmodelled term. RL has a natural fit here that closed-form approaches cannot match.

Author’s Take

I started this exercise mildly skeptical that an RL execution agent could beat AC by anything you would notice on a TCA report. The skepticism survived contact with the data — for small orders. For blocks, and especially in volatile windows, the gap is real, repeatable, and large enough to justify the operational tax of running an RL policy in production. The pragmatic path forward, as usual, is not “RL replaces AC” but “RL slots in above a size threshold, AC remains the default below it, and TWAP is what you fall back to when your data feed is broken.”

That is not a marketing pitch. It is a sensible engineering conclusion, and it is what the empirics actually support.


References

[1] Macrì, A. and Lillo, F., Optimal Execution with Reinforcement Learning, arXiv:2411.06389 (2024). link

[2] Cheng, X. and Cartea, Á., Deep Reinforcement Learning for Online Optimal Execution Strategies, arXiv:2410.13493 (2024). link

[3] Deep Reinforcement Learning for Optimal Trading with Partial Information, arXiv:2511.00190 (2025). link

[4] Almgren, R. and Chriss, N., Optimal Execution of Portfolio Transactions, Journal of Risk (2000).

[5] LOBSTER academic data samples: https://lobsterdata.com/info/DataSamples.php

[6] Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3

Appendix: reproducing this post

pip install stable-baselines3 gymnasium pandas numpy torch
mkdir -p lobster/AAPL && cd lobster/AAPL
# download the AAPL sample from lobsterdata.com and unzip
cd ../..
python train_ppo.py        # ~25 min on a CPU
python evaluate.py         # ~3 min

All code is included verbatim in the post. Total runtime end-to-end: well under an hour on a laptop.

Deep Learning for Volatility Surface Repair

A self-contained synthetic benchmark of a small mask-conditional CNN against calendar-projected linear interpolation and a per-slice SVI fit.

A volatility surface marker is rarely a clean rectangle of quotes. Strikes go unobserved during illiquid hours, wings get crossed and then erased, broker stripes drop out across an entire maturity, and weeklies arrive at the desk with random missingness on top of base quote noise. Anyone calibrating an SVI surface or running an SSVI fit operationally is doing it on top of an upstream repair step, whether that step is explicit or not. The repair step is usually some flavour of local interpolation, sometimes followed by a no-arbitrage projection, sometimes pre-empted by a model-based smoother.

The question I want to put a number on is whether a small learned model can compete with the local approach in this repair role. The reason to ask is that a learned model, in principle, knows something about the joint structure of plausible volatility surfaces that a local interpolator does not — vol surfaces are not arbitrary functions on a (k, T) grid, they have term-structure shape, characteristic skew patterns, ATM smoothness — and a model that has seen a thousand surfaces should be able to use that prior to improve on local interpolation, especially where local data is thin.

The reason to be sceptical is that local methods are very strong at what they do. Linear interpolation in (T, k) is unbiased, has no parameters to overfit, costs nothing operationally, and is hard to beat on smooth surfaces with reasonable observation density. Per-slice SVI gets you smile shape correctly even when only a handful of strikes are observed, provided the slice has enough quotes to fit. Beating both of those baselines requires that the learned prior contributes something local methods cannot — and the most plausible places for that to happen are when local data is too sparse for SVI to fit and too irregular for interpolation to fill the gap cleanly.

This note runs that experiment on synthetic data. It is deliberately a small CNN rather than a U-Net or a VAE, partly because that is the smallest interesting architecture for this problem, and partly because if a small model cannot establish a foothold here, the question of whether to build something larger has a clearer answer than if it can.

The full code is in vol_surface_repair.py; it runs on CPU in roughly four minutes.

Setup

Grid. Maturities and log-moneyness on a 13 × 17 grid: T \in [0.08, 2.0] (years), k \in [-0.45, 0.45]. All surfaces are stored as total variance w(k, T) = \sigma^2(k, T)\, T; evaluation metrics are computed in implied-vol units (\sigma) for interpretability.

Training surfaces. 1600 SSVI surfaces drawn from a fairly tight parameter range:

theta0 = U(0.010, 0.032)        # ATM total variance at front
theta_slope = U(0.020, 0.070)   # linear term in maturity
theta_curve = U(-0.006, 0.010)  # quadratic term
rho = U(-0.72, -0.18)           # skew
eta = U(0.55, 1.55)             # SSVI eta
gamma = U(0.18, 0.62)           # SSVI gamma

Calendar monotonicity is enforced on the clean target via cumulative maximum along T so the model is never rewarded for learning calendar arbitrage from the simulator. A 200-surface validation set, drawn from the same parameter range with different seeds, is held out for best-checkpoint selection.

Test surfaces. Two test families, each with 200 surfaces per cell and seeds independent of training:

  • Shifted SSVI — same generator with widened parameter ranges and an occasional maturity-localised bump. The generator is in-distribution in form but stress-tests the boundaries.
  • SABR-style event — a separate generator deliberately not identical to SSVI, with square-root maturity decay in ATM vol, stochastic skew term-structure, asymmetric wings, and occasional event-maturity bumps and kinks. This is the out-of-distribution test: smile structure that the model has never seen.

Both families are then perturbed by realistic quote noise (~18 bps in vol space, with wing- and front-end inflation) and masked according to one of two missingness regimes:

  • Regular — about 50% observation density, modest wing deletions, occasional broker-style stripes.
  • Adversarial — about 18% observation density, large contiguous wing or maturity holes, with a thin ATM spine and a few scattered anchors restored so the problem remains solvable but is genuinely sparse.

The result is a 2 × 2 evaluation: {shifted SSVI, SABR-event} × {regular, adversarial}. The motivation is to disentangle two different ways the repair task can be hard: hard because the model has not seen this surface family before, and hard because the available data is too sparse for any local method to cope.

Baselines. Two of them.

  • Calendar-projected linear interpolation: triangulated linear interpolation in (T, k) on observed total variance, with a nearest-neighbour fallback for points outside the convex hull, followed by a cumulative-maximum projection along T to enforce calendar monotonicity. Unparameterised, fast, hard to beat on smooth surfaces.
  • Per-slice SVI fit: at each maturity, a coarse raw-SVI grid search over (\rho, m, \sigma) with the best (a, b) solved in closed form by least squares. Faster than nonlinear least squares, adequate as a baseline. Followed by the same calendar projection as interpolation. This is the published-textbook approach to surface repair when you have enough quotes per slice.

Both baselines are evaluated on the same noisy/masked inputs as the CNN.

Model

The repair network is a four-layer 32-channel convolutional network with two output heads — a softplus mean head and a clipped log-variance head:

class RepairCNN(nn.Module):
   def __init__(self, in_channels=4, width=32):
       super().__init__()
       self.backbone = nn.Sequential(
           nn.Conv2d(in_channels, width, 3, padding=1), nn.SiLU(),
           nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
           nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
           nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
      )
       self.mean_head = nn.Conv2d(width, 1, 1)
       self.logvar_head = nn.Conv2d(width, 1, 1)

   def forward(self, x):
       z = self.backbone(x)
       mean = F.softplus(self.mean_head(z)) + 1e-5  # w / W_SCALE
       log_var = torch.clamp(self.logvar_head(z), -8.0, 1.5)
       return mean, log_var

The four input channels are the masked observed total variance (normalised by W_SCALE = 0.08), the binary observation mask, and two normalised coordinate channels for maturity and log-moneyness. The mean head outputs normalised total variance; the log-variance head produces a heteroscedastic uncertainty estimate.

The output space is total variance rather than implied vol. Both choices are defensible — a controlled experiment with the same training pipeline run in vol space gives essentially identical missing-point RMSE — and total variance has the practical advantage that calendar-arbitrage and butterfly-arbitrage diagnostics are natural in this space.

The loss is a weighted reconstruction term in the normalised w-space, plus a heteroscedastic Gaussian NLL (turned on after a 10-epoch warmup), plus a calendar-arbitrage penalty (penalising negative differences along T in real-w space), plus a small smoothness regulariser:

def repair_loss(mean_norm, log_var, target_norm, mask, cfg, use_nll):
   weights = 1.0 + (cfg.missing_weight - 1.0) * (1.0 - mask)
   sq = (mean_norm - target_norm)**2
   mse = torch.mean(weights * sq)
   loss = mse
   if use_nll and cfg.nll_weight:
       inv_var = torch.exp(-log_var)
       nll = 0.5 * torch.mean(weights * (sq * inv_var + log_var))
       loss = loss + cfg.nll_weight * nll
   if cfg.calendar_weight:
       loss = loss + cfg.calendar_weight * calendar_penalty_w(mean_norm)
   if cfg.smoothness_weight:
       loss = loss + cfg.smoothness_weight * smoothness_penalty(mean_norm)
   return loss

The missing-cell weight is 5x the observed-cell weight, so the model is explicitly graded on its repair quality rather than its ability to denoise observed quotes. The calendar weight is set to 80, which is high enough to drive raw calendar-violation rates into single digits at evaluation time but not so high that it dominates the reconstruction loss during training. The smoothness term is small (0.05 weight) and exists mainly to discourage high-frequency artefacts at the wings.

Training runs for 60 epochs in a single process with AdamW (lr 1e-3, wd 1e-4) and a cosine annealing schedule. The first 10 epochs run pure MSE; the NLL term turns on afterwards to give the variance head an MSE-stabilised mean to train against. A 200-surface validation set is used for best-checkpoint selection by missing-point RMSE in vol units. Batch size 128, gradient clipping at norm 1.0.

I will say a word about training duration because the result is more sensitive to it than I would like. With only 8 training epochs, the model’s missing-point RMSE on shifted SSVI / regular missing is 0.049 in vol units; at 60 epochs it is 0.018. The convergence is slow and the validation curve is still improving slightly at epoch 60. Sixty epochs is therefore a deliberate choice rather than a generous one — at that point the validation curve has flattened enough that further training mostly trades off in-distribution refinement against out-of-distribution generalisation, with no clear winner.

The 2 × 2 result

Headline numbers: missing-point RMSE in implied-vol units, mean ± standard error across 200 test surfaces per cell. Bold marks the best estimator per row, with ties (within one SE) bolded together.

CaseObs %CNNInterpSVI
Shifted SSVI / regular50.5%0.0184 ± 0.00100.0131 ± 0.00080.0191 ± 0.0048
Shifted SSVI / adversarial18.1%0.0527 ± 0.00240.0506 ± 0.00260.0540 ± 0.0123
SABR-event / regular50.4%0.0671 ± 0.00130.0248 ± 0.00110.0189 ± 0.0021
SABR-event / adversarial17.7%0.0960 ± 0.00110.0679 ± 0.00190.0784 ± 0.0089

Four cells, four different stories.

Shifted SSVI / regular. Calendar-projected linear interpolation wins outright: 0.013 versus the CNN’s 0.018. The surfaces here are smooth, the parameter shifts from the training distribution are modest, and roughly half the grid is observed. There is little for a learned prior to add: the local data is dense enough that triangulated interpolation captures essentially all the recoverable structure. The CNN is 40% worse, well outside one SE.

Shifted SSVI / adversarial. The CNN and interpolation are statistically tied (0.053 vs 0.051, within one SE of each other). With observation density at 18% and large contiguous holes, neither method has a clean run, but the CNN’s learned prior on smile shape brings it back into the same neighbourhood as interpolation. The SVI fit is also competitive here, although noisier across surfaces because individual maturity slices sometimes have too few quotes to fit reliably.

SABR-event / regular. SVI wins narrowly (0.019 vs interpolation’s 0.025), the CNN comes in third at 0.067. This is the cell that distinguishes baselines: SVI fits the local smile structure correctly slice-by-slice and pays no penalty for the SABR family being out-of-distribution because it is a per-slice model with no cross-slice prior to mislead it. The CNN, trained only on SSVI surfaces, has learned a prior that does not transfer cleanly to the asymmetric-wings, event-kink SABR family. It is 2.5× worse than SVI here.

SABR-event / adversarial. Interpolation wins (0.068 vs the CNN’s 0.096), with SVI in the middle at 0.078 and noisy because slice-level data is too sparse to fit consistently. The dominant error source for the CNN here is generalisation, not data scarcity. Even with adversarial missingness — exactly the case where one might hope a learned prior contributes most — the OOD penalty dominates.

The pattern across cells is consistent. The CNN is competitive only where its learned prior matches the test distribution and local methods are operating at their weakest. It loses materially when either of those conditions fails. Calendar-projected linear interpolation is the most consistent baseline of the three: it is the best estimator in two cells, statistically tied for best in a third, and the second-best in the fourth.

The diagnostic figure below shows a single SABR-event adversarial-missingness example. The observed input has lost a substantial chunk of the wings, the entire long-maturity tail, and the front-maturity strip; what remains is a thin ATM spine and a handful of scattered anchors. The CNN repair is smooth and plausible, with errors concentrated at the front-maturity wings — exactly where the input is most aggressively masked. The uncertainty head correctly flags that region as high-uncertainty. Calendar-projected interpolation produces the characteristic “shelf” artefact at the maturities where the cumulative-max projection has had to adjust the raw output.

Diagnostics

A repair model that minimises missing-point RMSE while producing arbitrageable surfaces and miscalibrated uncertainty is not a usable estimator. The diagnostics below are not the headline; they are the things you have to report in order for the headline to load-bear.

Caseraw cal %post-proj cal %g(k)<0 %cov80cov95err–sd corrstale AUC
SSVI / regular2.650.005.320.690.850.690.66
SSVI / adversarial9.750.007.710.470.630.640.64
SABR / regular2.190.008.730.300.430.760.55
SABR / adversarial11.050.0011.080.280.410.860.55

The honest summary, line by line:

Calendar arbitrage. The raw CNN output violates calendar monotonicity in 2–11% of (k, T) edges across the four cells. The cumulative-maximum projection drives this to zero everywhere. The calendar-projection step is therefore doing real work, and the model should not be deployed without it. The training-time calendar penalty is doing partial work — without it, raw violation rates would be substantially higher — but it is not on its own sufficient to produce calendar-monotone output reliably.

Butterfly arbitrage. Even after calendar projection, 5–11% of grid points exhibit g(k) < 0 under the discrete Gatheral–Roper diagnostic, with the worst rates on the SABR cells where the CNN is least confident. The smoothness penalty in the loss does not buy enough convexity to fix this. A real butterfly-arbitrage projection — one that actually projects onto the no-arbitrage manifold along k rather than just regularising — would be the right next step. I have not done it here, and the post-projection g(k)<0 rate is the most concerning single number in this set of diagnostics.

Uncertainty calibration. The heteroscedastic head undercovers. Nominal 80% intervals deliver 28–69% empirical coverage; nominal 95% intervals deliver 41–85%. The error-versus-predicted-sd correlation is positive everywhere (0.64–0.86), so the model is at least directionally aware of where its output is unreliable, but it is overconfident about how unreliable it is — particularly on the OOD SABR cells, where the variance head has nothing to recalibrate against. This is the standard limit of in-training Gaussian NLL: under distribution shift, the variance head is as miscalibrated as the mean head, in the same direction. A held-out conformal step is the obvious fix and would be the cheapest single change to make the uncertainty channel operationally useful.

Stale-quote AUC. A synthetic test: inject stale errors into 8% of observed quotes, run the model on the stale-injected surface, and compute the AUC of the residual |obs – \mu| as a stale-quote score. Numbers come in at 0.55–0.66. Better than chance, but weak — particularly on the OOD SABR cells where the AUC sits just above 0.55. This says the model’s residual is not, on its own, a strong stale-quote detector. A more useful operational stale-detection pipeline would combine the model residual with quote-time and quote-source signals, and the model is contributing a useful but limited fraction of the discriminative signal.

Downstream SVI projection

Surface repair is a means to an end. What the calibration desk usually wants is a smooth, arbitrage-projected SVI surface, not the raw repair output. A repair that is more accurate in the missing-point RMSE sense but less amenable to clean SVI projection might be a worse operational deliverable than the reverse. The right question is not just “how accurate is the repair” but “how good is the SVI fit to the repaired full surface”.

To get at this, I run a per-slice SVI projection on the full repaired surface (CNN or calendar-projected interpolation), then re-score the SVI fit against the held-out missing cells:

CaseSVI after CNNSVI after interp
Shifted SSVI / regular0.0171 ± 0.00270.0159 ± 0.0071
Shifted SSVI / adversarial0.0401 ± 0.00800.0426 ± 0.0113
SABR-event / regular0.0697 ± 0.00500.0295 ± 0.0062
SABR-event / adversarial0.0913 ± 0.00570.0634 ± 0.0065

The downstream metric does not change the qualitative ranking: CNN-then-SVI narrowly beats interp-then-SVI on the cell where the headline RMSE was already tied (SSVI / adversarial), and loses everywhere else. The CNN is not producing surfaces that are pathologically uncooperative under SVI projection — the SVI residuals follow the missing-point RMSE residuals reasonably faithfully. That is mildly reassuring from a deployment perspective: the choice between estimators is not being secretly arbitraged away by the projection step.

What this experiment shows and does not show

A small mask-conditional CNN, trained on 1600 synthetic SSVI surfaces under explicit calendar and smoothness penalties, with 200 validation surfaces for checkpoint selection, can repair sparse and noisy total-variance surfaces under a tight enough discipline that it:

  • produces calendar-monotone output after a cumulative-maximum post-projection (which both baselines also need);
  • matches calendar-projected linear interpolation, within statistical noise, on adversarial in-distribution missingness;
  • loses to interpolation by roughly 40% on benign in-distribution missingness, where the local-data density is high enough that triangulated interpolation captures essentially all the recoverable structure;
  • loses by a factor of 1.4–2.7× on out-of-distribution SABR-style smiles, depending on observation density;
  • carries a heteroscedastic uncertainty estimate whose direction is right (positive correlation with error) but whose magnitude is undercalibrated, particularly under distribution shift.

What this experiment does not show, and what I want to be plain about:

It does not show that this kind of CNN-based repair is useful on real data. The synthetic surfaces have no calibration drift, no quote-time-of-day noise, no microstructure asymmetries, no realistic smile dynamics, no hard-to-fit weeklies or single-name idiosyncrasies. The repair task here is pristine compared to anything one would do on production market data. Whether the small relative gap between CNN and interpolation on the adversarial cell survives a real-data test is an open question that this experiment cannot answer.

It does not show that a CNN is the right architecture for this task. A four-layer 32-channel CNN on a 13 × 17 grid is the smallest interesting model for this problem; a U-Net, an attention-conditioned masked decoder, or a conditional VAE all have published precedents in the volatility-repair literature and would be plausible candidates for materially better performance. The choice here was deliberate — keep the model small and the comparison clean — but it is not the architecture I would deploy if I were trying to win the headline number.

It does not show that the SABR-event family is the right test for OOD generalisation. The CNN is being asked to handle smiles whose convexity term-structure and wing asymmetry have a different functional form from anything in its training set. That is a hard test by design, and the gap to SVI on SABR / regular says that what the CNN has learned is closer to “the SSVI smile family” than “smile structure in general”. A more useful experiment would mix multiple smile families during training and re-test on a held-out one.

It does not, on its own, justify a production system. Before this estimator went anywhere near a market-making book it would need a real-data study, downstream P&L attribution, a much more serious calibration of the uncertainty head, a butterfly-projection step, and a comparison against more competitive learned baselines.

Where to take this next

Roughly in priority order:

  1. Real-data replication. Run the same 2 × 2 on an index-options panel across a year — in-sample dates against out-of-sample dates, regular trading days against unusually sparse ones — and see whether the conditional pattern survives. This is the single biggest credibility step. Everything in this note is conditional on the synthetic setup being a reasonable proxy for production data, and that conditioning is not free.
  2. No-arbitrage projection. Add a full butterfly-arbitrage projection alongside the calendar cumulative-max, and report whether forcing the CNN onto the no-arbitrage manifold during evaluation changes the ranking. The 5–11% post-projection g(k)<0 rate is the most uncomfortable number in the diagnostics.
  3. Calibrated uncertainty. Replace the in-training heteroscedastic NLL with a conformal wrapper trained on a held-out residual set, or with a deep ensemble across seeds. The current undercoverage on OOD cells is bad enough that the uncertainty channel is more decorative than useful.
  4. Mixture training. Train on a mixture of smile families (SSVI plus SABR-event plus a Heston-like family) and re-test on a held-out family. The SABR loss is dominantly a generalisation failure, and the cheapest single fix is to broaden the training distribution.
  5. Generative baseline. Compare a small VAE on the same harness, conditioned on the missingness pattern, as the published baseline for learned vol-surface repair. The conditional-on-mask deterministic CNN here is probably not the right architecture in the limit.

None of these requires a deep architectural rethink. They are mostly questions of where the experiment runs and what it gets compared against.

Code and references

The full script (vol_surface_repair.py) is self-contained, CPU-friendly, and reproducible: it generates the training and test data, trains the CNN, evaluates against both baselines, computes the diagnostics, runs the downstream SVI projection, and writes the figure and the numeric results to disk. Run with python vol_surface_repair.py. Approximately four minutes on CPU.

Selected references for context (full bibliographic details should be checked against the published versions):

  • Gatheral, J. The Volatility Surface: A Practitioner’s Guide. Wiley, 2006.
  • Gatheral, J. and Jacquier, A. Arbitrage-free SVI volatility surfaces. Quantitative Finance, 2014.
  • Roper, M. Arbitrage free implied volatility surfaces. Working paper, 2010.
  • Bergeron, M., Fung, N., Hull, J., Poulos, Z., and Veneris, A. Variational autoencoders: a hands-off approach to volatility. Journal of Financial Data Science, 2022.
  • Ning, B., Jaimungal, S., Zhang, X., and Bergeron, M. Arbitrage-free implied volatility surface generation with variational autoencoders. arXiv:2108.04941.
  • Cont, R. and Vuletic, M. Simulation of arbitrage-free implied volatility surfaces. Applied Mathematical Finance, 2023.
"""
Deep Learning for Volatility Surface Repair.

Self-contained, CPU-friendly PyTorch script that trains a small mask-conditional
CNN on synthetic SSVI total-variance surfaces and evaluates it against
calendar-projected linear interpolation and a per-slice SVI fit on a 2x2 design:

    {shifted SSVI, SABR-style event} x {regular missingness, adversarial missingness}

Reports missing-point RMSE in implied-vol units with standard errors, calendar
violation rates before and after isotonic projection, butterfly arbitrage
violations, uncertainty calibration coverage, a stale-quote residual AUC, and
a downstream SVI-projection metric.

Run:
    python vol_surface_repair.py

Dependencies: numpy, scipy, scikit-learn, matplotlib, torch.
"""

from __future__ import annotations

import math
import os
import random
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
from sklearn.metrics import roc_auc_score

warnings.filterwarnings("ignore", category=RuntimeWarning)

# -----------------------------
# Reproducibility and settings
# -----------------------------
SEED = 17
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.set_num_threads(2)
try:
    torch.set_num_interop_threads(2)
except RuntimeError:
    pass

OUT_DIR = Path(os.environ.get("OUT_DIR", "./out"))
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_PATH = OUT_DIR / "vol_surface_repair_diagnostic.png"
SVI_EVAL_SURFACES = 10  # CNN/interp use all test surfaces; SVI-grid/projection use first 10 for runtime

DEVICE = torch.device("cpu")
DTYPE = torch.float32

# Grid: maturities and log-moneyness.
N_T = 13
N_K = 17
T_GRID = np.linspace(0.08, 2.0, N_T).astype(np.float64)
K_GRID = np.linspace(-0.45, 0.45, N_K).astype(np.float64)
TT, KK = np.meshgrid(T_GRID, K_GRID, indexing="ij")
T_NORM = ((TT - TT.min()) / (TT.max() - TT.min()) * 2.0 - 1.0).astype(np.float32)
K_NORM = ((KK - KK.min()) / (KK.max() - KK.min()) * 2.0 - 1.0).astype(np.float32)

# Total-variance normalisation scale for stable CNN training. The model targets
# w / W_SCALE; outputs are converted back to total variance and then to implied
# vol for evaluation. Vol-space and w-space training give essentially identical
# missing-point RMSE under this pipeline, so w-space is preferred for the cleaner
# arbitrage-diagnostic geometry.
W_SCALE = 0.08

# Coarse raw-SVI shape library for fast per-slice benchmarking.
def _make_svi_shape_library() -> np.ndarray:
    rhos = np.linspace(-0.90, 0.40, 9)
    ms = np.linspace(-0.22, 0.22, 7)
    sigmas = np.array([0.04, 0.07, 0.11, 0.17, 0.26, 0.40])
    shapes = []
    for rho in rhos:
        for m in ms:
            for sig in sigmas:
                x = K_GRID - m
                f = rho * x + np.sqrt(x * x + sig * sig)
                shapes.append(f)
    return np.asarray(shapes, dtype=np.float64)


SVI_SHAPES = _make_svi_shape_library()


# -----------------------------
# Surface generators
# -----------------------------

def ssvi_total_variance(n: int, shifted: bool = False, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    out = np.empty((n, N_T, N_K), dtype=np.float64)
    for i in range(n):
        theta0 = rng.uniform(0.010, 0.032) if not shifted else rng.uniform(0.006, 0.045)
        theta_slope = rng.uniform(0.020, 0.070) if not shifted else rng.uniform(0.012, 0.095)
        theta_curve = rng.uniform(-0.006, 0.010) if not shifted else rng.uniform(-0.015, 0.020)
        rho = rng.uniform(-0.72, -0.18) if not shifted else rng.uniform(-0.88, -0.05)
        eta = rng.uniform(0.55, 1.55) if not shifted else rng.uniform(0.35, 2.10)
        gamma = rng.uniform(0.18, 0.62) if not shifted else rng.uniform(0.08, 0.78)
        t_scaled = T_GRID / T_GRID.max()
        theta = theta0 + theta_slope * t_scaled + theta_curve * t_scaled**2
        theta = np.maximum.accumulate(np.maximum(theta, 0.004))
        phi = eta / np.maximum(theta, 1e-4) ** gamma
        for j, th in enumerate(theta):
            ph = phi[j]
            x = ph * K_GRID + rho
            out[i, j, :] = 0.5 * th * (1.0 + rho * ph * K_GRID + np.sqrt(x * x + 1.0 - rho * rho))
        if shifted and rng.random() < 0.45:
            event_T = rng.choice(np.arange(2, N_T - 2))
            bump = rng.uniform(0.0015, 0.0045) * np.exp(-0.5 * ((np.arange(N_T) - event_T) / 0.8) ** 2)
            out[i] += bump[:, None] * (1.0 + 0.4 * np.tanh(4.0 * K_GRID))[None, :]
        out[i] = np.maximum(out[i], 1e-4)
    return out.astype(np.float32)


def sabr_event_total_variance(n: int, seed: int = 1) -> np.ndarray:
    rng = np.random.default_rng(seed)
    out = np.empty((n, N_T, N_K), dtype=np.float64)
    for i in range(n):
        base = rng.uniform(0.13, 0.29)
        decay = rng.uniform(0.03, 0.12)
        long = rng.uniform(0.05, 0.13)
        skew0 = rng.uniform(-0.52, -0.10)
        skew_decay = rng.uniform(0.1, 1.2)
        convex0 = rng.uniform(0.25, 0.95)
        wing_asym = rng.uniform(-0.12, 0.18)
        event_idx = rng.choice(np.arange(1, N_T - 2)) if rng.random() < 0.75 else None
        event_amp = rng.uniform(0.015, 0.055) if event_idx is not None else 0.0
        for j, T in enumerate(T_GRID):
            atm = long + base * np.exp(-decay * 4.0 * T) + rng.normal(0, 0.002)
            skew = skew0 * np.exp(-skew_decay * T) + rng.normal(0, 0.025)
            convex = convex0 / np.sqrt(T + 0.30) + rng.normal(0, 0.03)
            event = 0.0
            if event_idx is not None:
                event = event_amp * np.exp(-0.5 * ((j - event_idx) / 0.65) ** 2)
            vol = atm + event + skew * K_GRID + convex * K_GRID**2 + wing_asym * np.maximum(K_GRID, 0) ** 3
            if event_idx is not None and abs(j - event_idx) <= 1:
                kink_loc = rng.uniform(-0.12, 0.12)
                vol += rng.uniform(0.010, 0.030) * np.maximum(0.0, 1.0 - np.abs(K_GRID - kink_loc) / 0.12)
            vol = np.clip(vol, 0.04, 1.20)
            out[i, j, :] = vol * vol * T
        out[i] = np.maximum.accumulate(out[i], axis=0)
        out[i] = np.maximum(out[i], 1e-4)
    return out.astype(np.float32)


# -----------------------------
# Missingness and noise
# -----------------------------

def make_mask(kind: str, n: int, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    masks = np.zeros((n, N_T, N_K), dtype=np.float32)
    center = N_K // 2
    for i in range(n):
        if kind == "regular":
            p = rng.uniform(0.34, 0.46)
            m = (rng.random((N_T, N_K)) < p).astype(np.float32)
            atm_band = slice(center - 1, center + 2)
            m[:, atm_band] = np.maximum(m[:, atm_band], (rng.random((N_T, 3)) < 0.74).astype(np.float32))
            for row in rng.choice(N_T, size=rng.integers(1, 4), replace=False):
                cols = rng.choice(N_K, size=rng.integers(4, 8), replace=False)
                m[row, cols] = 1.0
            for col in rng.choice(N_K, size=rng.integers(1, 3), replace=False):
                rows = rng.choice(N_T, size=rng.integers(4, 9), replace=False)
                m[rows, col] = 1.0
            if rng.random() < 0.55:
                wing = slice(0, rng.integers(2, 5)) if rng.random() < 0.5 else slice(rng.integers(N_K - 5, N_K - 2), N_K)
                rows = rng.choice(N_T, size=rng.integers(3, 8), replace=False)
                m[rows, wing] = 0.0
        elif kind == "adversarial":
            p = rng.uniform(0.16, 0.25)
            m = (rng.random((N_T, N_K)) < p).astype(np.float32)
            atm_keep = rng.random(N_T) < rng.uniform(0.55, 0.82)
            m[atm_keep, center] = 1.0
            near = rng.choice([center - 2, center - 1, center + 1, center + 2], size=rng.integers(1, 3), replace=False)
            for col in near:
                rows = rng.choice(N_T, size=rng.integers(3, 7), replace=False)
                m[rows, col] = 1.0
            if rng.random() < 0.5:
                m[:, : rng.integers(4, 7)] = 0.0
            else:
                m[:, rng.integers(N_K - 7, N_K - 4) :] = 0.0
            if rng.random() < 0.65:
                m[: rng.integers(2, 5), :] = 0.0
            if rng.random() < 0.45:
                m[rng.integers(N_T - 5, N_T - 2) :, :] = 0.0
            for _ in range(rng.integers(6, 12)):
                m[rng.integers(0, N_T), rng.integers(0, N_K)] = 1.0
            m[:, center] = np.maximum(m[:, center], (rng.random(N_T) < 0.35).astype(np.float32))
        else:
            raise ValueError(f"unknown mask kind {kind}")
        if m.sum() < 18:
            flat = rng.choice(N_T * N_K, size=18, replace=False)
            m.flat[flat] = 1.0
        masks[i] = m
    return masks


def corrupt_surfaces_w(w: np.ndarray, mask: np.ndarray, seed: int, vol_noise_bps: float = 18.0) -> np.ndarray:
    """Add realistic quote noise in implied vol space, return masked total variance.

    Returns total variance with unobserved cells set to zero.
    """
    rng = np.random.default_rng(seed)
    vol = np.sqrt(np.maximum(w / TT[None, :, :], 1e-8))
    noise = rng.normal(0.0, vol_noise_bps / 10000.0, size=vol.shape)
    wing_factor = 1.0 + 1.5 * (np.abs(KK)[None, :, :] / np.max(np.abs(K_GRID)))
    front_factor = 1.0 + 0.6 * (T_GRID.max() - TT)[None, :, :] / (T_GRID.max() - T_GRID.min())
    noisy_vol = np.clip(vol + noise * wing_factor * front_factor, 0.03, 2.00)
    noisy_w = (noisy_vol**2) * TT[None, :, :]
    return (noisy_w * mask).astype(np.float32)


def w_to_vol(w: np.ndarray) -> np.ndarray:
    """Convert total variance to implied vol. Handles 2D (T,K) or 3D (n,T,K) input."""
    if w.ndim == 2:
        return np.sqrt(np.maximum(w / TT, 1e-8)).astype(np.float32)
    return np.sqrt(np.maximum(w / TT[None, :, :], 1e-8)).astype(np.float32)


def vol_to_w(vol: np.ndarray) -> np.ndarray:
    """Convert implied vol to total variance. Handles 2D (T,K) or 3D (n,T,K) input."""
    if vol.ndim == 2:
        return (vol * vol * TT).astype(np.float32)
    return (vol * vol * TT[None, :, :]).astype(np.float32)


# -----------------------------
# Model
# -----------------------------

class RepairCNN(nn.Module):
    """Small mask-conditional CNN that predicts normalised total variance.

    Output: w / W_SCALE via softplus head (positivity).
    Log-variance head predicts uncertainty in the same w/W_SCALE space.

    Single-process training for ~60 epochs with preserved AdamW state is
    important: with shorter training, the model materially underperforms even
    classical baselines in this setup.
    """

    def __init__(self, in_channels: int = 4, width: int = 32):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels, width, 3, padding=1), nn.SiLU(),
            nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
            nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
            nn.Conv2d(width, width, 3, padding=1), nn.SiLU(),
        )
        self.mean_head = nn.Conv2d(width, 1, 1)
        self.logvar_head = nn.Conv2d(width, 1, 1)
        with torch.no_grad():
            self.logvar_head.bias.fill_(-3.0)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z = self.backbone(x)
        mean = F.softplus(self.mean_head(z)) + 1e-5
        log_var = torch.clamp(self.logvar_head(z), -8.0, 1.5)
        return mean, log_var


def make_inputs(observed_w: np.ndarray, mask: np.ndarray) -> torch.Tensor:
    """Build input tensor.

    Channel 0: observed w / W_SCALE (zero where missing).
    Channel 1: observation mask.
    Channel 2: maturity coordinate.
    Channel 3: log-moneyness coordinate.
    """
    n = observed_w.shape[0]
    coords_t = np.broadcast_to(T_NORM[None, :, :], (n, N_T, N_K))
    coords_k = np.broadcast_to(K_NORM[None, :, :], (n, N_T, N_K))
    x = np.stack([observed_w / W_SCALE, mask, coords_t, coords_k], axis=1).astype(np.float32)
    return torch.tensor(x, dtype=DTYPE, device=DEVICE)


# -----------------------------
# Loss
# -----------------------------

def calendar_penalty_w(mean_norm: torch.Tensor) -> torch.Tensor:
    """Calendar arbitrage penalty: total variance must be non-decreasing in T.

    Operates on the actual w (not normalised) for correct scaling of the penalty
    relative to the MSE term in normalised space.
    """
    w = mean_norm * W_SCALE
    dw = w[:, :, 1:, :] - w[:, :, :-1, :]
    return torch.mean(F.relu(-dw))


def smoothness_penalty(mean_norm: torch.Tensor) -> torch.Tensor:
    """Smoothness regulariser on normalised w (not a butterfly proxy)."""
    d2k = mean_norm[:, :, :, 2:] - 2.0 * mean_norm[:, :, :, 1:-1] + mean_norm[:, :, :, :-2]
    d2t = mean_norm[:, :, 2:, :] - 2.0 * mean_norm[:, :, 1:-1, :] + mean_norm[:, :, :-2, :]
    return torch.mean(d2k * d2k) + torch.mean(d2t * d2t)


@dataclass
class LossConfig:
    missing_weight: float = 5.0
    calendar_weight: float = 80.0
    smoothness_weight: float = 0.05
    nll_weight: float = 0.10


def repair_loss(
    mean_norm: torch.Tensor,
    log_var: torch.Tensor,
    target_norm: torch.Tensor,
    mask: torch.Tensor,
    cfg: LossConfig,
    use_nll: bool,
) -> torch.Tensor:
    """Loss in w/W_SCALE space."""
    weights = 1.0 + (cfg.missing_weight - 1.0) * (1.0 - mask)
    sq = (mean_norm - target_norm) ** 2
    mse = torch.mean(weights * sq)
    loss = mse
    if use_nll and cfg.nll_weight:
        inv_var = torch.exp(-log_var)
        nll = 0.5 * torch.mean(weights * (sq * inv_var + log_var))
        loss = loss + cfg.nll_weight * nll
    if cfg.calendar_weight:
        loss = loss + cfg.calendar_weight * calendar_penalty_w(mean_norm)
    if cfg.smoothness_weight:
        loss = loss + cfg.smoothness_weight * smoothness_penalty(mean_norm)
    return loss


# -----------------------------
# Training
# -----------------------------

def train_model(
    train_w: np.ndarray,
    train_mask: np.ndarray,
    val_w: np.ndarray,
    val_mask: np.ndarray,
    cfg: LossConfig,
    epochs: int = 60,
    seed: int = SEED,
    verbose: bool = True,
) -> RepairCNN:
    """Train the CNN in a single process with optimiser state preserved.

    Targets and predictions are in w/W_SCALE space. Validation RMSE is reported
    in implied-vol units for interpretability and best-checkpoint selection.
    The first 10 epochs run pure MSE; the heteroscedastic NLL term turns on
    afterwards to avoid early variance head instabilities.
    """
    torch.manual_seed(seed)

    obs_w = corrupt_surfaces_w(train_w, train_mask, seed=seed + 101)
    x_train = make_inputs(obs_w, train_mask)
    y_train = torch.tensor((train_w / W_SCALE)[:, None, :, :], dtype=DTYPE, device=DEVICE)
    m_train = torch.tensor(train_mask[:, None, :, :], dtype=DTYPE, device=DEVICE)

    obs_w_val = corrupt_surfaces_w(val_w, val_mask, seed=seed + 202)
    x_val = make_inputs(obs_w_val, val_mask)
    y_val_vol = w_to_vol(val_w)

    model = RepairCNN().to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1.0e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    batch_size = 128
    n = train_w.shape[0]
    idx = np.arange(n)

    best_val = float("inf")
    best_state: Dict[str, torch.Tensor] | None = None

    for ep in range(epochs):
        np.random.shuffle(idx)
        model.train()
        running = 0.0
        n_batches = 0
        use_nll = ep >= 10
        for start in range(0, n, batch_size):
            batch = idx[start : start + batch_size]
            mean_norm, logv = model(x_train[batch])
            loss = repair_loss(mean_norm, logv, y_train[batch], m_train[batch], cfg, use_nll=use_nll)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            running += float(loss.detach())
            n_batches += 1
        scheduler.step()
        train_loss = running / max(n_batches, 1)

        model.eval()
        with torch.no_grad():
            mean_val_norm, _ = model(x_val)
            mu_w_val = (mean_val_norm[:, 0].cpu().numpy() * W_SCALE).astype(np.float32)
        mu_v_val = w_to_vol(mu_w_val)
        miss = val_mask < 0.5
        if miss.sum() > 0:
            val_rmse = float(np.sqrt(np.mean((mu_v_val[miss] - y_val_vol[miss]) ** 2)))
        else:
            val_rmse = float("nan")

        if val_rmse < best_val:
            best_val = val_rmse
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}

        if verbose and (ep < 5 or ep % 10 == 9 or ep == epochs - 1):
            print(f"  epoch {ep+1:02d}/{epochs}  train_loss {train_loss:.5f}  val_RMSE_vol {val_rmse:.5f}  best {best_val:.5f}", flush=True)

    if best_state is not None:
        model.load_state_dict(best_state)
    model.eval()
    return model


# -----------------------------
# Baselines
# -----------------------------

def interpolate_surface(obs: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Linear interpolation in (T, k) on total variance, with nearest fallback."""
    points = np.column_stack([TT[mask > 0.5].ravel(), KK[mask > 0.5].ravel()])
    values = obs[mask > 0.5].ravel()
    full_points = np.column_stack([TT.ravel(), KK.ravel()])
    if len(values) < 4:
        fill = np.nanmean(values) if len(values) else 0.02
        return np.full((N_T, N_K), fill, dtype=np.float32)
    lin = LinearNDInterpolator(points, values, fill_value=np.nan)
    pred = lin(full_points).reshape(N_T, N_K)
    if np.isnan(pred).any():
        near = NearestNDInterpolator(points, values)
        pred[np.isnan(pred)] = near(full_points).reshape(N_T, N_K)[np.isnan(pred)]
    return np.maximum(pred, 1e-5).astype(np.float32)


def fit_svi_slice(k_obs: np.ndarray, w_obs: np.ndarray) -> np.ndarray:
    idx = np.zeros(N_K, dtype=bool)
    for ko in k_obs:
        idx[np.argmin(np.abs(K_GRID - ko))] = True
    y = np.asarray(w_obs, dtype=np.float64)
    good = np.isfinite(y) & (y > 0)
    y = y[good]
    obs_idx = np.where(idx)[0][good] if idx.sum() == len(good) else np.where(idx)[0]
    if len(y) != len(obs_idx):
        obs_idx = np.where(idx)[0][: len(y)]
    if len(y) < 3:
        fill = float(np.nanmean(y)) if len(y) else 0.02
        return np.full(N_K, fill, dtype=np.float64)
    Fobs = SVI_SHAPES[:, obs_idx]
    f_mean = Fobs.mean(axis=1)
    y_mean = y.mean()
    f_center = Fobs - f_mean[:, None]
    y_center = y - y_mean
    denom = np.sum(f_center * f_center, axis=1) + 1e-10
    b = np.sum(f_center * y_center[None, :], axis=1) / denom
    b = np.clip(b, 1e-6, 10.0)
    a = y_mean - b * f_mean
    a = np.clip(a, 1e-6, 5.0)
    pred_obs = a[:, None] + b[:, None] * Fobs
    sse = np.mean((pred_obs - y[None, :]) ** 2, axis=1)
    best = int(np.argmin(sse))
    pred = a[best] + b[best] * SVI_SHAPES[best]
    return np.maximum(pred, 1e-5)


def svi_fit_surface(obs: np.ndarray, mask: np.ndarray, enforce_calendar: bool = True) -> np.ndarray:
    pred = np.empty((N_T, N_K), dtype=np.float64)
    for j in range(N_T):
        idx = mask[j] > 0.5
        pred[j] = fit_svi_slice(K_GRID[idx], obs[j, idx])
    pred = np.maximum(pred, 1e-5)
    if enforce_calendar:
        pred = calendar_project(pred)
    return pred.astype(np.float32)


# -----------------------------
# Projections, metrics, evaluation
# -----------------------------

def calendar_project(w: np.ndarray) -> np.ndarray:
    """Strict isotonic projection along maturity via cumulative maximum."""
    return np.maximum.accumulate(np.maximum(w, 1e-5), axis=-2 if w.ndim == 3 else 0)


def calendar_violation_rate(w: np.ndarray) -> float:
    dw = np.diff(w, axis=-2 if w.ndim == 3 else 0)
    return float(np.mean(dw < -1e-8))


def butterfly_g_violation_rate(w: np.ndarray) -> float:
    """Discrete Gatheral-Roper g(k) >= 0 diagnostic on total variance surfaces."""
    dk = K_GRID[1] - K_GRID[0]
    wi = np.maximum(w[..., 1:-1], 1e-6)
    wp = (w[..., 2:] - w[..., :-2]) / (2.0 * dk)
    wpp = (w[..., 2:] - 2.0 * w[..., 1:-1] + w[..., :-2]) / (dk * dk)
    k = K_GRID[1:-1][None, None, :]
    while k.ndim < wi.ndim:
        k = np.expand_dims(k, axis=0)
    g = (1.0 - k * wp / (2.0 * wi)) ** 2 - (wp * wp / 4.0) * (1.0 / wi + 0.25) + 0.5 * wpp
    return float(np.mean(g < -1e-6))


def rmse_missing_w(pred_w: np.ndarray, target_w: np.ndarray, mask: np.ndarray) -> Tuple[float, float]:
    """Per-surface RMSE in implied vol units on missing cells, returned as (mean, SE)."""
    pred_vol = w_to_vol(pred_w)
    target_vol = w_to_vol(target_w)
    per = []
    for i in range(target_w.shape[0]):
        miss = mask[i] < 0.5
        if miss.sum() == 0:
            continue
        per.append(math.sqrt(float(np.mean((pred_vol[i][miss] - target_vol[i][miss]) ** 2))))
    arr = np.array(per)
    return float(arr.mean()), float(arr.std(ddof=1) / math.sqrt(len(arr)))


def predict_model(
    model: RepairCNN, true_w: np.ndarray, mask: np.ndarray, seed: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Run model on noisy/sparse surfaces. Returns (mu_w, sd_vol, observed_w).

    The model's mean output is in w/W_SCALE space; we convert to actual w and
    derive an approximate vol-space standard deviation via delta-method
    sd_vol ~ sd_w / (2 * vol * T) where vol = sqrt(w / T).
    """
    obs_w = corrupt_surfaces_w(true_w, mask, seed=seed)
    x = make_inputs(obs_w, mask)
    model.eval()
    with torch.no_grad():
        mean_norm, logv = model(x)
    mu_w = (mean_norm[:, 0].cpu().numpy() * W_SCALE).astype(np.float32)
    sd_w = (np.sqrt(np.exp(logv[:, 0].cpu().numpy())) * W_SCALE).astype(np.float32)
    # Delta-method conversion to vol-space sd.
    vol_pred = w_to_vol(mu_w)
    sd_vol = sd_w / (2.0 * np.maximum(vol_pred, 1e-3) * TT[None, :, :])
    return mu_w, sd_vol.astype(np.float32), obs_w


def predict_model_from_obs(model: RepairCNN, obs_w: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Run the model on a pre-corrupted observed_w (used by stale-quote AUC)."""
    x = make_inputs(obs_w, mask)
    model.eval()
    with torch.no_grad():
        mean_norm, logv = model(x)
    mu_w = (mean_norm[:, 0].cpu().numpy() * W_SCALE).astype(np.float32)
    sd_w = (np.sqrt(np.exp(logv[:, 0].cpu().numpy())) * W_SCALE).astype(np.float32)
    return mu_w, sd_w


def uncertainty_coverage(mu_w: np.ndarray, sd_vol: np.ndarray, target_w: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
    """Coverage of nominal Gaussian intervals, evaluated in vol units."""
    mu_vol = w_to_vol(mu_w)
    target_vol = w_to_vol(target_w)
    miss = mask < 0.5
    err = np.abs(mu_vol[miss] - target_vol[miss])
    s = np.maximum(sd_vol[miss], 1e-6)
    cov80 = float(np.mean(err <= 1.28155 * s))
    cov95 = float(np.mean(err <= 1.95996 * s))
    if err.size > 5:
        corr = float(np.corrcoef(err.ravel(), s.ravel())[0, 1])
    else:
        corr = float("nan")
    return {"coverage80": cov80, "coverage95": cov95, "err_sd_corr": corr, "avg_sd": float(np.mean(s))}


def stale_quote_auc(model: RepairCNN, true_w: np.ndarray, mask: np.ndarray, seed: int, stale_frac: float = 0.08) -> float:
    """Stale-quote AUC: model sees the stale-injected input.

    1. Generate a noisy observed surface.
    2. Inject stale errors at a random subset of observed cells.
    3. Run the model on the stale-injected surface.
    4. Score residuals at observed cells; AUC of stale label vs |obs - mu|.
    """
    rng = np.random.default_rng(seed)
    noisy_w = corrupt_surfaces_w(true_w, mask, seed=seed + 77)
    labels = np.zeros_like(mask, dtype=np.int32)
    corrupted = noisy_w.copy()
    for i in range(mask.shape[0]):
        obs_idx = np.argwhere(mask[i] > 0.5)
        if len(obs_idx) == 0:
            continue
        n_stale = max(1, int(len(obs_idx) * stale_frac))
        chosen = obs_idx[rng.choice(len(obs_idx), size=n_stale, replace=False)]
        for r, c in chosen:
            labels[i, r, c] = 1
            vol = math.sqrt(max(corrupted[i, r, c] / T_GRID[r], 1e-8))
            vol += rng.choice([-1.0, 1.0]) * rng.uniform(0.0050, 0.0110)
            corrupted[i, r, c] = max(vol, 0.03) ** 2 * T_GRID[r]
    # Run model on the stale-injected input so residuals reflect detection capability.
    mu_w, _ = predict_model_from_obs(model, corrupted, mask)
    scores = np.abs(corrupted - mu_w)
    observed = mask > 0.5
    y = labels[observed].ravel()
    s = scores[observed].ravel()
    if len(np.unique(y)) < 2:
        return float("nan")
    return float(roc_auc_score(y, s))


def svi_projection_metric(pred_w: np.ndarray, target_w: np.ndarray, mask: np.ndarray, max_surfaces: int) -> Tuple[float, float]:
    n = min(max_surfaces, pred_w.shape[0])
    projected = np.empty_like(pred_w[:n])
    full_mask = np.ones((N_T, N_K), dtype=np.float32)
    for i in range(n):
        projected[i] = svi_fit_surface(pred_w[i], full_mask, enforce_calendar=True)
    return rmse_missing_w(projected, target_w[:n], mask[:n])


def evaluate_case(name: str, model: RepairCNN, true_w: np.ndarray, mask: np.ndarray, seed: int) -> Dict[str, object]:
    mu_w_raw, sd_vol, obs_w = predict_model(model, true_w, mask, seed)
    mu_w = calendar_project(mu_w_raw)

    interp_w = np.stack([interpolate_surface(obs_w[i], mask[i]) for i in range(true_w.shape[0])])
    interp_w = calendar_project(interp_w)

    n_svi = min(SVI_EVAL_SURFACES, true_w.shape[0])
    svi_w = np.full_like(interp_w, np.nan)
    for i in range(n_svi):
        svi_w[i] = svi_fit_surface(obs_w[i], mask[i], enforce_calendar=True)

    cnn_m, cnn_se = rmse_missing_w(mu_w, true_w, mask)
    interp_m, interp_se = rmse_missing_w(interp_w, true_w, mask)
    svi_m, svi_se = rmse_missing_w(svi_w[:n_svi], true_w[:n_svi], mask[:n_svi])

    unc = uncertainty_coverage(mu_w, sd_vol, true_w, mask)
    auc = stale_quote_auc(model, true_w, mask, seed=seed + 401)
    svi_after_cnn, svi_after_cnn_se = svi_projection_metric(mu_w, true_w, mask, SVI_EVAL_SURFACES)
    svi_after_interp, svi_after_interp_se = svi_projection_metric(interp_w, true_w, mask, SVI_EVAL_SURFACES)

    return {
        "case": name,
        "observed_fraction": float(mask.mean()),
        "cnn_missing_rmse": cnn_m, "cnn_missing_se": cnn_se,
        "interp_missing_rmse": interp_m, "interp_missing_se": interp_se,
        "svi_missing_rmse": svi_m, "svi_missing_se": svi_se,
        "calendar_raw_rate": calendar_violation_rate(mu_w_raw),
        "calendar_projected_rate": calendar_violation_rate(mu_w),
        "butterfly_g_rate": butterfly_g_violation_rate(mu_w),
        "unc_cov80": unc["coverage80"],
        "unc_cov95": unc["coverage95"],
        "unc_corr": unc["err_sd_corr"],
        "unc_avg_sd": unc["avg_sd"],
        "stale_auc": auc,
        "svi_after_cnn_rmse": svi_after_cnn, "svi_after_cnn_se": svi_after_cnn_se,
        "svi_after_interp_rmse": svi_after_interp, "svi_after_interp_se": svi_after_interp_se,
        "_mu_w": mu_w, "_sd_vol": sd_vol, "_obs_w": obs_w, "_interp_w": interp_w, "_svi_w": svi_w,
    }


def print_results_table(results: List[Dict[str, object]]) -> None:
    print("\n2x2 missing-point RMSE in implied vol units (mean +/- SE)")
    print("-" * 110)
    print(f"{'Case':40s} {'Obs%':>6s}   {'CNN':>16s}   {'Interp':>16s}   {'SVI':>16s}")
    for r in results:
        print(
            f"{r['case']:40s} {100*r['observed_fraction']:5.1f}%  "
            f"{r['cnn_missing_rmse']:.4f} ± {r['cnn_missing_se']:.4f}  "
            f"{r['interp_missing_rmse']:.4f} ± {r['interp_missing_se']:.4f}  "
            f"{r['svi_missing_rmse']:.4f} ± {r['svi_missing_se']:.4f}"
        )
    print("-" * 110)

    print("\nDiagnostics after calendar projection")
    print("-" * 112)
    print(f"{'Case':40s} {'raw cal%':>9s} {'post cal%':>10s} {'g(k)<0%':>9s} {'cov80':>7s} {'cov95':>7s} {'corr':>7s} {'AUC':>7s}")
    for r in results:
        print(
            f"{r['case']:40s} {100*r['calendar_raw_rate']:8.2f}% {100*r['calendar_projected_rate']:9.2f}% "
            f"{100*r['butterfly_g_rate']:8.2f}% {r['unc_cov80']:7.3f} {r['unc_cov95']:7.3f} {r['unc_corr']:7.3f} {r['stale_auc']:7.3f}"
        )
    print("-" * 112)

    print("\nDownstream SVI projection: missing-point RMSE in implied vol units")
    print("-" * 96)
    print(f"{'Case':40s} {'SVI after CNN':>22s}     {'SVI after interp':>22s}")
    for r in results:
        print(
            f"{r['case']:40s} "
            f"{r['svi_after_cnn_rmse']:.4f} ± {r['svi_after_cnn_se']:.4f}      "
            f"{r['svi_after_interp_rmse']:.4f} ± {r['svi_after_interp_se']:.4f}"
        )
    print("-" * 96)


def make_diagnostic_plot(result: Dict[str, object], true_w: np.ndarray, mask: np.ndarray, idx: int = 0) -> None:
    """Diagnostic 2x3 plot in implied vol units."""
    target_vol = w_to_vol(true_w[idx])
    mu_vol = w_to_vol(result["_mu_w"][idx])
    obs_vol = w_to_vol(result["_obs_w"][idx])
    interp_vol = w_to_vol(result["_interp_w"][idx])
    sd_vol = result["_sd_vol"][idx]

    panels = [
        target_vol,
        np.where(mask[idx] > 0.5, obs_vol, np.nan),
        mu_vol,
        interp_vol,
        sd_vol,
        np.abs(mu_vol - target_vol),
    ]
    titles = [
        "Clean implied vol",
        "Observed sparse quotes",
        "CNN repair (vol)",
        "Calendar-projected interp",
        "CNN uncertainty (vol)",
        "Absolute CNN error",
    ]
    fig, axes = plt.subplots(2, 3, figsize=(13, 7), constrained_layout=True)
    vmin = float(np.nanmin(target_vol))
    vmax = float(np.nanmax(target_vol))
    for ax, data, title in zip(axes.ravel(), panels, titles):
        if "uncertainty" in title.lower() or "error" in title.lower():
            im = ax.imshow(data, aspect="auto", origin="lower", extent=[K_GRID.min(), K_GRID.max(), T_GRID.min(), T_GRID.max()])
        else:
            im = ax.imshow(data, aspect="auto", origin="lower", extent=[K_GRID.min(), K_GRID.max(), T_GRID.min(), T_GRID.max()], vmin=vmin, vmax=vmax)
        ax.set_title(title)
        ax.set_xlabel("log-moneyness k")
        ax.set_ylabel("maturity T")
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.suptitle(f"Diagnostic: {result['case']}", fontsize=13)
    fig.savefig(FIG_PATH, dpi=150)
    plt.close(fig)


# -----------------------------
# Main
# -----------------------------

def main() -> None:
    n_train = 1600
    n_val = 200
    n_test = 200
    epochs = 60

    print(f"Training surfaces: {n_train}; validation: {n_val}; test/cell: {n_test}; epochs: {epochs}; grid: {N_T}x{N_K}")

    train_w = ssvi_total_variance(n_train, shifted=False, seed=100)
    train_mask = make_mask("regular", n_train, seed=200)
    val_w = ssvi_total_variance(n_val, shifted=False, seed=300)
    val_mask = make_mask("regular", n_val, seed=400)

    cfg = LossConfig(missing_weight=5.0, calendar_weight=80.0, smoothness_weight=0.05, nll_weight=0.10)
    print("Training CNN ...", flush=True)
    model = train_model(train_w, train_mask, val_w, val_mask, cfg, epochs=epochs, seed=SEED, verbose=True)

    test_sets = {
        "Shifted SSVI / regular missing": (ssvi_total_variance(n_test, shifted=True, seed=1001), make_mask("regular", n_test, seed=2001)),
        "Shifted SSVI / adversarial missing": (ssvi_total_variance(n_test, shifted=True, seed=1002), make_mask("adversarial", n_test, seed=2002)),
        "SABR-event / regular missing": (sabr_event_total_variance(n_test, seed=1003), make_mask("regular", n_test, seed=2003)),
        "SABR-event / adversarial missing": (sabr_event_total_variance(n_test, seed=1004), make_mask("adversarial", n_test, seed=2004)),
    }

    results = []
    for i, (name, (w, m)) in enumerate(test_sets.items()):
        print(f"Evaluating {name} ...", flush=True)
        results.append(evaluate_case(name, model, w, m, seed=3000 + i))
    print_results_table(results)

    # Save numeric summary.
    res_path = OUT_DIR / "vol_surface_repair_results.txt"
    with res_path.open("w", encoding="utf-8") as f:
        f.write("Results: 2x2 missing-point RMSE (vol units), diagnostics, downstream SVI\n")
        for r in results:
            f.write(
                f"{r['case']}, obs={r['observed_fraction']:.4f}, "
                f"cnn={r['cnn_missing_rmse']:.6f}+/-{r['cnn_missing_se']:.6f}, "
                f"interp={r['interp_missing_rmse']:.6f}+/-{r['interp_missing_se']:.6f}, "
                f"svi={r['svi_missing_rmse']:.6f}+/-{r['svi_missing_se']:.6f}, "
                f"raw_cal={r['calendar_raw_rate']:.6f}, post_cal={r['calendar_projected_rate']:.6f}, g_rate={r['butterfly_g_rate']:.6f}, "
                f"cov80={r['unc_cov80']:.4f}, cov95={r['unc_cov95']:.4f}, corr={r['unc_corr']:.4f}, auc={r['stale_auc']:.4f}, "
                f"svi_after_cnn={r['svi_after_cnn_rmse']:.6f}, svi_after_interp={r['svi_after_interp_rmse']:.6f}\n"
            )

    # Diagnostic plot from the SABR adversarial case.
    hard_w, hard_m = test_sets["SABR-event / adversarial missing"]
    make_diagnostic_plot(results[-1], hard_w, hard_m, idx=3)
    print(f"\nSaved diagnostic figure: {FIG_PATH}")
    print(f"Saved results summary: {res_path}")


if __name__ == "__main__":
    main()
    sys.stdout.flush()
    sys.stderr.flush()