JAX Integration#
This guide shows how to use phasecurvefit with JAX for faster computation, batching, and differentiation.
Basic Usage#
The library works seamlessly with JAX arrays—no special setup needed:
import jax.numpy as jnp
import phasecurvefit as pcf
# Phase-space data
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}
# Direct call works
result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)
The dict-based API is JAX PyTree compatible, so it works seamlessly with JAX transformations.
JIT Compilation#
Wrap the function to enable JIT compilation for faster repeated calls:
import jax
import jax.numpy as jnp
import phasecurvefit as pcf
# Wrap for JIT
@jax.jit
def order_stream(position, velocity):
return pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)
# Data
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}
# First call: compiles; subsequent calls use cached version
result = order_stream(position, velocity)
Note: JIT is most beneficial when calling the same function repeatedly with similar shapes.
Vectorization (vmap)#
Process multiple streams in parallel:
import jax
import jax.numpy as jnp
import phasecurvefit as pcf
# Multiple streams
streams_pos = [
{"x": jnp.array([0.0, 1.0, 2.0])},
{"x": jnp.array([3.0, 4.0, 5.0])},
{"x": jnp.array([6.0, 7.0, 8.0])},
]
streams_vel = [
{"x": jnp.array([1.0, 1.1, 1.2])},
{"x": jnp.array([1.3, 1.4, 1.5])},
{"x": jnp.array([1.6, 1.7, 1.8])},
]
# Stack arrays
stacked_pos = {"x": jnp.stack([s["x"] for s in streams_pos])}
stacked_vel = {"x": jnp.stack([s["x"] for s in streams_vel])}
# Apply vmap over batch dimension
batched_fn = jax.vmap(
lambda q, p: pcf.walk_local_flow(q, p, start_idx=0, metric_scale=1.0),
in_axes=(0, 0),
)
results = batched_fn(stacked_pos, stacked_vel)
print(f"Processed {results.indices.shape[0]} streams in parallel")
For a list of independent streams, use jax.tree.map:
streams = [
{"q": {"x": jnp.array([0.0, 1.0])}, "p": {"x": jnp.array([1.0, 1.1])}},
{"q": {"x": jnp.array([3.0, 4.0])}, "p": {"x": jnp.array([1.3, 1.4])}},
]
results = jax.tree.map(
lambda sd: pcf.walk_local_flow(sd["q"], sd["p"], start_idx=0, metric_scale=1.0),
streams,
is_leaf=lambda x: isinstance(x, dict),
)
Differentiation#
Compute gradients with respect to parameters:
import jax
import jax.numpy as jnp
import phasecurvefit as pcf
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.1, 1.2, 1.3])}
# Define a scalar loss
def loss_fn(metric_scale):
result = pcf.walk_local_flow(
position, velocity, start_idx=0, metric_scale=metric_scale
)
return jnp.sum(result.indices.astype(jnp.float32))
# Compute gradient
grads = jax.grad(loss_fn)(jnp.array(1.5))
# Or get both value and gradient
value, grads = jax.value_and_grad(loss_fn)(jnp.array(1.5))
print(f"Loss: {value}, Gradient: {grads}")
Performance Tips#
Use JAX arrays: Convert NumPy arrays to JAX before calling:
import numpy as np
# NumPy data
pos_numpy = {"x": np.array([0.0, 1.0, 2.0, 3.0])}
vel_numpy = {"x": np.array([1.0, 1.1, 1.2, 1.3])}
# Convert to JAX
pos_jax = jax.tree.map(jnp.asarray, pos_numpy)
vel_jax = jax.tree.map(jnp.asarray, vel_numpy)
result = pcf.walk_local_flow(pos_jax, vel_jax, start_idx=0, metric_scale=1.0)
Combine JIT and vmap: For batched operations that run repeatedly, wrap both:
@jax.jit
def batch_order(stacked_pos, stacked_vel):
return jax.vmap(
lambda q, p: pcf.walk_local_flow(q, p, start_idx=0, metric_scale=1.0),
in_axes=(0, 0),
)(stacked_pos, stacked_vel)
Hardware Acceleration#
The library works on GPU/TPU with no code changes:
import jax
import jax.numpy as jnp
position = {"x": jnp.array([0.0, 1.0, 2.0, 3.0])}
velocity = {"x": jnp.array([1.0, 1.0, 1.0, 1.0])}
# Check available devices
devices = jax.devices()
print(f"Available devices: {devices}")
# Computation automatically runs on GPU/TPU if available
result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)
Debugging Tips#
Disable JIT: For easier debugging, disable JIT compilation:
import jax
with jax.disable_jit():
result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)
Check shapes: Verify array shapes in dicts:
import jax
print(jax.tree.map(lambda x: x.shape, position))
See Also#
Algorithm Details - Implementation specifics
JAX Documentation - Official JAX guide
API Reference - Complete API documentation