Using the picasso trained predictors

Using the picasso trained predictors#

This notebook shows how one can use the trained models to make predictions of gas thermodynamics from halo properties. For a full documentation of the predictor objects and their methods, see picasso.predictors: From halo properties to gas properties.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

from picasso import predictors
from picasso.test_data import halos, profs

import seaborn as sns

sns.set_style("darkgrid")
sns.set_theme("notebook")

benchmark = True

We will use the minimal_576 trained model, which takes as input halo mass and concentration:

predictor = predictors.minimal_576
print(predictor.input_names)
['log M200' 'c200']

Predicting gas model parameters#

First, we want to compute predictions for the model parameter vector, \(\vartheta_{\rm gas}\). To do so, we simply need the vector of scalar halo properties \(\vartheta_{\rm halo}\). We’ll use some pre-stored data (containing four halos from the simulations presented in Kéruzoré+24) and write the input vector:

logM200c = halos["log M200"]
c200c = jnp.array(halos["c200"])
theta_halo = jnp.array([logM200c, c200c]).T
print(theta_halo.shape)
(7, 2)

We can then use the predictor.predict_model_parameters() function to predict \(\vartheta_{\rm gas}\). For a single halo:

theta_gas_0 = predictor.predict_model_parameters(theta_halo[0])
print(theta_gas_0)
[ 2.8579862   1.5800635   1.1345413   0.          0.76251674 -1.3342006
 -0.5730133   0.8244546 ]

Note that the predictors have been trained in log-space for some of the parameters, and make their predictions in that space. The exact normalization of the predictions is that of Table 1 in the paper:

Symbol

Meaning

Range

\(\log_{10} \rho_0\)

(log-scaled) Central normalized gas density

\((1.5, \, 5)\)

\(\log_{10} P_0\)

(log-scaled) Central normalized gas total pressure

\((0, \, 4.5)\)

\(\Gamma_0\)

Gas polytropic index limit as \(r \rightarrow \infty\)

\((1, \, 1.4)\)

\(c_\gamma\)

Gas polytropic index shape parameter

\([0]\)*

\(\theta_0 / (10^{-6} \, {\rm km^2 s^{-2}})\)

Polytropic normalization

\((0, \, 2)\)

\(\log_{10} A_{\rm nt}\)

(log-scaled) Central plateau of non-thermal pressure fraction

\((-4, \, 0)\)

\(\log_{10} B_{\rm nt}\)

(log-scaled) Non-thermal pressure fraction at \(r=2R_{500c}\)

\((-1.5, \, 0)\)

\(C_{\rm nt}\)

Non-thermal pressure fraction profile power law index

\((0, \, 4)\)

* \((-1,\, 1)\) for the NR+\(\Gamma(r)\) and SG+\(\Gamma(r)\) models.

The predictor.predict_model_parameters() function can also be used for several halos at a time:

theta_gas = predictor.predict_model_parameters(theta_halo)
print(f"{theta_gas=}")
print(f"{theta_gas.shape=}")
theta_gas=Array([[ 2.8579862 ,  1.5800636 ,  1.1345413 ,  0.        ,  0.7625168 ,
        -1.3342004 , -0.5730133 ,  0.8244548 ],
       [ 3.0891585 ,  1.8412256 ,  1.1362435 ,  0.        ,  0.49745762,
        -1.423728  , -0.56955934,  1.0702431 ],
       [ 3.0761304 ,  1.8216532 ,  1.1312138 ,  0.        ,  0.3338053 ,
        -1.3907712 , -0.5089177 ,  1.0523993 ],
       [ 2.8483396 ,  1.5602771 ,  1.1200436 ,  0.        ,  0.24938817,
        -1.3459554 , -0.46140385,  0.853472  ],
       [ 2.9394426 ,  1.6772497 ,  1.1291878 ,  0.        ,  0.16645454,
        -1.3711526 , -0.45240808,  0.94847983],
       [ 2.9498107 ,  1.6725324 ,  1.1352186 ,  0.        ,  0.1211944 ,
        -1.3999796 , -0.4420457 ,  1.0124265 ],
       [ 3.2267897 ,  2.0133085 ,  1.1500207 ,  0.        ,  0.07647304,
        -1.5900545 , -0.49178195,  1.3881102 ]], dtype=float32)
theta_gas.shape=(7, 8)
df = {"$\\log M_{200c}$": logM200c, "$c_{200c}$": c200c}
for i, name in enumerate(
    [
        "$\\log \\rho_0$",
        "$\\log P_0$",
        "$\\Gamma_0$",
        "$c_\\gamma$",
        "$\\theta_0 \\times 10^6$",
        "$\\log A_{\\rm nt}$",
        "$\\log B_{\\rm nt}$",
        "$C_{\\rm nt}$",
    ]
):
    df[name] = theta_gas[:, i]
df = pd.DataFrame(df)

pg = sns.PairGrid(df, corner=True, diag_sharey=False)
pg.map_lower(sns.scatterplot)
pg.map_diag(sns.histplot)
<seaborn.axisgrid.PairGrid at 0x31e7e9f70>
../_images/1ecbeef4d48e77e2f810abddef2c48c28c4a0830f6b21cb59236278d96c18e31.png

It can also be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_model_parameters)
    print("Not jitted:", end=" ")
    %timeit _ = predictor.predict_model_parameters(theta_halo)
    print("jitted:", end=" ")
    _ = predict_jit(theta_halo)
    %timeit _ = predict_jit(theta_halo)
Not jitted: 5.76 ms ± 80.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jitted: 13.9 μs ± 423 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Predicting gas thermodynamics#

With a prediction for \(\vartheta_{\rm gas}\), we can use picasso.polytrop and picasso.nonthermal to predict gas thermodynamics (see Using the picasso analytical gas model). PicassoPredictor objects also offers a wrapper function that predicts all thermodynamic properties directly from an input vector \(\vartheta_{\rm halo}\) and a potential distribution. Assuming the halos above are NFW, we can predict their potential profiles:

r_R500c, phi = profs["r_R500"], profs["phi"]

Then, we can make predictions of gas thermodynamics for one halo:

rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(
    theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2
)

Or for all halos at the same time (this function uses jax.vmap to vectorize the predictions):

rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(
    theta_halo, phi, r_R500c, r_R500c / 2
)
fig, axs = plt.subplots(1, 4, figsize=(13, 4))
for ax, q in zip(axs, [rho_g, P_tot, P_th, f_nt]):
    ax.loglog(r_R500c.T, q.T)
    ax.set_xlabel("$r / R_{500c}$")

axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
axs[2].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[3].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()
../_images/55cf0cf5a89099ed29f3b5a6a1e55fdc8f0c414c9a9a64f5340ddb1252a139fc.png

Again, these functions can be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_gas_model)

    print("1 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)
    print("1 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)
    %timeit _ = predict_jit(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)

    print("4 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo, phi, r_R500c, r_R500c / 2)
    print("4 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo, phi, r_R500c, r_R500c / 2)
    %timeit _ = predict_jit(theta_halo, phi, r_R500c, r_R500c / 2)
1 halo, not jitted: 7.55 ms ± 185 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1 halo, jitted: 188 μs ± 4.17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
4 halo, not jitted: 13.3 ms ± 216 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4 halo, jitted: 47.3 μs ± 2.43 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For an example of how these predictions can be included in a differentiable loss function and used to train a picasso model, see Training a picasso predictor.