phasecurvefit

πŸš€ Get Started#

phasecurvefit is a Python library for constructing a single, ordered walk through phase-space data. It was originally built for stellar stream simulations but is general-purpose and applies to any dataset where you want to order observations by proximity and momentum in phase-space.

The core approach combines:

  • Spatial proximity: Finding nearby points in position space

  • Velocity momentum: Preferring points that align with the current velocity direction

This is particularly useful for coherent trajectories in phase-space, such as stellar streams, but works well for many other ordered-walk problems.


Installation#

pip install phasecurvefit[all]

where β€œall” enables unit support (through unxt) and kdtree support through jaxkd.

uv add phasecurvefit --extra all

where β€œall” enables unit support (through unxt) and kdtree support through jaxkd.

To install the latest development version of phasecurvefit directly from the GitHub repository, use uv:

uv add git+https://github.com/GalacticDynamics/phasecurvefit.git@main

You can customize the branch by replacing main with any other branch name.

To build phasecurvefit from source, clone the repository and install it with uv:

cd /path/to/parent
git clone https://github.com/GalacticDynamics/phasecurvefit.git
cd phasecurvefit
uv pip install -e .  # editable mode

Quick Example#

import jax
import jax.numpy as jnp
import phasecurvefit as pcf

# Create phase-space observations as dictionaries
pos = {
    "x": jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]),
    "y": jnp.array([0.0, 0.5, 1.0, 1.5, 2.0]),
}
vel = {
    "x": jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]),
    "y": jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]),
}

# Order observations using walk_local_flow
config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=3))
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)

# Train autoencoder for gap filling
key = jax.random.key(0)
normalizer = pcf.nn.StandardScalerNormalizer(pos, vel)
autoencoder = pcf.nn.PathAutoencoder.make(
    normalizer, gamma_range=result.gamma_range, key=key
)

train_cfg = pcf.nn.TrainingConfig(
    n_epochs_encoder=100, n_epochs_both=50, show_pbar=False
)
result, _, _ = pcf.nn.train_autoencoder(autoencoder, result, config=train_cfg, key=key)

print(result.indices)  # Array([0, 1, 2, 3, 4])

Features#

  • βœ… JAX-native: Full support for JIT compilation, vectorization, and auto-differentiation

  • βœ… High performance: Optimized with jax.lax.while_loop for speed

  • βœ… Gap filling: Autoencoder neural network interpolates skipped tracers

  • βœ… Flexible: Works in any number of dimensions

  • βœ… Type-safe: Full type annotations with jaxtyping

  • βœ… Well-tested: Comprehensive test suite with property-based testing

How It Works#

Localflowwalk constructs a single ordered walk through your phase-space data by iteratively selecting the nearest next point based on:

  1. Current position: Where you are in the walk

  2. Candidate points: Remaining unvisited observations

  3. Distance metric: A configurable function that scores proximity

  4. Termination criteria: Optional constraints on walk length or distance thresholds

The library ships with multiple built-in metrics (e.g., momentum-weighted, spatial-only), and you can implement custom metrics for domain-specific use cases. See the Metrics Guide for full details and examples.

For the mathematical background on momentum-weighted ordering, refer to the NN+p paper.

Configuration Options#

  • metric: Distance metric to use (default: AlignedMomentumDistanceMetric). Determines how β€œcloseness” is computed. See Metrics Guide.

  • metric_scale: Scale parameter for distance metrics. Interpretation depends on the metric:

    • AlignedMomentumDistanceMetric: momentum weight (distance units)

    • FullPhaseSpaceDistanceMetric: time scale for velocity-to-position conversion

    • SpatialDistanceMetric: unused (can be any value)

  • max_dist: Maximum allowed distance to the next point. Stops the walk if no unvisited point is closer.

  • n_max: Maximum number of points to include in the walk (caps walk length).

  • start_idx: Starting index in the data (default: 0).

  • terminate_indices: Set of indices where the walk should stop.

strategy: Neighbor query strategy instance. Options: - BruteForce() (default): compute distances to all points - KDTree(k=...): spatial KD-tree prefiltering, then metric selection - Install optional dependency: uv add phasecurvefit[kdtree] - Uses jaxkd

Example using KD-tree (requires jaxkd):

import phasecurvefit as pcf

config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=2))
result = pcf.walk_local_flow(pos, vel, config=config)

Data Format#

Phase-space data uses raw Python dictionaries for maximum performance and JAX compatibility:

import jax.numpy as jnp

# Position dictionary: coordinate names β†’ arrays
position = {
    "x": jnp.array([0.0, 1.0, 2.0]),
    "y": jnp.array([0.0, 0.5, 1.0]),
    "z": jnp.array([0.0, 0.1, 0.2]),
}

# Velocity dictionary: same keys β†’ velocity components
velocity = {
    "x": jnp.array([1.0, 1.0, 1.0]),
    "y": jnp.array([0.5, 0.5, 0.5]),
    "z": jnp.array([0.0, 0.0, 0.0]),
}

This dict-based API is designed for:

  • Efficient JAX tree operations via jax.tree.map

  • Seamless integration with JAX transformations (jit, vmap, grad)

  • Minimal overhead in hot loops

Citation#

The core algorithm originates from Nibauer et al. (2022). If you use momentum-weighted ordering or reference the original work in your research, please cite:

@article{nibauer2022charting,
  title={Charting Galactic Accelerations with Stellar Streams and Machine Learning},
  author={Nibauer, Jacob and others},
  journal={arXiv preprint arXiv:2201.12042},
  year={2022}
}

If you use phasecurvefit with custom metrics or for general phase-space ordering, please cite this package directly (check the GitHub repository for the latest citation format).

Next Steps#

Quickstart

Get up and running in minutes

Quickstart Guide
Distance Metrics

Explore built-in and custom metrics

Distance Metrics Guide
Neural Network Gap Filling

Interpolate skipped observations

Autoencoder for Gap Filling
API Reference

Full API documentation

API Reference
JAX Integration

Optimize with JIT, vmap, and grad

JAX Integration
Examples

Interactive tutorials with Jupyter notebooks

Tutorials

Indices and tables#