Distance Metrics Guide#

The walk_local_flow algorithm uses distance metrics to determine how to select the next point in a phase-space trajectory. This guide explains the metric system and shows how to use and create custom metrics.

Overview#

A distance metric defines how the algorithm measures “closeness” between the current point and candidate next points in phase-space. Different metrics enable different physical interpretations and behaviors.

Metrics are configured via WalkConfig, which composes a metric with a query strategy (discussed in a separate guide):

import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import FullPhaseSpaceDistanceMetric

pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

config = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)

Built-in Metrics#

SpatialDistanceMetric#

A position-only metric that computes pure Euclidean distance, completely ignoring velocity information.

Mathematical formulation:

\[ d = d_0 \]

where \(d_0\) is the Euclidean distance between positions. The metric_scale parameter is ignored.

When to use:

  • Velocity information is unreliable or unavailable

  • Pure spatial proximity is desired (e.g., spatial clustering)

  • Comparing against baseline nearest-neighbor approaches

  • Setting metric_scale=0 with AlignedMomentumDistanceMetric is equivalent, but this metric is more explicit

Usage:

from phasecurvefit.metrics import SpatialDistanceMetric

# Pure nearest-neighbor search in position space
config = pcf.WalkConfig(metric=SpatialDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=0.0)

AlignedMomentumDistanceMetric#

The Nearest Neighbors with Momentum (NN+p) metric from Nibauer et al. (2022). This is the default metric.

Mathematical formulation:

\[ d = d_0 + \lambda (1 - \cos\theta) \]

where:

  • \(d_0\) is the Euclidean distance between positions

  • \(\theta\) is the angle between the current velocity and the direction to the candidate point

  • \(\lambda\) is the momentum weight parameter

Physical interpretation:

This metric combines spatial proximity with velocity alignment. Points that lie along the current velocity direction receive lower penalties, making the algorithm favor coherent flows in phase-space.

Usage:

import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import AlignedMomentumDistanceMetric

pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

# Aligned momentum metric
config = pcf.WalkConfig(metric=AlignedMomentumDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)

FullPhaseSpaceDistanceMetric#

A true 6D Euclidean distance metric in full phase-space, treating position and velocity symmetrically. This is the default metric.

Mathematical formulation:

\[ d = \sqrt{d_0^2 + (\tau \cdot d_v)^2} \]

where:

  • \(d_0\) is the Euclidean distance in position space

  • \(d_v\) is the Euclidean distance in velocity space

  • \(\tau\) is the time parameter (metric_scale) that converts velocity differences to position units

Physical interpretation:

This metric computes true Euclidean distance in the 6-dimensional phase space by combining position and velocity differences. The parameter metric_scale (with time units) determines the relative weighting: for example, if positions are measured in kpc and velocities in kpc/Myr, then metric_scale in Myr converts velocity differences to kpc, creating a uniformly scaled phase space.

Unlike AlignedMomentumDistanceMetric, this metric has no directional bias from momentum alignment — it treats all directions in phase space equally.

When to use:

  • Position and velocity information are equally important

  • You want true 6D proximity without momentum direction bias

  • The natural time scale of the system is known

  • Comparing against full phase-space clustering methods

Usage:

from phasecurvefit.metrics import FullPhaseSpaceDistanceMetric

# Full 6D phase-space distance (this is the default)
# metric_scale represents a time scale (e.g., if pos ~ kpc, vel ~ kpc/Myr, metric_scale ~ Myr)
config = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)

Comparison with momentum metric:

  • AlignedMomentumDistanceMetric: Directional — favors points along velocity direction

  • FullPhaseSpaceDistanceMetric: Isotropic — treats all directions equally

  • Both reduce to SpatialDistanceMetric when metric_scale=0

Creating Custom Metrics#

Custom metrics enable alternative distance calculations for specific use cases. For example, you might want:

  • Full 6D Cartesian distance in phase-space

  • Weighted combinations of position and velocity

  • Problem-specific distance measures

The AbstractDistanceMetric Interface#

All metrics must inherit from AbstractDistanceMetric and implement the __call__ method:

import equinox as eqx
from phasecurvefit.metrics import AbstractDistanceMetric


class CustomMetric(AbstractDistanceMetric):
    """Your custom distance metric."""

    def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
        """Compute modified distances."""
        # Your distance calculation here
        ...

Example: 6D Cartesian Metric#

Here’s a complete example of a metric that computes full 6D Cartesian distance:

import equinox as eqx
import jax
import jax.numpy as jnp
from phasecurvefit.metrics import AbstractDistanceMetric


