phasecurvefit
π Get Started#
phasecurvefit is a Python library for constructing a single, ordered walk through phase-space data. It was originally built for stellar stream simulations but is general-purpose and applies to any dataset where you want to order observations by proximity and momentum in phase-space.
The core approach combines:
Spatial proximity: Finding nearby points in position space
Velocity momentum: Preferring points that align with the current velocity direction
This is particularly useful for coherent trajectories in phase-space, such as stellar streams, but works well for many other ordered-walk problems.
Installation#
pip install phasecurvefit[all]
where βallβ enables unit support (through unxt) and kdtree support through jaxkd.
uv add phasecurvefit --extra all
where βallβ enables unit support (through unxt) and kdtree support through jaxkd.
To install the latest development version of phasecurvefit directly from the
GitHub repository, use uv:
uv add git+https://github.com/GalacticDynamics/phasecurvefit.git@main
You can customize the branch by replacing main with any other branch name.
To build phasecurvefit from source, clone the repository and install it with uv:
cd /path/to/parent
git clone https://github.com/GalacticDynamics/phasecurvefit.git
cd phasecurvefit
uv pip install -e . # editable mode
Quick Example#
import jax
import jax.numpy as jnp
import phasecurvefit as pcf
# Create phase-space observations as dictionaries
pos = {
"x": jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]),
"y": jnp.array([0.0, 0.5, 1.0, 1.5, 2.0]),
}
vel = {
"x": jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]),
"y": jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]),
}
# Order observations using walk_local_flow
config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=3))
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
# Train autoencoder for gap filling
key = jax.random.key(0)
normalizer = pcf.nn.StandardScalerNormalizer(pos, vel)
autoencoder = pcf.nn.PathAutoencoder.make(
normalizer, gamma_range=result.gamma_range, key=key
)
train_cfg = pcf.nn.TrainingConfig(
n_epochs_encoder=100, n_epochs_both=50, show_pbar=False
)
result, _, _ = pcf.nn.train_autoencoder(autoencoder, result, config=train_cfg, key=key)
print(result.indices) # Array([0, 1, 2, 3, 4])
Features#
β JAX-native: Full support for JIT compilation, vectorization, and auto-differentiation
β High performance: Optimized with
jax.lax.while_loopfor speedβ Gap filling: Autoencoder neural network interpolates skipped tracers
β Flexible: Works in any number of dimensions
β Type-safe: Full type annotations with
jaxtypingβ Well-tested: Comprehensive test suite with property-based testing
How It Works#
Localflowwalk constructs a single ordered walk through your phase-space data by iteratively selecting the nearest next point based on:
Current position: Where you are in the walk
Candidate points: Remaining unvisited observations
Distance metric: A configurable function that scores proximity
Termination criteria: Optional constraints on walk length or distance thresholds
The library ships with multiple built-in metrics (e.g., momentum-weighted, spatial-only), and you can implement custom metrics for domain-specific use cases. See the Metrics Guide for full details and examples.
For the mathematical background on momentum-weighted ordering, refer to the NN+p paper.
Configuration Options#
metric: Distance metric to use (default:AlignedMomentumDistanceMetric). Determines how βclosenessβ is computed. See Metrics Guide.metric_scale: Scale parameter for distance metrics. Interpretation depends on the metric:AlignedMomentumDistanceMetric: momentum weight (distance units)FullPhaseSpaceDistanceMetric: time scale for velocity-to-position conversionSpatialDistanceMetric: unused (can be any value)
max_dist: Maximum allowed distance to the next point. Stops the walk if no unvisited point is closer.n_max: Maximum number of points to include in the walk (caps walk length).start_idx: Starting index in the data (default: 0).terminate_indices: Set of indices where the walk should stop.
strategy: Neighbor query strategy instance. Options:
- BruteForce() (default): compute distances to all points
- KDTree(k=...): spatial KD-tree prefiltering, then metric selection
- Install optional dependency: uv add phasecurvefit[kdtree]
- Uses jaxkd
Example using KD-tree (requires jaxkd):
import phasecurvefit as pcf
config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=2))
result = pcf.walk_local_flow(pos, vel, config=config)
Data Format#
Phase-space data uses raw Python dictionaries for maximum performance and JAX compatibility:
import jax.numpy as jnp
# Position dictionary: coordinate names β arrays
position = {
"x": jnp.array([0.0, 1.0, 2.0]),
"y": jnp.array([0.0, 0.5, 1.0]),
"z": jnp.array([0.0, 0.1, 0.2]),
}
# Velocity dictionary: same keys β velocity components
velocity = {
"x": jnp.array([1.0, 1.0, 1.0]),
"y": jnp.array([0.5, 0.5, 0.5]),
"z": jnp.array([0.0, 0.0, 0.0]),
}
This dict-based API is designed for:
Efficient JAX tree operations via
jax.tree.mapSeamless integration with JAX transformations (
jit,vmap,grad)Minimal overhead in hot loops
Citation#
The core algorithm originates from Nibauer et al. (2022). If you use momentum-weighted ordering or reference the original work in your research, please cite:
@article{nibauer2022charting,
title={Charting Galactic Accelerations with Stellar Streams and Machine Learning},
author={Nibauer, Jacob and others},
journal={arXiv preprint arXiv:2201.12042},
year={2022}
}
If you use phasecurvefit with custom metrics or for general phase-space ordering, please cite this package directly (check the GitHub repository for the latest citation format).
Next Steps#
Get up and running in minutes
Explore built-in and custom metrics
Interpolate skipped observations
Full API documentation
Optimize with JIT, vmap, and grad
Interactive tutorials with Jupyter notebooks