Learning a Stream with an Autoencoder

Learning a Stream with an Autoencoder#

Open In Colab

[1]:
import pathlib
import pickle

import galax.coordinates as gc
import galax.dynamics as gd
import galax.potential as gp
import jax
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np

import quaxed.numpy as jnp
import unxt as u

import phasecurvefit as pcf
[2]:
key = jr.key(201030)

Data: a Stream#

[3]:
usys = u.unitsystems.galactic

# Progenitor Parameters
prog_w0 = gc.PhaseSpaceCoordinate(
    q=u.Q([10, 3, 5], "kpc"), p=u.Q([-4, 100, 4], "km/s"), t=u.Q(0.0, "Myr")
)
prog_mass = u.Quantity(2.5e4, "Msun")

# Stream Distribution Function
df = gd.FardalStreamDF()

# Potential
pot = gp.LMJ09LogarithmicPotential(
    v_c=u.Q(150, "km/s"),
    r_s=u.Q(2, "kpc"),
    q1=1.0,
    q2=1.3,
    q3=0.9,
    phi=u.Q(0, "deg"),
    units=usys,
)

# Mock stream generator (galax)
mockgen = gd.MockStreamGenerator(df, pot)
[4]:
mockstream_path = pathlib.Path("mockstream.pkl")
mockstream_path.parent.mkdir(parents=True, exist_ok=True)

try:
    if mockstream_path.exists():
        with mockstream_path.open("rb") as handle:
            mockstream, prog = pickle.load(handle)  # noqa: S301
    else:
        raise FileNotFoundError(f"Missing {mockstream_path}")  # noqa: EM102, TRY003, TRY301
except Exception as exc:  # noqa: BLE001
    print(f"Loading failed ({exc!r}); running the simulation.")
    mockstream, prog = mockgen.run(
        rng,
        u.Q(jnp.linspace(0, 4, 4_000), "Gyr"),
        prog_w0,
        prog_mass,
    )
    mockstream = jax.block_until_ready(mockstream)
    with mockstream_path.open("wb") as handle:
        pickle.dump((mockstream, prog), handle)
[5]:
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111)
im = ax.scatter(
    np.array(mockstream.q.x),
    np.array(mockstream.q.y),
    s=1,
    c=jnp.linspace(1, 0, len(mockstream.q)),
    cmap="RdYlBu_r",
)
plt.colorbar(im, ax=ax, label="data order", shrink=0.5)
ax.set_xlabel(r"$x$ [kpc]", fontsize=14)
ax.set_ylabel(r"$y$ [kpc]", fontsize=14)
ax.set_title("Mock Stellar Stream in LMJ09 Potential", fontsize=20)
plt.show();
../_images/tutorials_stream_autoencoder_7_0.png

Fitting \(\gamma, \vec{x}\) for the Whole Stream#

[6]:
# Shuffle the data
key, subkey = jr.split(key)
order = jr.permutation(subkey, jnp.arange(len(mockstream.q.x)))

qs = {k: getattr(mockstream.q, k)[order] for k in mockstream.q.components}
ps = {k: getattr(mockstream.p, k)[order] for k in mockstream.p.components}
[7]:
# Determine the starting index as the point closest to the progenitor
# Note the index must be static
start_idx = int(np.argmin(jnp.linalg.norm(mockstream.q[order] - prog.q, axis=1)))

# Walk configuration
config = pcf.WalkConfig(
    strategy=pcf.strats.KDTree(k=50),
    metric=pcf.metrics.AlignedMomentumDistanceMetric(),
)
metric_scale = u.Q(100, "kpc")
max_dist = u.Q(3, "kpc")


# Perform walk
walkresult = pcf.walk_local_flow(
    qs,
    ps,
    start_idx=start_idx,
    metric_scale=metric_scale,
    max_dist=max_dist,
    config=config,
    direction="both",
    metadata=pcf.StateMetadata(usys=usys),
)
print(walkresult.gamma_range)

# Train autoencoder
key, model_key, train_key = jr.split(key, 3)
normalizer = pcf.nn.StandardScalerNormalizer(qs, ps)
model = pcf.nn.PathAutoencoder.make(
    normalizer, gamma_range=walkresult.gamma_range, key=model_key
)
result, opt_state, losses = pcf.nn.train_autoencoder(model, walkresult, key=train_key)
(-1.0, 1.0)
[8]:
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
im = ax.scatter(qs["x"], qs["y"], s=1, c="k")
ax.scatter(
    [qs["x"][start_idx]], [qs["y"][start_idx]], s=100, c="green", label="Start Point"
)

ordering = walkresult.ordering
walk_qs = {k: v[ordering] for k, v in walkresult.positions.items()}
timeline = np.linspace(0, 1, len(ordering))
ax.scatter(walk_qs["x"], walk_qs["y"], s=50, c=timeline, cmap="RdYlBu")
ax.plot(walk_qs["x"], walk_qs["y"], c="k", lw=3, ls="--")

plt.colorbar(im, ax=ax, label="data order")
ax.set_xlabel(r"$x$ [kpc]", fontsize=14)
ax.set_ylabel(r"$y$ [kpc]", fontsize=14)
ax.set_title("Mock Stellar Stream in Flattened Logarithmic Potential", fontsize=16)
plt.show();
../_images/tutorials_stream_autoencoder_11_0.png
[9]:
# Plot training losses
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(np.asarray(losses), linewidth=2)
ax.set_xlabel("Epoch", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("Training Loss Over Epochs", fontsize=14)
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print(f"Final loss: {losses[-1]:.6f}")
../_images/tutorials_stream_autoencoder_12_0.png
Final loss: 0.264630
[10]:
# Visualize the ML-filled path in 2D with mean path prediction
fig, ax = plt.subplots(figsize=(12, 10))

# all_gamma, all_probs = jax.vmap(model.encoder)(all_ws)
all_gamma, all_probs = result.model.encode(walkresult.positions, walkresult.velocities)
rejected_membership = all_probs < 0.5
qs_pred = result(jnp.linspace(-0.95, 0.95, 1_000))

# Plot all points with gradient coloring
im = ax.scatter(
    np.asarray(qs["x"]),
    np.asarray(qs["y"]),
    s=50,
    c=np.asarray(all_gamma),
    cmap="RdYlBu",
    alpha=0.8,
    label="Stream members",
)

# Plot predicted mean path
ax.plot(
    np.asarray(qs_pred["x"]),
    np.asarray(qs_pred["y"]),
    c="k",
    lw=3,
    label="Predicted mean path",
)

# Mark rejected samples in cyan
ax.scatter(
    np.asarray(qs["x"][rejected_membership]),
    np.asarray(qs["y"][rejected_membership]),
    s=100,
    c="cyan",
    alpha=1.0,
    marker="o",
    edgecolors="black",
    linewidths=0.5,
    label="Rejected samples",
)

# Mark start point
ax.scatter(
    np.asarray(qs["x"][start_idx]),
    np.asarray(qs["y"][start_idx]),
    s=200,
    c="tab:green",
    marker="*",
    label="Start point",
    linewidths=2,
    zorder=5,
)

ax.set_xlabel(r"$x$ [kpc]", fontsize=14)
ax.set_ylabel(r"$y$ [kpc]", fontsize=14)
ax.set_title("ML Ordering of Stellar Stream with Predicted Mean Path", fontsize=16)
fig.colorbar(im, ax=ax, label="Gamma (ordering parameter)")
ax.legend(loc="upper right", fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show();
../_images/tutorials_stream_autoencoder_13_0.png
[ ]: