API Reference#
Complete API documentation for phasecurvefit.
Main Function#
- phasecurvefit.walk_local_flow(xs: dict[str, ~jaxtyping.Real[Array, 'N'] | ~jaxtyping.Real[ndarray, 'N'] | ~jaxtyping.Real[TypedNdArray, 'N']], vs: dict[str, ~jaxtyping.Real[Array, 'N'] | ~jaxtyping.Real[ndarray, 'N'] | ~jaxtyping.Real[TypedNdArray, 'N']], /, *, start_idx: int = 0, metric_scale: ~jaxtyping.Real[Array, ''] | ~jaxtyping.Real[ndarray, ''] | ~numpy.number | int | float | ~jaxtyping.Real[TypedNdArray, ''] = 1.0, max_dist: float = inf, terminate_indices: ~collections.abc.Set[int] | None = None, n_max: int | None = None, config: ~phasecurvefit._src.query_config.WalkConfig = WalkConfig( metric=AlignedMomentumDistanceMetric(), strategy=<phasecurvefit._src.strategies.BruteForce object> ), metadata: ~phasecurvefit._src.algorithm.StateMetadata = StateMetadata({}), direction: ~typing.Literal['forward', 'backward', 'both'] = 'forward')
Find an ordered path through phase-space using the local flow.
- Parameters:
xs (
dict[str,Union[Real[Array, 'N'],Real[ndarray, 'N'],Real[TypedNdArray, 'N']]]) – Position dictionary with 1D array values of shape (N,).vs (
dict[str,Union[Real[Array, 'N'],Real[ndarray, 'N'],Real[TypedNdArray, 'N']]]) – Velocity dictionary with 1D array values of shape (N,).start_idx (int, optional) – The index of the starting observation (default: 0).
metric_scale (
Union[Real[Array, ''],Real[ndarray, ''],number,int,float,Real[TypedNdArray, '']]) – Metric-dependent scale parameter. Interpretation depends on the metric: - AlignedMomentumDistanceMetric: Momentum weight (distance units) - FullPhaseSpaceDistanceMetric: Time scale for velocity-to-position conversion - SpatialDistanceMetric: Ignored Default: 1.0.max_dist (
float) – Maximum allowable distance between neighbors. If the minimum distance exceeds this value, the algorithm terminates, leaving remaining points unvisited. This is key to the algorithm’s ability to skip outliers. Default: jnp.inf (no limit).terminate_indices (
Set[int] |None) – Set of indices at which to terminate the algorithm if reached. Default: None.n_max (
int|None) – Maximum number of iterations. Default: None (process all points).config (WalkConfig) – Configuration for neighbor queries, containing both the distance metric and the query strategy. Use
WalkConfig(metric=..., strategy=...)to customize. Defaults toWalkConfig()which usesFullPhaseSpaceDistanceMetricwithBruteForce.metadata (
StateMetadata) – Optional metadata to pass through the algorithm state without participating in computation. Useful for unit systems or other context.direction (
Literal['forward','backward','both']) – Direction to walk the local flow. ‘forward’ walks along the velocity field, ‘backward’ walks against the velocity field, and ‘both’ walks in both directions. Default is ‘forward’.
- Returns:
NamedTuple with fields:
”indices”: ordered indices array with -1 for unvisited observations.
”visited”: boolean array indicating visited observations.
”xs”: original xs dict
”vs”: original vs dict
- Return type:
WalkLocalFlowResult
Examples
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf
Create phase-space data for a simple stream:
>>> pos = { ... "x": jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]), ... "y": jnp.array([0.0, 0.1, 0.2, 0.3, 0.4]), ... } >>> vel = { ... "x": jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]), ... "y": jnp.array([0.1, 0.1, 0.1, 0.1, 0.1]), ... }
Run the algorithm starting from index 0:
>>> result = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=0.5) >>> result.ordering Array([0, 1, 2, 3, 4], dtype=int32)
Walk in the backward direction:
>>> result_backward = pcf.walk_local_flow( ... pos, vel, start_idx=4, metric_scale=0.5, direction="backward" ... ) >>> result_backward.indices Array([4, 3, 2, 1, 0], dtype=int32)
- walk_local_flow(positions: collections.abc.Mapping[str, jaxtyping.Real[AbstractQuantity, 'N']], velocities: collections.abc.Mapping[str, jaxtyping.Real[AbstractQuantity, 'N']], /, *, start_idx: int, metric_scale: jaxtyping.Real[AbstractQuantity, ''], max_dist: jaxtyping.Real[AbstractQuantity, ''] = Quantity(Array(inf, dtype=float32, weak_type=True), unit='m'), terminate_indices: set[int] | None = None, n_max: int | None = None, config: phasecurvefit._src.query_config.WalkConfig = WalkConfig(
tric=AlignedMomentumDistanceMetric(), rategy=<phasecurvefit._src.strategies.BruteForce object at 0x78c8dc338230> irection: Literal[‘forward’, ‘backward’, ‘both’] = ‘forward’, metadata: phasecurvefit._src.algorithm.StateMetadata = StateMetadata({}), usys: unxt._src.unitsystems.base.AbstractUnitSystem | None = None) -> phasecurvefit._src.algorithm.WalkLocalFlowResult Implement for Quantity-valued phase-space data.
- Parameters:
positions (Mapping[str, unxt.AbstractQuantity]) – Position dictionary with Quantity-valued components. All values must have compatible length dimensions.
velocities (Mapping[str, unxt.AbstractQuantity]) – Velocity dictionary with Quantity-valued components. All values must have compatible velocity dimensions (length/time).
start_idx (int) – Index of the starting observation.
metric_scale (Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']) – Metric-dependent scale parameter.
max_dist (float) – Maximum allowed distance for neighbor selection. Observations beyond this distance are not considered. Default is infinity.
terminate_indices (set[int], optional) – Set of observation indices that, when reached, terminate the ordering.
n_max (int, optional) – Maximum number of observations to include in the ordering. If None, all observations are included.
config (WalkConfig) – Configuration for neighbor queries, containing both the distance metric and the query strategy. Use
WalkConfig(metric=..., strategy=...)to customize. Defaults toWalkConfig()which usesFullPhaseSpaceDistanceMetricwithBruteForce.direction (['forward', 'backward', 'both'], optional) – Direction to walk the local flow. ‘forward’ walks along the velocity field, ‘backward’ walks against the velocity field, and ‘both’ walks in both directions. Default is ‘forward’.
metadata (StateMetadata) – Optional metadata to pass through the algorithm state without participating in computation. Useful for unit systems or other context.
usys (unxt.AbstractUnitSystem, optional) – Unit system to use for consistent unit stripping of Quantities. Default is SI units.
xs (dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']])
vs (dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']])
- Returns:
Result container with ordered indices and original data.
- Return type:
WalkLocalFlowResult
Examples
>>> import jax.numpy as jnp >>> import unxt as u >>> q = { ... "x": u.Q(jnp.array([0.0, 1.0, 2.0]), "m"), ... "y": u.Q(jnp.array([0.0, 0.5, 1.0]), "m"), ... } >>> p = { ... "x": u.Q(jnp.array([1.0, 1.0, 1.0]), "m/s"), ... "y": u.Q(jnp.array([0.5, 0.5, 0.5]), "m/s"), ... } >>> result = walk_local_flow( ... q, p, start_idx=0, metric_scale=u.Q(1.0, "m"), usys=u.unitsystems.si ... ) >>> result WalkLocalFlowResult( positions={'x': Quantity(f32[3], unit='m'), 'y': Quantity(f32[3], unit='m')}, velocities={ 'x': Quantity(f32[3], unit='m / s'), 'y': Quantity(f32[3], unit='m / s') }, indices=i32[3], gamma_range=(0.0, 1.0) )
Result Accessor#
Helper function to extract ordered data from results.
- phasecurvefit.order_w(res: WalkLocalFlowResult, /)
Get xs and vs in the ordered sequence from a WalkLocalFlowResult.
Filters out unvisited indices (marked as -1) and returns only the visited observations in the order they were traversed.
- Parameters:
res (
WalkLocalFlowResult) – The result from walk_local_flow.- Return type:
tuple[dict[str,Union[Real[Array, 'N'],Real[ndarray, 'N'],Real[TypedNdArray, 'N']]],dict[str,Union[Real[Array, 'N'],Real[ndarray, 'N'],Real[TypedNdArray, 'N']]]]- Returns:
xs (dict[str, Array]) – Position arrays reordered according to the algorithm’s output, with unvisited observations removed.
vs (dict[str, Array]) – Velocity arrays reordered according to the algorithm’s output, with unvisited observations removed.
Examples
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf >>> pos = {"x": jnp.array([3.0, 1.0, 2.0])} >>> vel = {"x": jnp.array([1.0, 1.0, 1.0])} >>> result = pcf.walk_local_flow(pos, vel, start_idx=1, metric_scale=0.0) >>> ordered_pos, ordered_vel = pcf.order_w(result)
Distance Metrics#
Pluggable distance metrics for controlling how the algorithm selects the next point. See the Metrics Guide for usage examples.
- class phasecurvefit.metrics.AbstractDistanceMetric
Bases:
ModuleAbstract base class for distance metrics in phase-space walks.
A distance metric computes modified distances between a current point and all candidate next points, incorporating both spatial and velocity information. Different metrics can implement different weighting schemes or use different phase-space representations.
Examples
>>> import phasecurvefit as pcf >>> metric = pcf.metrics.AlignedMomentumDistanceMetric() >>> # Use with walk_local_flow via metric parameter
- final class phasecurvefit.metrics.AlignedMomentumDistanceMetric
Bases:
AbstractDistanceMetricDefault momentum-based distance metric.
Computes modified distance as:
$$ d = d_0 + lambda (1 - costheta) $$
where $d_0$ is the Euclidean distance in position space, $theta$ is the angle between the current velocity and the direction to the candidate point, and $lambda$ controls the relative importance of momentum alignment.
When $lambda = 0$, reduces to pure nearest-neighbor search in position space. As $lambda$ increases, points aligned with the current velocity direction are increasingly favored.
This is the original phase-flow walk metric from Nibauer et al. (2022).
Examples
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf >>> metric = pcf.metrics.AlignedMomentumDistanceMetric() >>> pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])} >>> vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])} >>> current_pos = {k: v[0] for k, v in pos.items()} >>> current_vel = {k: v[0] for k, v in vel.items()} >>> distances = metric(current_pos, current_vel, pos, vel, metric_scale=1.0) >>> distances.shape (3,)
- final class phasecurvefit.metrics.SpatialDistanceMetric
Bases:
AbstractDistanceMetricPosition-only distance metric.
Computes pure Euclidean distance in position space, ignoring velocity information entirely. This reduces to standard nearest-neighbor search.
$$ d = d_0 $$
where $d_0$ is the Euclidean distance between positions. The metric_scale parameter is ignored.
This metric is useful when: - Velocity information is unreliable or unavailable - Pure spatial proximity is the desired ordering criterion - Comparing against baseline nearest-neighbor approaches
Examples
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf >>> metric = pcf.metrics.SpatialDistanceMetric() >>> pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])} >>> vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])} >>> current_pos = {k: v[0] for k, v in pos.items()} >>> current_vel = {k: v[0] for k, v in vel.items()} >>> distances = metric(current_pos, current_vel, pos, vel, metric_scale=0.0) >>> distances.shape (3,)
- final class phasecurvefit.metrics.FullPhaseSpaceDistanceMetric
Bases:
AbstractDistanceMetricFull 6D phase-space distance metric.
Computes the Euclidean distance in the full 6-dimensional phase space by combining position and velocity differences. The parameter metric_scale (with time units) converts velocity differences to position units.
$$ d = sqrt{d_0^2 + (tau cdot d_v)^2} $$
where: - $d_0$ is the Euclidean distance in position space - $d_v$ is the Euclidean distance in velocity space - $tau$ is the time parameter (metric_scale) that converts velocity to
position units
This metric treats position and velocity symmetrically in phase space, without directional bias from momentum alignment. The metric_scale parameter determines the relative weighting of velocity differences.
Physically, if we think of phase space as having position coordinates measured in kpc and velocity coordinates measured in kpc/Myr, then metric_scale with units of Myr converts velocity differences to kpc, allowing us to compute a true Euclidean distance in a uniformly scaled phase space.
This metric is useful when:
Position and velocity information are equally important
You want true 6D proximity without momentum direction bias
The natural time scale of the system is known
Examples
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf >>> metric = pcf.metrics.FullPhaseSpaceDistanceMetric() >>> pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])} >>> vel = {"x": jnp.array([1.0, 1.5, 2.0]), "y": jnp.array([0.5, 1.0, 1.5])} >>> current_pos = {k: v[0] for k, v in pos.items()} >>> current_vel = {k: v[0] for k, v in vel.items()} >>> # metric_scale=1.0 means 1 unit of velocity diff = 1 unit of position diff >>> distances = metric(current_pos, current_vel, pos, vel, metric_scale=1.0) >>> distances.shape (3,)
Phase-Space Utilities#
Low-level functions for phase-space operations. Available in the phasecurvefit.w submodule.
- phasecurvefit.w.euclidean_distance(q_a: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], q_b: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], /)
Compute Euclidean distance between two position points in Cartesian space.
This function operates on scalar components only (single points). Use jax.vmap to compute distances for arrays of points.
- Parameters:
q_a (ScalarComponents) – Position dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
q_b (ScalarComponents) – Position dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
- Returns:
The Euclidean distance.
- Return type:
FLikeSz0
Examples
>>> import jax.numpy as jnp >>> q_a = {"x": jnp.array(0.0), "y": jnp.array(0.0)} >>> q_b = {"x": jnp.array(3.0), "y": jnp.array(4.0)} >>> float(euclidean_distance(q_a, q_b)) 5.0
- phasecurvefit.w.euclidean_distance(q_a: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], q_b: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], /) jaxtyping.Real[AbstractQuantity, '']
Euclidean distance between Quantity-valued component dictionaries.
Computes the distance between two phase-space positions represented as dictionaries with unxt Quantity scalar values.
- Parameters:
q_a (Mapping[str, unxt.AbstractQuantity]) – Position dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible length dimensions.
q_b (Mapping[str, unxt.AbstractQuantity]) – Position dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible length dimensions.
- Returns:
The Euclidean distance with the unit of the input components.
- Return type:
unxt.Quantity
Examples
>>> import jax.numpy as jnp >>> import unxt as u >>> q_a = {"x": u.Q(0.0, "m"), "y": u.Q(0.0, "m")} >>> q_b = {"x": u.Q(3.0, "m"), "y": u.Q(4.0, "m")} >>> euclidean_distance(q_a, q_b) Quantity(Array(5., dtype=float32, weak_type=True), unit='m')
- phasecurvefit.w.unit_direction(q_a: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], q_b: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], /)
Compute unit direction vector from position a to b in Cartesian space.
This function operates on scalar components only (single points). Use jax.vmap to compute directions for arrays of points.
- Parameters:
q_a (ScalarComponents) – Position dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
q_b (ScalarComponents) – Position dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
- Returns:
Dictionary of unit direction Cartesian components.
- Return type:
ScalarComponents
Examples
>>> import jax.numpy as jnp >>> q_a = {"x": jnp.array(0.0), "y": jnp.array(0.0)} >>> q_b = {"x": jnp.array(3.0), "y": jnp.array(4.0)} >>> udir = unit_direction(q_a, q_b) >>> float(udir["x"]), float(udir["y"]) (0.6..., 0.8...)
- phasecurvefit.w.unit_direction(q_a: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], q_b: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], /) Mapping[str, jaxtyping.Real[AbstractQuantity, '']]
Compute unit direction vector from q_a to q_b for Quantity-valued components.
Computes the unit direction vector pointing from position q_a to q_b, where both positions are represented as dictionaries with unxt Quantity scalar values.
- Parameters:
q_a (Mapping[str, unxt.AbstractQuantity]) – Position dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible length dimensions.
q_b (Mapping[str, unxt.AbstractQuantity]) – Position dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible length dimensions.
- Returns:
A dictionary representing the unit direction vector. The components are dimensionless Quantities.
- Return type:
Mapping[str, unxt.AbstractQuantity]
Examples
>>> import jax.numpy as jnp >>> import unxt as u >>> q_a = {"x": u.Q(0.0, "m"), "y": u.Q(0.0, "m")} >>> q_b = {"x": u.Q(3.0, "m"), "y": u.Q(4.0, "m")} >>> unit_direction(q_a, q_b) {'x': Quantity(Array(0.6, dtype=float32, weak_type=True), unit=''), 'y': Quantity(Array(0.8, dtype=float32, weak_type=True), unit='')}
- phasecurvefit.w.unit_velocity(velocity: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], /)
Compute the unit velocity vector in Cartesian space.
This function operates on scalar components only (single velocity vector). Use jax.vmap to compute unit velocities for arrays of velocities.
- Parameters:
velocity (ScalarComponents) – Velocity dictionary with scalar Cartesian components (keys: “x”, “y”, “z”).
- Returns:
Dictionary of unit velocity Cartesian components.
- Return type:
ScalarComponents
Examples
>>> import jax.numpy as jnp >>> vel = {"x": jnp.array(3.0), "y": jnp.array(4.0)} >>> uvel = unit_velocity(vel) >>> float(uvel["x"]), float(uvel["y"]) (0.6..., 0.8...)
- phasecurvefit.w.unit_velocity(velocity: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], /) Mapping[str, jaxtyping.Real[AbstractQuantity, '']]
Compute unit velocity vector for Quantity-valued components.
Computes the unit velocity vector from a velocity represented as a dictionary with unxt Quantity scalar values.
- Parameters:
velocity (Mapping[str, unxt.AbstractQuantity]) – Velocity dictionary with Quantity-valued components. All values must have compatible velocity dimensions (length/time).
- Returns:
A dictionary representing the unit velocity vector. The components are dimensionless Quantities.
- Return type:
Mapping[str, unxt.AbstractQuantity]
Examples
>>> import jax.numpy as jnp >>> import unxt as u >>> vel = {"x": u.Q(3.0, "m/s"), "y": u.Q(4.0, "m/s")} >>> unit_velocity(vel) {'x': Quantity(Array(0.6, dtype=float32, ...), unit=''), 'y': Quantity(Array(0.8, dtype=float32, ...), unit='')}
- phasecurvefit.w.cosine_similarity(vec_a: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], vec_b: dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']], /)
Compute cosine similarity between two vectors in Cartesian space.
The cosine similarity is defined as the dot product of the vectors. For unit vectors, this equals the cosine of the angle between them.
This function operates on scalar components only (single vectors). Use jax.vmap to compute similarities for arrays of vectors.
- Parameters:
vec_a (ScalarComponents) – Vector dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
vec_b (ScalarComponents) – Vector dictionaries with scalar Cartesian components (keys: “x”, “y”, “z”).
- Returns:
The cosine similarity (dot product).
- Return type:
FLikeSz0
Examples
>>> import jax.numpy as jnp >>> # Parallel vectors >>> a = {"x": jnp.array(1.0), "y": jnp.array(0.0)} >>> b = {"x": jnp.array(1.0), "y": jnp.array(0.0)} >>> float(cosine_similarity(a, b)) 1.0
>>> # Orthogonal vectors >>> a = {"x": jnp.array(1.0), "y": jnp.array(0.0)} >>> b = {"x": jnp.array(0.0), "y": jnp.array(1.0)} >>> float(cosine_similarity(a, b)) 0.0
- phasecurvefit.w.cosine_similarity(vel_a: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], vel_b: Mapping[str, jaxtyping.Real[AbstractQuantity, '']], /) jaxtyping.Real[AbstractQuantity, '']
Compute cosine similarity between Quantity-valued velocity components.
Computes the cosine similarity (dimensionless) between two vectors represented as dictionaries with unxt Quantity scalar values. The result is the cosine of the angle between the two vectors.
- Parameters:
vel_a (Mapping[str, unxt.AbstractQuantity]) – Velocity or direction dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible dimensions.
vel_b (Mapping[str, unxt.AbstractQuantity]) – Velocity or direction dictionaries with Quantity-valued components. Must have the same keys. All values must have compatible dimensions.
vec_a (dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']])
vec_b (dict[str, Real[Array, ''] | Real[ndarray, ''] | number | int | float | Real[TypedNdArray, '']])
- Returns:
The dimensionless cosine similarity between the two vectors.
- Return type:
unxt.Quantity
Examples
>>> import jax.numpy as jnp >>> import unxt as u >>> vel_a = {"x": u.Q(1.0, "m/s"), "y": u.Q(0.0, "m/s")} >>> vel_b = {"x": u.Q(0.0, "m/s"), "y": u.Q(1.0, "m/s")} >>> cosine_similarity(vel_a, vel_b) Quantity(Array(0., dtype=float32, ...), unit='')
- phasecurvefit.w.get_w_at(q: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], p: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], idx: int | Int[Array, ''] | Int[Array, 'N'], /)
Extract a phase-space point at the given index.
This function uses standard phase-space notation where: - q = position (generalized Cartesian coordinates: “x”, “y”, “z”) - p = momentum/velocity (generalized Cartesian momenta: “x”, “y”, “z”) - w = (q, p) = full phase-space point
This extracts a single point (scalar components) from arrays. For extracting multiple points, use jax.vmap or array indexing directly.
- Parameters:
q (VectorComponents) – Position dictionary with 1D array Cartesian values of shape (N,).
p (VectorComponents) – Velocity/momentum dictionary with 1D array Cartesian values of shape (N,).
idx (int | Array) – Index or indices to extract. Can be: - int: Extract a single point (returns scalar arrays) - 0-d Array: Extract a single point (returns scalar arrays) - 1-d Array: Extract multiple points (returns 1D arrays)
- Returns:
The (position, velocity) tuple at the given index/indices.
- Return type:
Examples
>>> import jax.numpy as jnp >>> pos = {"x": jnp.array([1.0, 2.0, 3.0]), "y": jnp.array([4.0, 5.0, 6.0])} >>> vel = {"x": jnp.array([0.1, 0.2, 0.3]), "y": jnp.array([0.4, 0.5, 0.6])}
Extract a single point:
>>> q, p = get_w_at(pos, vel, 1) >>> float(q["x"]), float(p["y"]) (2.0, 0.5)
Extract multiple points:
>>> q, p = get_w_at(pos, vel, jnp.array([0, 2])) >>> list(q["x"]) [Array(1., dtype=float32), Array(3., dtype=float32)]
Types#
- class phasecurvefit.WalkLocalFlowResult(positions: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], velocities: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], indices: Int[Array, 'N'], *, gamma_range: tuple[float, float] = (0.0, 1.0))
Bases:
AbstractResultResult of the local flow walk algorithm.
This class represents the complete output of the phase-flow walk algorithm. It contains the walk ordering, original phase-space data, and provides methods for examining and interpolating along the discovered stream.
- positions
Position dictionary with keys (e.g., “x”, “y”, “z”) and values as 1D arrays of shape (n_obs,). These are the original positions from the input, not reordered.
- velocities
Velocity dictionary with same keys and shape as
positions. These are the original velocities from the input, not reordered.
- indices
Ordered indices of visited observations. Shape (n_obs,). Unvisited observations are marked with -1.
The walk order can be extracted by filtering: indices[indices >= 0]. See
orderingproperty for a convenience accessor.- Type:
Int[Array, ” n_obs”]
- gamma_range
Valid range of the ordering parameter in __call__. Default is (0.0, 1.0). This is a static field and cannot be changed after construction.
Notes
The walk algorithm discovers a path through phase-space by following the local flow defined by the velocity field. The ordering encodes which observations form a coherent sequence along this path.
Key distinction:
indicesis an array of lengthn_obswhere the position in the array indicates the order in the walk, and the value at that position is the original observation index. For example:indices = [3, 7, 1, -1, 5, ...] # ^ 1st visited observation is index 3 # ^ 2nd visited observation is index 7 # ^ 3rd visited observation is index 1 # ^ 4th observation was not visited # ^ 5th visited observation is index 5
Properties provide convenient access to: -
visited: Boolean mask of visited observations -ordering: Indices in walk order (filtered non-negative) -ordered: Positions/velocities reordered by walk -skipped_indices: Indices of unvisited observationsThe interpolation method (
__call__()) enables smooth spatial interpolation along the discovered path using a continuous ordering parameter $gamma in [0, 1]$.Examples
Basic Usage: Extract Ordering and Properties
>>> import jax.numpy as jnp >>> import phasecurvefit as pcf >>> pos = { ... "x": jnp.linspace(0, 10, 20), ... "y": jnp.sin(jnp.linspace(0, 2 * 3.14159, 20)), ... } >>> vel = {"x": jnp.ones(20), "y": jnp.cos(jnp.linspace(0, 2 * 3.14159, 20))} >>> result = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0) >>> result.indices Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=int32) >>> result.n_visited Array(20, dtype=int32) >>> result.n_skipped Array(0, dtype=int32)
Accessing Ordered Data
>>> qs_ordered, vs_ordered = result.ordered >>> qs_ordered["x"].shape (20,)
Spatial Interpolation with Gamma Parameter
The walk result can be called as a function to interpolate spatial positions from an ordering parameter $gamma in [0, 1]$:
>>> gamma = jnp.array([0.0, 0.5, 1.0]) >>> positions_interp = result(gamma) >>> positions_interp["x"] Array([ 0., 5., 10.], dtype=float32)
Scalar Interpolation
>>> pos_at_midpoint = result(0.5) >>> pos_at_midpoint["x"] Array(5., dtype=float32)
JAX Transformations: JIT Compilation
The interpolator is JIT-compatible for efficient compilation:
>>> import jax >>> @jax.jit ... def get_position(gamma): ... return result(gamma) >>> get_position(0.25) {'x': Array(2.5, dtype=float32), 'y': Array(0.9897884, dtype=float32)}
JAX Transformations: Vectorization with vmap
Interpolate multiple gamma values efficiently:
>>> gamma_batch = jnp.linspace(0, 1, 100) >>> @jax.jit ... def interpolate_many(gammas): ... return jax.vmap(result)(gammas) >>> positions_batch = interpolate_many(gamma_batch) >>> positions_batch["x"].shape (100,)
JAX Transformations: Automatic Differentiation
Compute gradients of positions with respect to the ordering parameter:
>>> def loss(gamma): ... pos = result(gamma) ... return jnp.sum(pos["x"] ** 2 + pos["y"] ** 2) >>> grad_fn = jax.grad(loss) >>> grad_at_half = grad_fn(0.5)
Composition: JIT + vmap + grad
Combine transformations for maximum efficiency:
>>> @jax.jit ... def compute_gradients(gammas): ... return jax.vmap(jax.grad(loss))(gammas) >>> compute_gradients(jnp.linspace(0, 1, 50)) Array([ 0. , 5.6351056, 11.270211 , ...], dtype=float32)
>>> result.visited.shape (20,)
- Parameters:
-
indices:
Int[Array, 'N']
- property visited: Bool[Array, 'N']
Boolean array indicating which observations were visited.
- property n_visited: Int[Array, '']
Number of observations that were visited (not skipped).
- property n_skipped: Int[Array, '']
Number of observations that were not visited (skipped).
- property all_visited: Bool[Array, '']
Whether all observations were visited (no skips).
- property skipped_indices: Int[Array, 'n_skipped']
Indices of skipped observations (marked as -1 in indices).
- property ordering: Int[Array, 'n_visited']
Indices of visited observations in the order they were visited.
- phasecurvefit.ScalarComponents : TypeAlias = Mapping[str, FLikeSz0]#
Type alias for dictionaries mapping component names to scalar JAX arrays.
Used for single phase-space points. Keys are coordinate/component names (e.g., “x”, “y”, “z”), values are 0-dimensional JAX arrays.
Example:
position: ScalarComponents = { "x": jnp.array(1.0), "y": jnp.array(2.0), }
- phasecurvefit.VectorComponents : TypeAlias = Mapping[str, FLikeSzN]#
Type alias for dictionaries mapping component names to 1D JAX arrays.
Used for arrays of phase-space points. Keys are coordinate/component names (e.g., “x”, “y”, “z”), values are 1-dimensional JAX arrays of shape (N,).
Example:
position: VectorComponents = { "x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 1.0, 2.0]), }
Autoencoder Module#
Neural network for interpolating skipped tracers. See Autoencoder Guide for details.
Classes#
- class phasecurvefit.nn.PathAutoencoder(encoder: OrderingNet, decoder: TrackNet, normalizer: AbstractNormalizer)
Bases:
AbstractAutoencoderAutoencoder combining OrderingNet and TrackNet.
This autoencoder is trained to assign $gamma$ values to stream tracers that were skipped by the phase-flow walk algorithm. It consists of two parts:
Interpolation Network: Maps phase-space coordinates $(x, v) to (gamma, p)$ where $gamma in [0, 1]$ is the ordering parameter and $p in [0, 1]$ is the membership probability.
Param-Net (Decoder): Maps $gamma to x$, reconstructing the position from the ordering parameter.
- Parameters:
encoder (
OrderingNet)decoder (
TrackNet)normalizer (
AbstractNormalizer)
-
encoder:
OrderingNet
-
decoder:
TrackNet
-
normalizer:
AbstractNormalizer
- classmethod make(normalizer: AbstractNormalizer, *, gamma_range: tuple[float, float], ordering_width_size: int = 100, ordering_depth: int = 2, track_width_size: int = 128, track_depth: int = 3, key: Key[Array, ''] | UInt32[Array, '2'])
- decode(gamma: Float[Array, 'N'], /, *, key: Key[Array, ''] | UInt32[Array, '2'] | None = None)
Decode $gamma$ to reconstructed position.
- Parameters:
gamma (Array) – Ordering parameter, typically in the range defined by gamma_range, shape (N,). Some decoders may support extrapolation beyond this range.
key (PRNGKeyArray, optional) – JAX random key for stochastic decoding (if applicable).
- Returns:
position – Reconstructed dict of positions.
- Return type:
VectorComponents
- encode(qs: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], ps: dict[str, Real[Array, 'N'] | Real[ndarray, 'N'] | Real[TypedNdArray, 'N']], /, *, key: Key[Array, ''] | UInt32[Array, '2'] | None = None)
Encode phase-space coordinates to ($gamma$, $p$).
- Parameters:
qs (VectorComponents) – Spatial / velocity coordinates of shape (N, n_dims).
ps (VectorComponents) – Spatial / velocity coordinates of shape (N, n_dims).
key (PRNGKeyArray, optional) – JAX random key for stochastic encoding (if applicable).
- Return type:
tuple[Float[Array, 'N'],Float[Array, 'N']]- Returns:
gamma (Array) – Ordering parameter in [0, 1], shape (N,).
prob (Array) – Membership probability in [0, 1], shape (N,).
- class phasecurvefit.nn.OrderingNet(in_size: int = 6, width_size: int = 100, depth: int = 2, *, gamma_range: tuple[float, float] = (0.0, 1.0), key: Key[Array, ''] | UInt32[Array, '2'])
Bases:
ModuleInterpolation network:$(x, v) ;mapsto; (gamma, p),$.
This network takes N-D phase-space coordinates and outputs:
$gamma in [0, 1]$: The ordering parameter along the stream
$p in [0, 1]$: The membership probability (1 = likely stream member)
The architecture follows Appendix B.3 of Nibauer et al. (2022).
Uses scan-over-layers for improved compilation speed. See: https://docs.kidger.site/equinox/tricks/#improve-compilation-speed-with-scan-over-layers
- Parameters:
in_size (int) – Number of spatial + kinematic dimensions (6 for 3D: x, y, z, vx, vy, vz).
width_size (int) – The size of each hidden layer.
depth (int, optional) –
The number of hidden layers, not include the input layer or output heads. For example, depth=2 results in an network with layers:
[Linear(in_size, width_size), Linear(width_size, width_size), Linear(width_size, out_size), (output_heads)]key (PRNGKeyArray) – JAX random key for initialization.
-
in_size:
int
-
width_size:
int
-
depth:
int
-
mlp:
MLP
-
gamma_head:
Linear
-
prob_head:
Linear
- class phasecurvefit.nn.TrackNet(out_size: int = 3, width_size: int = 100, depth: int = 3, *, key: Key[Array, ''] | UInt32[Array, '2'])
Bases:
ModuleParam-Net (decoder): maps $gamma to$ position (x, y, z).
This network reconstructs the stream track position from the ordering parameter $gamma$. It serves as the second half of the autoencoder.
The architecture follows Appendix B.1 of Nibauer et al. (2022).
Uses scan-over-layers for improved compilation speed. See: https://docs.kidger.site/equinox/tricks/#improve-compilation-speed-with-scan-over-layers
- Parameters:
key (PRNGKeyArray) – JAX random key for initialization.
out_size (int) – Number of spatial dimensions (2 for 2D, 3 for 3D) for the track speed.
hidden_size (int, optional) – Size of hidden layers. Default: 100.
n_hidden (int, optional) – Number of hidden layers. Default: 3.
width_size (
int)depth (
int)
-
out_size:
int
-
width_size:
int
-
depth:
int
-
mlp:
MLP
- class phasecurvefit.nn.TrainingConfig(*, optimizer: ~optax._src.base.GradientTransformation = (<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), batch_size: int = 100, show_pbar: bool = True, member_threshold: float = 0.5, n_epochs_encoder: int = 800, lambda_prob: float = 1.0, n_epochs_decoder: int = 100, n_epochs_both: int = 200, lambda_q: float = 1.0, lambda_p: tuple[float, float] = (1.0, 5.0), weight_by_density: bool | ~collections.abc.Mapping[str, object] = False, freeze_encoder_final_training: bool = False)
Bases:
objectConfiguration for three-phase autoencoder training.
- Parameters:
-
optimizer:
GradientTransformation= (<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>) Optax optimizer for training.
-
batch_size:
int= 100 Batch size for training.
-
show_pbar:
bool= True Show an epoch progress bar via tqdm.
-
member_threshold:
float= 0.5 Membership p > threshold for identifying stream members.
-
n_epochs_encoder:
int= 800 Number of epochs for Phase 1 training (OrderingNet)
-
lambda_prob:
float= 1.0 Weight for probability loss terms.
-
n_epochs_decoder:
int= 100 Number of epochs for Phase 2 training (TrackNet)
-
n_epochs_both:
int= 200 Number of epochs for Phase 2 training (TrackNet)
-
lambda_q:
float= 1.0 Weight for phase-2 spatial training.
-
weight_by_density:
bool|Mapping[str,object] = False Whether to inverse density weight the samples. USE WITH CARE.
-
freeze_encoder_final_training:
bool= False Whether to freeze the encoder during phase 2 training.
- property n_epochs: int
Return the total number of epochs.
- encoderonly_config()
Construct the OrderingNet config.
- Return type:
OrderingTrainingConfig
- decoderonly_config()
Construct the TrackNet config.
- Return type:
TrackTrainingConfig
- autoencoder_config()
Construct the Autoencoder config.
- Return type:
EncoderDecoderTrainingConfig
Training Functions#
- phasecurvefit.nn.train_autoencoder(model: PathAutoencoder, all_ws: Float[Array, 'N TwoF'], ordering_indices: Int[Array, 'N'], /, *, config: TrainingConfig | None = None, key: Key[Array, ''] | UInt32[Array, '2'])
Train the PathAutoencoder in two phases.
This function orchestrates the complete two-phase training procedure:
Phase 1 (OrderingNet/Encoder): Trains the encoder to predict $gamma$ (ordering parameter) and $p$ (membership probability) from phase-space coordinates. Uses the ordering from the walk algorithm as supervision.
Phase 2 (TrackNet/Decoder): Trains the decoder to reconstruct spatial positions from $gamma$ while aligning with velocity directions. Uses the trained encoder to filter stream members based on membership probability threshold.
- Parameters:
model (EncoderExternalDecoder) – Untrained or partially trained autoencoder model.
all_ws (Array, shape (N, 2*n_dims)) – All phase-space coordinates (positions + velocities).
ordering_indices (Array, shape (N,)) – Ordering indices from walk algorithm. Valid indices (>= 0) indicate ordered tracers; -1 indicates skipped/unordered tracers.
config (OrderingTrainingConfig, optional) – Complete training configuration for both phases. If None (default), uses default configuration.
key (PRNGKeyArray) – Random key for training (split internally for each phase).
model – The autoencoder model to train. Its encoder will be updated.
all_ws – All phase-space coordinates (positions + velocities) in normalized form.
ordering_indices – Ordering indices from walk algorithm. Valid indices (>= 0) indicate ordered tracers; -1 indicates skipped/unordered tracers.
config – Training configuration for the encoder. If None, uses default config.
decoder_kwargs (Mapping, optional) – Keyword arguments passed to decoder function creation. For running-mean decoder, can include ‘window_size’. If None, uses defaults.
key – Random key for training.
- Return type:
- Returns:
result (AutoencoderResult) – Result containing the fully trained autoencoder and ordering data.
opt_states (dict[str, optax.OptState]) – Dictionary with ‘encoder’, ‘decoder’ and ‘both’ optimizer states.
losses (Array, shape (n_epochs_encoder + n_epochs_both,)) – Concatenated training losses from both phases.
.. py (function:: train_autoencoder(model: phasecurvefit._src.nn.abstractautoencoder.AbstractAutoencoder, walk_results: phasecurvefit._src.algorithm.WalkLocalFlowResult, /, *, config: phasecurvefit._src.nn.autoencoder.TrainingConfig | None = None, key: Union[jaxtyping.Key[Array, ‘’], jaxtyping.UInt32[Array, ‘2’]]) -> tuple[phasecurvefit._src.nn.result.AutoencoderResult, dict[str, jaxtyping.PyTree], jaxtyping.Float[Array, ‘{config.n_epochs}’]]) – :noindex:
.. py (function:: train_autoencoder(model: phasecurvefit._src.nn.externalautoencoder.EncoderExternalDecoder, all_ws: jaxtyping.Float[Array, ‘N TwoF’], ordering_indices: jaxtyping.Int[Array, ‘N’], /, *, config: phasecurvefit._src.nn.order_net.OrderingTrainingConfig | phasecurvefit._src.nn.autoencoder.TrainingConfig | None = None, key: Union[jaxtyping.Key[Array, ‘’], jaxtyping.UInt32[Array, ‘2’]]) -> tuple[phasecurvefit._src.nn.result.AutoencoderResult, dict[str, jaxtyping.PyTree], jaxtyping.Float[Array, ‘{config.n_epochs}’]]) – :noindex:
Train the EncoderExternalDecoder encoder and create running-mean decoder.
This function provides a simplified training workflow
1. Train the encoder (OrderingNet) using supervised learning from ordering indices
2. Create a running-mean decoder using the trained encoder and training data
Unlike train_autoencoder for PathAutoencoder, this does not train a decoder
network. Instead, it uses the provided (or default) decoder function.
- Returns:
result (AutoencoderResult) – Result containing the trained autoencoder and ordering data.
opt_state (dict[str, PyTree]) – Optimizer state from encoder training (wrapped in dict for consistency).
losses (Array, shape (n_epochs,)) – Training losses from encoder training.
Examples
>>> import jax.numpy as jnp >>> import jax.random as jr >>> import phasecurvefit as pcf
>>> key = jr.key(0) >>> N = 50 >>> positions = {"x": jnp.arange(N, dtype=float), "y": jnp.zeros(N)} >>> velocities = {"x": jnp.ones(N), "y": jnp.zeros(N)} >>> ordering = jnp.arange(N)
>>> model = pcf.nn.EncoderExternalDecoder( ... pcf.nn.OrderingNet(in_size=4, width_size=32, depth=2, key=jr.key(1)), ... pcf.nn.RunningMeanDecoder(window_size=0.05), ... pcf.nn.StandardScalerNormalizer(positions, velocities), ... )
Train (with minimal epochs for demonstration)
>>> qs_norm, ps_norm = model.normalizer.transform(positions, velocities) >>> ws_norm = jnp.concat([qs_norm, ps_norm], axis=1) >>> config = pcf.nn.OrderingTrainingConfig( ... n_epochs=10, batch_size=16, show_pbar=False ... ) >>> result, opt_state, losses = pcf.nn.train_autoencoder( ... model, ws_norm, ordering, config=config, key=jr.key(2) ... ) >>> losses.shape (10,)
- phasecurvefit.nn.fill_ordering_gaps(model: AbstractAutoencoder, result: AbstractResult, /, prob_threshold: float = 0.5)
Use trained autoencoder to fill gaps in phase-flow walk ordering.
This function predicts $gamma$ values for all tracers (including those skipped by phase-flow walk) and returns a complete ordering.
- Parameters:
model (PathAutoencoder) – Trained autoencoder model.
result (AbstractResult) – Result from walk_local_flow.
prob_threshold (float, optional) – Minimum membership probability to include. Default: 0.5.
- Returns:
result – Complete ordering including previously skipped tracers.
- Return type:
AutoencoderResult
Examples
>>> import jax >>> import jax.numpy as jnp >>> import phasecurvefit as pcf
>>> pos = {"x": jnp.linspace(0, 5, 20), "y": jnp.zeros(20)} >>> vel = {"x": jnp.ones(20), "y": jnp.zeros(20)} >>> result = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0) >>> keys = jax.random.split(jax.random.key(0), 2) >>> normalizer = pcf.nn.StandardScalerNormalizer(pos, vel) >>> model = pcf.nn.PathAutoencoder.make( ... normalizer, gamma_range=result.gamma_range, key=keys[0] ... ) >>> cfg = pcf.nn.TrainingConfig(show_pbar=False) >>> result, *_ = pcf.nn.train_autoencoder(model, result, config=cfg, key=keys[1])