Quickstart Guide#

Get started with phasecurvefit in 5 minutes!

Installation#

Install phasecurvefit using pip or uv:

pip install phasecurvefit
uv add phasecurvefit

Basic Usage#

1. Import the library#

import jax.numpy as jnp
import phasecurvefit as pcf

2. Prepare your phase-space data#

Phase-space data is represented as two dictionaries:

  • position: Maps coordinate names to position arrays

  • velocity: Maps coordinate names to velocity arrays

# Example: 2D stream
position = {
    "x": jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]),
    "y": jnp.array([0.0, 0.5, 1.0, 1.5, 2.0]),
}

velocity = {
    "x": jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]),
    "y": jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]),
}

3. Run the algorithm#

result = pcf.walk_local_flow(
    position,
    velocity,
    start_idx=0,  # Start from first point
    metric_scale=1.0,  # Metric-dependent scale parameter
)

print(result.ordering)
# Array([0, 1, 2, 3, 4])

4. Extract ordered data#

Use the convenience function to get reordered arrays:

ordered_pos, ordered_vel = pcf.order_w(result)

print(ordered_pos["x"])
# Array([0., 1., 2., 3., 4.])

Understanding the Result#

The walk_local_flow function returns a WalkLocalFlowResult NamedTuple with:

  • ordering: Array of indices of the discovered order

  • position: Original position dictionary

  • velocity: Original velocity dictionary

Adjusting the Metric Scale#

The metric_scale parameter controls how the algorithm weighs different aspects of the data. Its interpretation depends on which distance metric you’re using:

# Pure nearest neighbor (spatial only) - metric_scale ignored
result_spatial = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=0.0)

# Balanced (default)
result_balanced = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)

# Higher metric_scale value (interpretation metric-dependent)
result_momentum = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=5.0)

Walking in Reverse#

Use the direction parameter to trace streams backwards by negating the velocity vectors:

# Default: forward walk following the velocity direction
result_forward = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=1.0)

# Reverse: walk against the velocity direction
result_reverse = pcf.walk_local_flow(
    position, velocity, start_idx=0, metric_scale=1.0, direction="backward"
)

This is useful for tracing stellar streams from the tidal tail back towards the progenitor.

Configuring the Query#

Use WalkConfig to configure the distance metric and query strategy:

from phasecurvefit.metrics import AlignedMomentumDistanceMetric

# Configure with aligned momentum metric and KD-tree strategy
config = pcf.WalkConfig(
    metric=AlignedMomentumDistanceMetric(),
    strategy=pcf.strats.KDTree(k=5),  # Only 5 points in the fake dataset
)

result = pcf.walk_local_flow(
    position, velocity, config=config, start_idx=0, metric_scale=1.0
)

Handling Gaps with max_dist#

Use max_dist to stop when there’s a gap in the data:

# Stop if next nearest point is more than 2 units away
result = pcf.walk_local_flow(
    position,
    velocity,
    start_idx=0,
    metric_scale=1.0,
    max_dist=2.0,
)

# Check if any points were skipped
if result.n_skipped > 0:
    print(f"Skipped {result.n_skipped} points")

Working in 3D#

The algorithm works in any number of dimensions:

# 3D helix
t = jnp.linspace(0, 4 * jnp.pi, 100)
position = {
    "x": jnp.cos(t),
    "y": jnp.sin(t),
    "z": t / (2 * jnp.pi),
}

velocity = {
    "x": -jnp.sin(t),
    "y": jnp.cos(t),
    "z": jnp.ones_like(t) / (2 * jnp.pi),
}

result = pcf.walk_local_flow(position, velocity, start_idx=0, metric_scale=2.0)

Bidirectional Walks (Forward and Reverse)#

For streams that extend in both directions from a starting point, run forward and reverse walks separately, then combine them:

# Run forward walk from starting point
result = pcf.walk_local_flow(
    position, velocity, start_idx=2, metric_scale=1.0, direction="both"
)

# Get the combined ordered indices
print(result.indices)  # Indices ordered from reverse tail through start to forward tail

# Extract the ordered positions and velocities
ordered_pos, ordered_vel = pcf.order_w(result)

This is particularly useful for:

  • Tracing complete stellar streams from a central progenitor

  • Exploring both tidal tails simultaneously

  • Verifying stream connectivity in both directions

  • Having different parameters for forward vs reverse walks

JAX Integration#

The algorithm is fully compatible with JAX transformations:

JIT Compilation#

from jax import jit


@jit
def order_stream(pos, vel):
    return pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0)


result = order_stream(position, velocity)

Vectorization#

from jax import vmap

# Position and velocity for a single stream
position = {
    "x": jnp.array([0.0, 1.0, 2.0, 3.0]),
    "y": jnp.array([0.0, 0.5, 1.0, 1.5]),
}
velocity = {
    "x": jnp.array([1.0, 1.0, 1.0, 1.0]),
    "y": jnp.array([0.5, 0.5, 0.5, 0.5]),
}

# To process multiple streams, use vmap with careful handling of dictionaries
# See the JAX integration guide for detailed examples

Next Steps#