Autoencoder for Gap Filling

Autoencoder for Gap Filling#

The walk algorithm skips some tracers due to the momentum condition. This guide explains how to use an autoencoder to assign ordering values (\(\gamma\)) to these skipped tracers.

Problem and Solution#

Problem: walk inevitably skips tracers that don’t align with the velocity direction.

Solution: An autoencoder with two networks:

  • Encoder: \((x, v) \rightarrow (\gamma, p)\) β€” predicts ordering and membership probability

  • Decoder: \(\gamma \rightarrow x\) β€” reconstructs position from ordering

The encoder learns from the walk-ordered tracers and generalizes to predict \(\gamma\) for skipped tracers.

Quick Start#

import jax
import jax.numpy as jnp
import phasecurvefit as pcf

# Get initial ordering from walk
pos = {"x": jnp.linspace(0, 5, 50), "y": jnp.sin(jnp.linspace(0, jnp.pi, 50))}
vel = {"x": jnp.ones(50), "y": jnp.cos(jnp.linspace(0, jnp.pi, 50))}
walkresult = pcf.walk_local_flow(pos, vel, start_idx=0, metric_scale=1.0)

# Create normalizer and autoencoder
key = jax.random.key(0)
normalizer = pcf.nn.StandardScalerNormalizer(pos, vel)
autoencoder = pcf.nn.PathAutoencoder.make(
    normalizer, gamma_range=walkresult.gamma_range, key=key
)

# Train autoencoder
config = pcf.nn.TrainingConfig(show_pbar=False)
result, _, losses = pcf.nn.train_autoencoder(
    autoencoder, walkresult, config=config, key=key
)

gamma = result.gamma
ordered_all = result.indices

How It Works#

  1. Initialization: Walk assigns \(\gamma \in [-1, 1]\) to ordered tracers

  2. Phase 1: Encoder learns to predict \(\gamma\) from phase-space coordinates

  3. Phase 2: Both networks train together with momentum constraint β€” ensures velocity alignment

  4. Membership: Network outputs probability \(p\) to distinguish stream from background

Customizing Training#

The default settings appear to work for most cases, but can be set by the user.

config = pcf.nn.TrainingConfig(
    n_epochs_encoder=800,  # Encoder-only epochs
    n_epochs_decoder=100,  # Decoder-only epochs
    n_epochs_both=200,  # En+Decoder epochs
    batch_size=100,  # Batch size for training
    lambda_prob=1.0,  # Probability loss weight
    lambda_q=1.0,  # Spatial reconstruction loss weight
    lambda_p=(1.0, 150.0),  # Velocity alignment loss weight range
    show_pbar=False,
)

result, _, losses = pcf.nn.train_autoencoder(
    autoencoder, walkresult, config=config, key=key
)

Key parameters:

  • lambda_p: Higher maximum (100-150) enforces stronger velocity alignment in Phase 2

  • n_epochs_encoder: Should be ~200-500 for good initial interpolation

  • batch_size: Larger batches are more stable but require more memory

  • lambda_q: Weight for spatial reconstruction loss in Phase 2