Algorithm Details#
This page explains the mathematical foundations and implementation details of the phase flow walking algorithm.
Mathematical Foundation#
Walking Algorithm (Metric-Agnostic)#
The walking algorithm is metric-agnostic. It only requires a distance (or score) function that compares the current point to candidate points. The algorithm then:
Computes distances to all unvisited points
Selects the minimum distance
Advances the walk
Stops early when termination criteria are met
The specific distance metric is pluggable and can be replaced without changing the walking logic.
Distance Metrics (Pluggable)#
The default metric is AlignedMomentumDistanceMetric, which combines spatial proximity with momentum alignment:
where:
\(d_0 = \|\mathbf{p}_{\text{candidate}} - \mathbf{p}_{\text{current}}\|\) is the Euclidean distance in position space
\(\cos\theta\) is the cosine similarity between the velocity and the direction to the candidate
\(\lambda \geq 0\) is the momentum weight parameter
Other metrics are available (e.g., spatial-only), and custom metrics can be supplied. See Metrics for the full list and extension guide.
Breaking Down the Aligned-Momentum Components#
Spatial Distance (\(d_0\))#
The Euclidean distance between two points in position space:
This measures how far apart the points are spatially.
Momentum Term (\(\lambda \cdot (1 - \cos\theta)\))#
The momentum term penalizes points that are not in the direction of the current velocity:
Unit direction vector from current to candidate: $\(\hat{\mathbf{u}}_{\text{dir}} = \frac{\mathbf{p}_{\text{candidate}} - \mathbf{p}_{\text{current}}}{d_0}\)$
Unit velocity vector: $\(\hat{\mathbf{v}} = \frac{\mathbf{v}_{\text{current}}}{\|\mathbf{v}_{\text{current}}\|}\)$
Cosine similarity: $\(\cos\theta = \hat{\mathbf{u}}_{\text{dir}} \cdot \hat{\mathbf{v}}\)$
Momentum penalty:
When \(\cos\theta = 1\) (aligned): penalty = 0
When \(\cos\theta = 0\) (perpendicular): penalty = \(\lambda\)
When \(\cos\theta = -1\) (opposite): penalty = \(2\lambda\)
Effect of \(\lambda\) in the Default Metric#
The momentum weight \(\lambda\) controls the balance between spatial and momentum considerations:
\(\lambda = 0\): Pure nearest neighbor (spatial only) $\(d = d_0\)$ Selects the closest point regardless of velocity direction.
\(\lambda \to \infty\): Pure momentum (directional only) $\(d \approx \lambda \cdot (1 - \cos\theta)\)$ Strongly favors points in the velocity direction, even if far away.
\(\lambda \approx 1\): Balanced Both spatial proximity and momentum alignment matter equally.
Physical Interpretation (Default Metric)#
For stellar streams, the momentum term encodes the fact that stars in a stream follow coherent trajectories. A star is more likely to be the next member of the stream if:
Itβs spatially close (\(d_0\) is small)
It lies in the direction the stream is flowing (\(\cos\theta \approx 1\))
The algorithm naturally traces the stream by following the βflowβ of velocities through phase-space.
Algorithm Pseudocode#
Input:
position: dict[str, Array] # N points with shape (N,) per component
velocity: dict[str, Array] # N points with shape (N,) per component
start_idx: int # starting index
metric_scale: float # scale parameter for distance metric
max_dist: float | None # optional gap detection
terminate_indices: Set[int] | None # optional termination indices
n_max: int | None # optional iteration limit
Output:
indices: tuple[int, ...]
skipped_indices: tuple[int, ...]
Procedure:
Initialize:
ordered_arr β [-1, -1, ..., -1] # Array to store ordered indices
visited_mask β [1.0, 1.0, ..., 1.0] # 0 = visited, 1 = unvisited
visited_mask[start_idx] β 0.0
ordered_arr[0] β start_idx
current_idx β start_idx
step β 1
While (step < n_max) AND (current_idx not in terminate_indices):
# Get scalar phase-space data at current index
current_pos β {key: position[key][current_idx] for key in position}
current_vel β {key: velocity[key][current_idx] for key in velocity}
# Vectorized computation to all points (vmap over array index)
distances[i] β metric(current_pos, current_vel, {key: position[key][i] for key}, {key: velocity[key][i] for key})
# Mask visited points with infinity
distances_masked[i] β visited_mask[i] > 0.5 ? distances[i] : infinity
# Find nearest unvisited neighbor
min_dist β min(distances_masked)
best_idx β argmin(distances_masked)
# Check early termination
if min_dist > max_dist:
Break # Gap detected, stop algorithm
# Update state
ordered_arr[step] β best_idx
visited_mask[best_idx] β 0.0
current_idx β best_idx
step β step + 1
# Extract valid indices
skipped β {i : visited_mask[i] > 0.5 for i in 0..N-1}
Return (ordered_arr[:step], skipped)
Gap Filling with Autoencoder#
Due to the momentum condition, the walk algorithm inevitably skips some tracers. To assign \(\gamma\) values to these skipped particles, an autoencoder neural network can interpolate based on phase-space location:
Interpolation Network: Learns \((x, v) \rightarrow (\gamma, p)\) from ordered tracers
Param-Net: Reconstructs positions from \(\gamma\) values
Momentum condition: Ensures alignment with velocity field
See Autoencoder for Gap Filling for details.
Extensions and Variants#
The current implementation supports:
Arbitrary dimensions: Works in 1D, 2D, 3D, or higher
Gap detection:
max_distparameterConditional termination:
terminate_indicesparameterLimited search:
n_maxparameterGap filling: Autoencoder neural network for skipped tracers
Reverse walks:
direction="backward"parameter to trace streams backwards by negating velocitiesBidirectional walks:
combine_results()to trace streams in both directions simultaneously
Reverse Walks#
The direction parameter enables walking through phase-space in the opposite
direction by negating the velocity vectors. This is useful for:
Tracing stellar streams backwards from the tidal tail towards the progenitor
Exploring both branches of a bifurcated stream
Testing bidirectional connectivity in phase-space
To use backward walks:
import jax.numpy as jnp
import phasecurvefit as pcf
# Create sample data
pos = {
"x": jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]),
"y": jnp.array([0.0, 0.5, 1.0, 1.5, 2.0]),
}
vel = {
"x": jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]),
"y": jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]),
}
# Forward walk (default)
result_forward = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0)
# Reverse walk from the same starting point
result_reverse = pcf.walk_local_flow(
pos, vel, start_idx=0, metric_scale=1.0, direction="backward"
)
The negated velocities ensure that the algorithm follows the stream in the opposite direction. This is mathematically equivalent to physically reversing time in the dynamical system.
Combining Forward and Reverse Walks#
For stellar streams that extend in both directions from a starting point (e.g.,
from a progenitor or disruption point), the combine_results() function
combines the results of two separate walks into a single coherent ordering:
# Run forward and reverse walks separately
result_forward = pcf.walk_local_flow(
pos, vel, start_idx=2, metric_scale=1.0, direction="forward"
)
result_reverse = pcf.walk_local_flow(
pos, vel, start_idx=2, metric_scale=1.0, direction="backward"
)
# Combine the results
result = pcf.combine_results(result_forward, result_reverse)
# Result indices are ordered: [reverse tail] β [start] β [forward tail]
This can be simplified to:
result = pcf.walk_local_flow(pos, vel, start_idx=2, metric_scale=1.0, direction="both")
This is particularly useful for:
Complete stream tracing: Get a full ordering from one tidal tail to the other
Progenitor analysis: Start from the progenitor and trace both streams
Bifurcated streams: Explore complex geometries extending in multiple directions
Stream validation: Verify connectivity in both directions simultaneously
The combination strategy ensures spatial coherence by placing the starting point
near the center with the reverse tail on one end and the forward tail on the
other. Since you run the walks separately, you have full control over the
parameters for each direction (e.g., different max_dist thresholds).
Spatial Interpolation from Ordering Parameter#
The WalkLocalFlowResult object provides a __call__ method that enables
efficient linear interpolation of spatial positions along the walk ordering
from an ordering parameter \(\gamma \in [0, 1]\):
import jax
import jax.numpy as jnp
import phasecurvefit as pcf
# Create sample data
pos = {
"x": jnp.linspace(0, 10, 50),
"y": jnp.sin(jnp.linspace(0, 2 * jnp.pi, 50)),
}
vel = {
"x": jnp.ones(50),
"y": 2 * jnp.pi * jnp.cos(jnp.linspace(0, 2 * jnp.pi, 50)) / (2 * jnp.pi),
}
# Run walk algorithm
result = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0)
# Interpolate positions at specific gamma values
gamma_values = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0])
interpolated_pos = result(gamma_values)
print("Interpolated x:", interpolated_pos["x"]) # Shape (5,)
print("Interpolated y:", interpolated_pos["y"]) # Shape (5,)
The interpolator is JAX-compatible and works with all JAX transformations:
JIT Compilation:
@jax.jit
def interpolate_position(gamma):
return result(gamma)
# Compile once, run efficiently
pos_jitted = interpolate_position(0.5)
Vectorization (vmap):
# Interpolate batch of gamma values
@jax.jit
def interpolate_batch(gamma_array):
return jax.vmap(result)(gamma_array)
gamma_batch = jnp.linspace(0, 1, 100)
pos_batch = interpolate_batch(gamma_batch)
print("Shape:", pos_batch["x"].shape) # (100,)
Automatic Differentiation (grad):
# Compute gradients with respect to gamma
def loss_fn(gamma):
pos = result(gamma)
return jnp.sum(pos["x"] ** 2 + pos["y"] ** 2)
grad_fn = jax.grad(loss_fn)
gradient = grad_fn(0.5)
Composition of Transformations:
# JIT + vmap + grad combination
@jax.jit
def compute_gradients(gamma_array):
return jax.vmap(jax.grad(loss_fn))(gamma_array)
gradients = compute_gradients(jnp.linspace(0, 1, 50))
The interpolator works by:
Mapping \(\gamma \in [0, 1]\) to indices in the visited observation sequence
Finding lower and upper neighboring visited observations
Computing linear interpolation weights
Interpolating each position component across all visited observations
Unvisited observations (marked with index -1) are automatically masked out without changing array shapes, making the interpolation efficient in JIT.
Potential future extensions:
Adaptive \(\lambda\) based on local stream properties
Multiple starting points with graph merging
Probabilistic variants for noisy data
GPU-specific optimizations for very large datasets
References#
Nibauer, J., et al. (2022). βCharting Galactic Accelerations with Stellar Streams and Machine Learning.β arXiv:2201.12042.