Using the picasso trained predictors

Using the picasso trained predictors#

This notebook shows how one can use the trained models to make predictions of gas thermodynamics from halo properties. For a full documentation of the predictor objects and their methods, see picasso.predictors: From halo properties to gas properties.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

from picasso import predictors
from picasso.test_data import halos, profs

import seaborn as sns
sns.set_style("darkgrid")
sns.set_theme("notebook")

benchmark = True
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import jax
      2 import jax.numpy as jnp
      3 import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'jax'

We will use the minimal_576 trained model, which takes as input halo mass and concentration:

predictor = predictors.minimal_576
print(predictor.input_names)
['log M200' 'c200']

Predicting gas model parameters#

First, we want to compute predictions for the model parameter vector, \(\vartheta_{\rm gas}\). To do so, we simply need the vector of scalar halo properties \(\vartheta_{\rm halo}\). We’ll use some pre-stored data (containing four halos from the simulations presented in Kéruzoré+24) and write the input vector:

logM200c = halos["log M200"]
c200c = jnp.array(halos["c200"])
theta_halo = jnp.array([logM200c, c200c]).T
print(theta_halo.shape)
(7, 2)

We can then use the predictor.predict_model_parameters() function to predict \(\vartheta_{\rm gas}\). For a single halo:

theta_gas_0 = predictor.predict_model_parameters(theta_halo[0])
print(theta_gas_0)
[ 2.9106064   1.6468848   1.1334685   0.          0.7579316  -1.3241944
 -0.5855      0.85590917]

The predictor.predict_model_parameters() function can also be used for several halos at a time:

theta_gas = predictor.predict_model_parameters(theta_halo)
print(f"{theta_gas=}")
print(f"{theta_gas.shape=}")
theta_gas=Array([[ 2.9106064 ,  1.6468848 ,  1.1334685 ,  0.        ,  0.7579316 ,
        -1.3241944 , -0.5855    ,  0.85590917],
       [ 3.0965114 ,  1.8520657 ,  1.1311967 ,  0.        ,  0.48492068,
        -1.4299242 , -0.5614897 ,  1.1043483 ],
       [ 3.0746489 ,  1.8180763 ,  1.1310333 ,  0.        ,  0.3356811 ,
        -1.37233   , -0.5092694 ,  1.0796801 ],
       [ 2.8056588 ,  1.5297322 ,  1.1362283 ,  0.        ,  0.27505645,
        -1.4245594 , -0.48284996,  0.7942678 ],
       [ 2.950333  ,  1.6717771 ,  1.1333015 ,  0.        ,  0.17360152,
        -1.400001  , -0.47301912,  0.89461106],
       [ 2.9969964 ,  1.7104851 ,  1.1341447 ,  0.        ,  0.12555973,
        -1.4094577 , -0.47190595,  0.9506246 ],
       [ 3.2324495 ,  2.008476  ,  1.1400145 ,  0.        ,  0.07529356,
        -1.5716207 , -0.50930977,  1.2982608 ]], dtype=float32)
theta_gas.shape=(7, 8)
df = {"$\\log M_{200c}$": logM200c, "$c_{200c}$": c200c}
for i, name in enumerate(
    [
        "$\\rho_0$",
        "$P_0$",
        "$\\Gamma_0$",
        "$c_\\gamma$",
        "$\\theta_0$",
        "$A_{\\rm nt}$",
        "$B_{\\rm nt}$",
        "$C_{\\rm nt}$",
    ]
):
    df[name] = theta_gas[:, i]
df = pd.DataFrame(df)

pg = sns.PairGrid(df, corner=True, diag_sharey=False)
pg.map_lower(sns.scatterplot)
pg.map_diag(sns.histplot)
<seaborn.axisgrid.PairGrid at 0x7f92983da160>
../_images/9e4523a1c25d9098f699277a79dc7cecd17b38b0a05ce5550a9bf5fe3754651b.png

It can also be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_model_parameters)
    print("Not jitted:", end=" ")
    %timeit _ = predictor.predict_model_parameters(theta_halo)
    print("jitted:", end=" ")
    _ = predict_jit(theta_halo)
    %timeit _ = predict_jit(theta_halo)
Not jitted: 17 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jitted: 12.8 µs ± 998 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Predicting gas thermodynamics#

With a prediction for \(\vartheta_{\rm gas}\), we can use picasso.polytrop and picasso.nonthermal to predict gas thermodynamics (see Using the picasso analytical gas model). PicassoPredictor objects also offers a wrapper function that predicts all thermodynamic properties directly from an input vector \(\vartheta_{\rm halo}\) and a potential distribution. Assuming the halos above are NFW, we can predict their potential profiles:

r_R500c, phi = profs["r_R500"], profs["phi"]

Then, we can make predictions of gas thermodynamics for one halo:

rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)

Or for all halos at the same time (this function uses jax.vmap to vectorize the predictions):

rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo, phi, r_R500c, r_R500c / 2)
fig, axs = plt.subplots(1, 4, figsize=(13, 4))
for ax, q in zip(axs, [rho_g, P_tot, P_th, f_nt]):
    ax.loglog(r_R500c.T, q.T)
    ax.set_xlabel("$r / R_{500c}$")

axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
axs[2].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[3].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()
../_images/2a73c75f0cdb6583552dd1d3d7b53f165aaa8e7edd9c10afd891b29e4d47295a.png

Again, these functions can be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_gas_model)

    print("1 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)
    print("1 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)
    %timeit _ = predict_jit(theta_halo[0], phi[0], r_R500c[0], r_R500c[0] / 2)

    print("4 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo, phi, r_R500c, r_R500c / 2)
    print("4 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo, phi, r_R500c, r_R500c / 2)
    %timeit _ = predict_jit(theta_halo, phi, r_R500c, r_R500c / 2)
1 halo, not jitted: 21.5 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1 halo, jitted: 376 µs ± 16.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
4 halo, not jitted: 38.8 ms ± 244 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4 halo, jitted: 33.9 µs ± 772 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For an example of how these predictions can be included in a differentiable loss function and used to train a picasso model, see Training a picasso predictor.