class Full6DMetric(AbstractDistanceMetric):
    """6D Cartesian distance in phase-space.

    Treats position and velocity on equal footing, with `metric_scale` serving as
    a velocity-to-position scaling factor (units of time).

    Distance formula:
        d = sqrt(|Δr|² + (λ|Δv|)²)

    where Δr is position difference and Δv is velocity difference.
    """

    def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
        # Compute position differences (vmap over N points)
        pos_diff = jax.tree.map(jnp.subtract, positions, current_pos)

        # Sum of squared position differences
        pos_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, pos_diff)))

        # Compute velocity differences (vmap over N points)
        vel_diff = jax.tree.map(jnp.subtract, velocities, current_vel)

        # Sum of squared velocity differences, weighted by metric_scale^2
        vel_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, vel_diff)))

        # Combined 6D distance
        return jnp.sqrt(pos_dist_sq + (metric_scale**2) * vel_dist_sq)


# Use the custom metric via WalkConfig
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

config = pcf.WalkConfig(metric=Full6DMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)

Example: Weighted Position Metric#

A metric that ignores velocity entirely and uses weighted position coordinates:

class WeightedPositionMetric(AbstractDistanceMetric):
    """Position-only metric with per-component weights."""

    weights: dict[str, float] = eqx.field(static=True)

    def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
        # Compute weighted position differences
        def weighted_diff_sq(component_name, positions_component):
            diff = positions_component - current_pos[component_name]
            weight = self.weights.get(component_name, 1.0)
            return weight * diff**2

        # Sum over all components
        weighted_dist_sq = sum(weighted_diff_sq(k, v) for k, v in positions.items())

        return jnp.sqrt(weighted_dist_sq)


# Use with custom weights (ignore y-coordinate)
metric = WeightedPositionMetric(weights={"x": 1.0, "y": 0.1})
config = pcf.WalkConfig(metric=metric)
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=0.0)

Units and Metrics#

When using physical units via unxt, ensure your metric correctly handles unit propagation:

import unxt as u
from phasecurvefit.metrics import (
    AlignedMomentumDistanceMetric,
    FullPhaseSpaceDistanceMetric,
)

# Position in kpc, velocity in km/s
pos = {"x": u.Q([0.0, 1.0, 2.0], "kpc"), "y": u.Q([0.0, 0.5, 1.0], "kpc")}
vel = {"x": u.Q([1.0, 1.0, 1.0], "km/s"), "y": u.Q([0.5, 0.5, 0.5], "km/s")}

# metric_scale must have units of distance for AlignedMomentumDistanceMetric
config = pcf.WalkConfig(metric=AlignedMomentumDistanceMetric())
result = pcf.walk_local_flow(
    pos,
    vel,
    config=config,
    start_idx=0,
    metric_scale=u.Q(100.0, "kpc"),  # Momentum weight in distance units
    usys=u.unitsystems.galactic,  # Required when using Quantities
)

# For FullPhaseSpaceDistanceMetric, metric_scale has units of time
config_6d = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result_6d = pcf.walk_local_flow(
    pos,
    vel,
    config=config_6d,
    start_idx=0,
    metric_scale=u.Q(
        1.0, "Gyr"
    ),  # Time to convert velocity distance to spatial distance
    usys=u.unitsystems.galactic,  # Required when using Quantities
)

Metric Comparison#

Metric

Position

Velocity

Lambda Meaning

SpatialDistanceMetric

Ignored

AlignedMomentumDistanceMetric

✓ (alignment)

Momentum penalty weight

FullPhaseSpaceDistanceMetric

✓ (magnitude)

Time scale (velocity → position units)

When to use each:

  • FullPhaseSpaceDistanceMetric (default): True 6D distance when position and velocity are equally important and you know the system’s natural time scale. No directional preference.

  • AlignedMomentumDistanceMetric: For coherent flows (stellar streams, winds) where velocity alignment should bias the ordering.

  • SpatialDistanceMetric: When velocity is unreliable or you want pure spatial clustering. Good baseline for comparison.

Metric Comparison Example#

Here’s a comparison of different metrics on the same data:

import jax.numpy as jnp
import phasecurvefit as pcf
from phasecurvefit.metrics import (
    AlignedMomentumDistanceMetric,
    SpatialDistanceMetric,
)

# Sample spiral trajectory
theta = jnp.linspace(0, 4 * jnp.pi, 100)
pos = {
    "x": jnp.cos(theta) * jnp.exp(theta / 10),
    "y": jnp.sin(theta) * jnp.exp(theta / 10),
}
vel = {
    "x": jnp.gradient(pos["x"]),
    "y": jnp.gradient(pos["y"]),
}

# Compare metrics
metrics = {
    "Momentum": AlignedMomentumDistanceMetric(),
    "Spatial": SpatialDistanceMetric(),
}

for name, metric in metrics.items():
    config = pcf.WalkConfig(metric=metric)
    result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
    n_visited = len([i for i in result.indices if i >= 0])
    print(f"{name}: {n_visited}/100 points ordered")

See Also#