NUTS for Galaxies: Extending Leading-Edge Statistical Tools for Use in Astronomy and Cosmology

Published: Nov. 26, 2025

2025A

Scientific Background

Modern Bayesian inference in astronomy increasingly relies on gradient-based MCMC methods such as the No-U-Turn Sampler (NUTS), which enable efficient and scalable exploration of complex, high-dimensional parameter spaces. These approaches underpin next-generation probabilistic programming frameworks such as PyMC and NumPyro, which employ automatic differentiation engines—PyTensor and JAX, respectively—to compute analytical gradients of user-defined likelihoods.

For extragalactic astronomy, one of the most fundamental likelihood components is the Schechter luminosity function, which models the distribution of galaxies as a function of luminosity or stellar mass. Its simplicity and physical interpretability have made it the standard model for describing galaxy demographics across cosmic time. However, the normalisation of the Schechter function requires evaluation of the incomplete gamma function, which diverges for faint-end slopes α < 0—precisely the astrophysically relevant regime for galaxies (−2 ≲ α ≲ 0). While numerical libraries such as SciPy and SymPy can evaluate this integral, their implementations are not differentiable, and therefore cannot be used within NUTS or other Hamiltonian Monte Carlo (HMC) frameworks.

This limitation has quietly become a bottleneck for astronomy: without a differentiable, JAX- or PyTensor-compatible incomplete gamma function, astronomers cannot define a fully regularised Schechter likelihood for use in hierarchical Bayesian models. As a result, most analyses remain confined to low-dimensional parameter spaces or rely on less efficient ensemble samplers such as emcee. Enabling a differentiable Schechter distribution within NUTS-based frameworks therefore represents a small but crucial step toward scaling hierarchical galaxy modelling to cosmological sample sizes.

An overview of the datasets that motivate this work is shown in Figure 1, which illustrates the footprints of the 6dF, 2MRS, and GAMA surveys on the sky, underlaid with an SDSS image of the Milky Way and the nearby Universe.

Plot of survey footprints

Figure 1: Survey footprints for the 6dF, 2MRS, and GAMA datasets underlaid on an SDSS composite image of the Milky Way and local Universe. The combined sky coverage highlights the complementary depth and area of each survey used in this work.

Project Achievements

Through this six-week ADACS collaboration, we successfully developed a JAX-compatible, jit-compiled incomplete gamma function capable of handling negative values of α while preserving differentiability under automatic differentiation. The work began by analysing the mathematical conditions under which the recurrence relations for the incomplete gamma function converge for α < 0, and by reformulating these relations into a numerically stable recursive form suitable for JAX compilation.

The resulting implementation, written entirely in JAX, was validated against SciPy’s reference routine across a wide parameter domain (−3 < α < +3, 10^−5 < x < 10^3). The implementation remained fully differentiable and compatible with jax.jit. Gradient checks confirmed that the function propagated stable derivatives, a critical requirement for Hamiltonian Monte Carlo inference.

Once validated, the new incomplete gamma function was integrated into a NumPyro-compatible Schechter distribution framework that includes both single- and double–Schechter forms. The corresponding Python classes provide native log-probability evaluation and automatic normalisation for arbitrary α, ensuring complete differentiability throughout the likelihood. This implementation enables the Schechter function—and, by extension, luminosity and mass function models—to be used directly as priors or likelihoods within NumPyro and PyMC. The double–Schechter variant, in particular, reproduces the two-component luminosity distributions observed in real galaxy populations, making it directly applicable to large survey analyses. Benchmarks on synthetic datasets demonstrated a 10–15× speed improvement over traditional ensemble samplers such as emcee, while maintaining consistent posterior convergence.

The code was modularised and fully documented for community use. Figure 2 shows a practical example of this framework in use: the luminosity density distribution reconstructed across the full sky from our hierarchical Bayesian model. Each cell represents an independent but hierarchical fit to the galaxy luminosity function, allowing spatial variations and cosmic variance to be quantified. We can clearly see that our hierarchical fit in each cell is able to keep up with the background structures. Cells with structures (with large density of points) appear denser with yellower colours. Figure 3 illustrates one such cell (from the GAMA region), showing the fitted double–Schechter function and the underlying galaxy counts. The blue line in the plot shows the fit accounting for the selection affects. It matches extremely well with the observed data.

Plot showing sky cells with luminosity density distributions

Figure 2: Reconstructed luminosity density distribution across the sky from the hierarchical Bayesian model. Each pixel represents an independent sky cell, colour-coded by its mean luminosity density J, enabling direct visualisation of large-scale structure and cosmic variance.

Plot showing a good fit of the double-Schechter for a GAMA cell.

Figure 3: Fit of the double-Schechter function for the GAMA cell. The red line shows the intrinsic luminosity function that our model predicts. The blue line accounts for the selection effects and shows the luminosity function that we predict to observe through our telescope. It matches really well with the observed GAMA galaxy data.

Impact and Next Steps

This development removes a long-standing barrier in Bayesian galaxy modelling. For the first time, a fully differentiable, NUTS-ready Schechter distribution is available for use in large-scale hierarchical inference. This capability directly benefits a wide range of extragalactic applications, from measuring spatial variations in luminosity function parameters across the sky (as in 2MRS, 6dFGS, and GAMA) to modelling environmental dependencies, cosmic variance, and galaxy evolution in upcoming surveys such as DESI, 4MOST, and LSST.

The broader impact extends beyond astronomy. A differentiable upper incomplete gamma function is a reusable numerical primitive with applications in probabilistic machine learning, cosmology, and any field requiring differentiable special functions. By bridging the gap between astronomical methodology and modern machine-learning infrastructure, this ADACS project exemplifies how targeted software development can unlock entire classes of scientific models.

The codebase has been released as an open-source Python package numpyro-schechter, available on GitHub and registered with PyPI for easy installation. The package provides both the differentiable incomplete gamma function and ready-to-use Schechter distribution classes, complete with unit tests and worked examples. Future work will focus on extending this implementation to include this inside the NumPyro package instead of a standalone Python library.

In summary, this ADACS project delivered a compact but enabling piece of computational infrastructure that allows astronomers to move from sampling tens of parameters to hundreds or thousands, thereby scaling Bayesian hierarchical analysis to the era of precision cosmology.

Written by Aryan Bansal.

Project Details

Node: Swinburne University of Technology
Project Length: 6 weeks
Development Team:
  • Alice Serene
Research Science Team:
  • Aryan Bansal (project lead)
  • Edward Taylor

Check out some of our other projects.

See all projects.