Distance Metrics Guide#
The walk_local_flow algorithm uses distance metrics to determine how to select
the next point in a phase-space trajectory. This guide explains the metric
system and shows how to use and create custom metrics.
Overview#
A distance metric defines how the algorithm measures “closeness” between the current point and candidate next points in phase-space. Different metrics enable different physical interpretations and behaviors.
Metrics are configured via WalkConfig, which composes a metric with a query
strategy (discussed in a separate guide):
import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import FullPhaseSpaceDistanceMetric
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}
config = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
Built-in Metrics#
SpatialDistanceMetric#
A position-only metric that computes pure Euclidean distance, completely ignoring velocity information.
Mathematical formulation:
where \(d_0\) is the Euclidean distance between positions. The metric_scale parameter is ignored.
When to use:
Velocity information is unreliable or unavailable
Pure spatial proximity is desired (e.g., spatial clustering)
Comparing against baseline nearest-neighbor approaches
Setting
metric_scale=0withAlignedMomentumDistanceMetricis equivalent, but this metric is more explicit
Usage:
from phasecurvefit.metrics import SpatialDistanceMetric
# Pure nearest-neighbor search in position space
config = pcf.WalkConfig(metric=SpatialDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=0.0)
AlignedMomentumDistanceMetric#
The Nearest Neighbors with Momentum (NN+p) metric from Nibauer et al. (2022). This is the default metric.
Mathematical formulation:
where:
\(d_0\) is the Euclidean distance between positions
\(\theta\) is the angle between the current velocity and the direction to the candidate point
\(\lambda\) is the momentum weight parameter
Physical interpretation:
This metric combines spatial proximity with velocity alignment. Points that lie along the current velocity direction receive lower penalties, making the algorithm favor coherent flows in phase-space.
Usage:
import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import AlignedMomentumDistanceMetric
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}
# Aligned momentum metric
config = pcf.WalkConfig(metric=AlignedMomentumDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
FullPhaseSpaceDistanceMetric#
A true 6D Euclidean distance metric in full phase-space, treating position and velocity symmetrically. This is the default metric.
Mathematical formulation:
where:
\(d_0\) is the Euclidean distance in position space
\(d_v\) is the Euclidean distance in velocity space
\(\tau\) is the time parameter (
metric_scale) that converts velocity differences to position units
Physical interpretation:
This metric computes true Euclidean distance in the 6-dimensional phase space by combining position and velocity differences. The parameter metric_scale (with time units) determines the relative weighting: for example, if positions are measured in kpc and velocities in kpc/Myr, then metric_scale in Myr converts velocity differences to kpc, creating a uniformly scaled phase space.
Unlike AlignedMomentumDistanceMetric, this metric has no directional bias from momentum alignment — it treats all directions in phase space equally.
When to use:
Position and velocity information are equally important
You want true 6D proximity without momentum direction bias
The natural time scale of the system is known
Comparing against full phase-space clustering methods
Usage:
from phasecurvefit.metrics import FullPhaseSpaceDistanceMetric
# Full 6D phase-space distance (this is the default)
# metric_scale represents a time scale (e.g., if pos ~ kpc, vel ~ kpc/Myr, metric_scale ~ Myr)
config = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
Comparison with momentum metric:
AlignedMomentumDistanceMetric: Directional — favors points along velocity directionFullPhaseSpaceDistanceMetric: Isotropic — treats all directions equallyBoth reduce to
SpatialDistanceMetricwhenmetric_scale=0
Creating Custom Metrics#
Custom metrics enable alternative distance calculations for specific use cases. For example, you might want:
Full 6D Cartesian distance in phase-space
Weighted combinations of position and velocity
Problem-specific distance measures
The AbstractDistanceMetric Interface#
All metrics must inherit from AbstractDistanceMetric and implement the __call__ method:
import equinox as eqx
from phasecurvefit.metrics import AbstractDistanceMetric
class CustomMetric(AbstractDistanceMetric):
"""Your custom distance metric."""
def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
"""Compute modified distances."""
# Your distance calculation here
...
Example: 6D Cartesian Metric#
Here’s a complete example of a metric that computes full 6D Cartesian distance:
import equinox as eqx
import jax
import jax.numpy as jnp
from phasecurvefit.metrics import AbstractDistanceMetric
class Full6DMetric(AbstractDistanceMetric):
"""6D Cartesian distance in phase-space.
Treats position and velocity on equal footing, with `metric_scale` serving as
a velocity-to-position scaling factor (units of time).
Distance formula:
d = sqrt(|Δr|² + (λ|Δv|)²)
where Δr is position difference and Δv is velocity difference.
"""
def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
# Compute position differences (vmap over N points)
pos_diff = jax.tree.map(jnp.subtract, positions, current_pos)
# Sum of squared position differences
pos_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, pos_diff)))
# Compute velocity differences (vmap over N points)
vel_diff = jax.tree.map(jnp.subtract, velocities, current_vel)
# Sum of squared velocity differences, weighted by metric_scale^2
vel_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, vel_diff)))
# Combined 6D distance
return jnp.sqrt(pos_dist_sq + (metric_scale**2) * vel_dist_sq)
# Use the custom metric via WalkConfig
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}
config = pcf.WalkConfig(metric=Full6DMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
Example: Weighted Position Metric#
A metric that ignores velocity entirely and uses weighted position coordinates:
class WeightedPositionMetric(AbstractDistanceMetric):
"""Position-only metric with per-component weights."""
weights: dict[str, float] = eqx.field(static=True)
def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
# Compute weighted position differences
def weighted_diff_sq(component_name, positions_component):
diff = positions_component - current_pos[component_name]
weight = self.weights.get(component_name, 1.0)
return weight * diff**2
# Sum over all components
weighted_dist_sq = sum(weighted_diff_sq(k, v) for k, v in positions.items())
return jnp.sqrt(weighted_dist_sq)
# Use with custom weights (ignore y-coordinate)
metric = WeightedPositionMetric(weights={"x": 1.0, "y": 0.1})
config = pcf.WalkConfig(metric=metric)
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=0.0)
Units and Metrics#
When using physical units via unxt, ensure your metric correctly handles unit propagation:
import unxt as u
from phasecurvefit.metrics import (
AlignedMomentumDistanceMetric,
FullPhaseSpaceDistanceMetric,
)
# Position in kpc, velocity in km/s
pos = {"x": u.Q([0.0, 1.0, 2.0], "kpc"), "y": u.Q([0.0, 0.5, 1.0], "kpc")}
vel = {"x": u.Q([1.0, 1.0, 1.0], "km/s"), "y": u.Q([0.5, 0.5, 0.5], "km/s")}
# metric_scale must have units of distance for AlignedMomentumDistanceMetric
config = pcf.WalkConfig(metric=AlignedMomentumDistanceMetric())
result = pcf.walk_local_flow(
pos,
vel,
config=config,
start_idx=0,
metric_scale=u.Q(100.0, "kpc"), # Momentum weight in distance units
usys=u.unitsystems.galactic, # Required when using Quantities
)
# For FullPhaseSpaceDistanceMetric, metric_scale has units of time
config_6d = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result_6d = pcf.walk_local_flow(
pos,
vel,
config=config_6d,
start_idx=0,
metric_scale=u.Q(
1.0, "Gyr"
), # Time to convert velocity distance to spatial distance
usys=u.unitsystems.galactic, # Required when using Quantities
)
Metric Comparison#
Metric |
Position |
Velocity |
Lambda Meaning |
|---|---|---|---|
|
✓ |
✗ |
Ignored |
|
✓ |
✓ (alignment) |
Momentum penalty weight |
|
✓ |
✓ (magnitude) |
Time scale (velocity → position units) |
When to use each:
FullPhaseSpaceDistanceMetric (default): True 6D distance when position and velocity are equally important and you know the system’s natural time scale. No directional preference.
AlignedMomentumDistanceMetric: For coherent flows (stellar streams, winds) where velocity alignment should bias the ordering.
SpatialDistanceMetric: When velocity is unreliable or you want pure spatial clustering. Good baseline for comparison.
Metric Comparison Example#
Here’s a comparison of different metrics on the same data:
import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import (
AlignedMomentumDistanceMetric,
SpatialDistanceMetric,
)
# Sample spiral trajectory
theta = jnp.linspace(0, 4 * jnp.pi, 100)
pos = {
"x": jnp.cos(theta) * jnp.exp(theta / 10),
"y": jnp.sin(theta) * jnp.exp(theta / 10),
}
vel = {
"x": jnp.gradient(pos["x"]),
"y": jnp.gradient(pos["y"]),
}
# Compare metrics
metrics = {
"Momentum": AlignedMomentumDistanceMetric(),
"Spatial": SpatialDistanceMetric(),
}
for name, metric in metrics.items():
config = pcf.WalkConfig(metric=metric)
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
n_visited = len([i for i in result.indices if i >= 0])
print(f"{name}: {n_visited}/100 points ordered")
See Also#
Algorithm Guide - Core algorithm details
API Reference - Complete API documentation