DEV Community

Cover image for First internal integration of the new API
Satwik Sai Prakash Sahoo
Satwik Sai Prakash Sahoo

Posted on

First internal integration of the new API

Hello again! The GSoC coding period is in full swing, and weeks 3 and 4 have been absolutely packed with progress.

Following up on the foundational work from my first PR, I have just successfully merged two major PRs (#1877 and #1882) into the gsoc-2026 branch for sbi. These PRs introduce the first concrete builder class and wire it directly into the core Neural Posterior Estimation (NPE) trainers.

Here is a dive into what I built, the technical hurdles, and a very valuable lesson I learned about software architecture along the way.

PR #1877: The DensityEstimatorBuilder

The primary goal of this phase was to replace the old, opaque posterior_nn() and likelihood_nn() factory closures with something typed, inspectable, and much more robust.

To solve this, I introduced the DensityEstimatorBuilder. It inherits from the base contract we established in Week 1 and serves as the unified entry point for creating neural networks in sbi. Using the __post_init__ method in Python dataclasses, it immediately validates the model name against a _VALID_DENSITY_MODELS set, failing early if the user provides an unknown architecture.

Major architectural shift

Initially, our plan dictated that the build() method should take a BuildContext object. The idea was that this context would hold all necessary information, including pre-computed z-scoring stats, and pass it neatly down the chain.

However, as I implemented the body of the build() method, my mentor Jan Teusen noticed that the context parameter wasn't actually being used. Every piece of information the builder needed could be derived directly from the raw batch_theta and batch_x tensors.

We realized that forcing the BuildContext into this signature was a premature abstraction.

Instead of holding onto a design just because it was the original plan, we decided to defer the context object entirely until the z-scoring stats are actually pre-computed in a later phase. We updated our roadmap and simplified the build signature.

def build(self, batch_theta: Tensor, batch_x: Tensor):
    pass
Enter fullscreen mode Exit fullscreen mode

It was a great decision to defer the context implementation to work on the breadth first and then move towards depth.

PR #1882: Integrating the Builder into NPE Trainers

With the builder merged, the next step was integration. I updated PosteriorEstimatorTrainer, NPE_B, NPE_C, and MNPE to accept the new DensityEstimatorBuilder instead of relying solely on strings or callables.

To maintain backward compatibility while moving the API forward, a graceful deprecation path was implemented. If a user passes a string (e.g., "maf"), the code still works perfectly, but it now emits a FutureWarning.

if density_estimator is None:
    self._build_neural_net = self._wrap_builder(DensityEstimatorBuilder(model="maf"))
elif isinstance(density_estimator, str):
    warnings.warn(
        "Passing a string for `density_estimator` is deprecated. "
        "Use DensityEstimatorBuilder(model=...) instead.",
        FutureWarning,
        stacklevel=3,
    )
    self._build_neural_net = posterior_nn(model=density_estimator)
elif isinstance(density_estimator, _EstimatorBuilderBase):
    self._build_neural_net = self._wrap_builder(density_estimator)
else:
    self._build_neural_net = density_estimator
Enter fullscreen mode Exit fullscreen mode

Overal feedback on my work

The code review for PR #1882 was intense but incredibly rewarding. My mentor provided feedback on how to write tests that are not just concise, but strong and explicit in their intent.

For example, I originally wrote a test that checked if passing a callable avoided triggering the deprecation warning. But I wasn't actually asserting that no warning was thrown, I was just running the code and assuming silence meant success.

My mentor showed me how to use warnings.catch_warnings() with a strict filter to instantly fail the test if a FutureWarning leaked through:

import warnings

with warnings.catch_warnings():
    warnings.simplefilter("error", FutureWarning)
    inference = NPE_C(prior, density_estimator=builder, show_progress_bars=False)
Enter fullscreen mode Exit fullscreen mode

We also did a deep dive into correct type hinting and managing default arguments. I initially left density_estimator="maf" as the default argument in the NPE_C initialization. My mentor pointed out that this would cause the deprecation warning to fire every single time a user initialized the class without arguments! The fix was to change the type hint default to None and handle the "maf" fallback inside the logic block.

What's Next?

Weeks 3 and 4 were a massive leap forward for the API refactor. We now have a working, integrated builder that correctly handles all continuous density estimators.

Next up, I will be tackling the remaining likelihood and classifier builders. Thanks for following along on this journey, and see you in the next update!

Top comments (0)