JAX Integration#

This guide shows how to use phasecurvefit with JAX for faster computation, batching, and differentiation.

Basic Usage#

The library works seamlessly with JAX arrays—no special setup needed:

import jax.numpy as jnp
import phasecurvefit as pcf

# Phase-space data
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}

# Direct call works
result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)

The dict-based API is JAX PyTree compatible, so it works seamlessly with JAX transformations.

JIT Compilation#

Wrap the function to enable JIT compilation for faster repeated calls:

import jax
import jax.numpy as jnp
import phasecurvefit as pcf


# Wrap for JIT
@jax.jit
def order_stream(position, velocity):
    return pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)


# Data
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}

# First call: compiles; subsequent calls use cached version
result = order_stream(position, velocity)

Note: JIT is most beneficial when calling the same function repeatedly with similar shapes.

Vectorization (vmap)#

Process multiple streams in parallel:

import jax
import jax.numpy as jnp
import phasecurvefit as pcf
from jaxmore import vmap  # a better vmap

# Multiple streams
streams_pos = [
    {"x": jnp.array([0.0, 1.0, 2.0])},
    {"x": jnp.array([3.0, 4.0, 5.0])},
    {"x": jnp.array([6.0, 7.0, 8.0])},
]
streams_vel = [
    {"x": jnp.array([1.0, 1.1, 1.2])},
    {"x": jnp.array([1.3, 1.4, 1.5])},
    {"x": jnp.array([1.6, 1.7, 1.8])},
]

# Stack arrays
stacked_pos = {"x": jnp.stack([s["x"] for s in streams_pos])}
stacked_vel = {"x": jnp.stack([s["x"] for s in streams_vel])}

# Apply vmap over batch dimension
batched_fn = vmap(
    pcf.walk_local_flow,
    in_axes=(0, 0),
    static_kw={"start_idx": 0, "metric_scale": 1.0},
)
results = batched_fn(stacked_pos, stacked_vel)
print(f"Processed {results.indices.shape[0]} streams in parallel")

For a list of independent streams, use jax.tree.map:

streams = [
    {"q": {"x": jnp.array([0.0, 1.0])}, "p": {"x": jnp.array([1.0, 1.1])}},
    {"q": {"x": jnp.array([3.0, 4.0])}, "p": {"x": jnp.array([1.3, 1.4])}},
]

results = jax.tree.map(
    lambda sd: pcf.walk_local_flow(sd["q"], sd["p"], start_idx=0, metric_scale=1.0),
    streams,
    is_leaf=lambda x: isinstance(x, dict),
)

Differentiation#

Compute gradients with respect to parameters:

import jax
import jax.numpy as jnp
import phasecurvefit as pcf

position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}


# Define a scalar loss
def loss_fn(metric_scale):
    result = pcf.walk_local_flow(
        position, velocity, start_idx=0, metric_scale=metric_scale
    )
    return jnp.sum(result.indices.astype(jnp.float32))


# Compute gradient
grads = jax.grad(loss_fn)(jnp.array(1.5))

# Or get both value and gradient
value, grads = jax.value_and_grad(loss_fn)(jnp.array(1.5))
print(f"Loss: {value}, Gradient: {grads}")

Performance Tips#

Use JAX arrays: Convert NumPy arrays to JAX before calling:

import numpy as np

# NumPy data
pos_numpy = {"x": np.array([0.0, 1.0, 2.0, 3.0])}
vel_numpy = {"x": np.array([1.0, 1.1, 1.2, 1.3])}

# Convert to JAX
pos_jax = jax.tree.map(jnp.asarray, pos_numpy)
vel_jax = jax.tree.map(jnp.asarray, vel_numpy)
result = pcf.walk_local_flow(pos_jax, vel_jax, start_idx=0, metric_scale=1.0)

Combine JIT and vmap: For batched operations that run repeatedly, wrap both:

@jax.jit
def batch_order(stacked_pos, stacked_vel):
    return vmap(
        pcf.walk_local_flow,
        in_axes=(0, 0),
        static_kw={"start_idx": 0, "metric_scale": 1.0},
    )(stacked_pos, stacked_vel)

Hardware Acceleration#

The library works on GPU/TPU with no code changes:

import jax
import jax.numpy as jnp

position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.0, 1.0, 1.0])}

# Check available devices
devices = jax.devices()
print(f"Available devices: {devices}")

# Computation automatically runs on GPU/TPU if available
result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)

Debugging Tips#

Disable JIT: For easier debugging, disable JIT compilation:

import jax

with jax.disable_jit():
    result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)

Check shapes: Verify array shapes in dicts:

import jax

print(jax.tree.map(lambda x: x.shape, position))

See Also#