CEBRA Best Practices with 🐟iTuna#

This notebook is based on the “Best Practices for Training CEBRA models” notebook

This demo shows a complete workflow for training CEBRA models with consistency evaluation using iTuna. We cover:

  1. Setting up a CEBRA model

  2. Loading neural data

  3. Train/validation splits

  4. Consistency evaluation with ConsistencyEnsemble

  5. Visualizing and interpreting results

  6. Grid search for hyperparameters

Prerequisites#

pip install ituna cebra[datasets,integrations]
import numpy as np
import matplotlib.pyplot as plt
from cebra import CEBRA
import cebra.datasets

import ituna
from ituna import ConsistencyEnsemble, metrics

1. Set up a CEBRA Model#

CEBRA is a self-supervised representation learning method for neural data. It learns embeddings that capture the temporal structure of neural activity.

CEBRA models are identifiable up to an affine transformation, so we use metrics.Linear() (which includes the intercept) as our indeterminacy class.

# Define a CEBRA-Time model
cebra_model = CEBRA(
    model_architecture="offset10-model",
    batch_size=512,
    learning_rate=3e-4,
    temperature=1.12,
    max_iterations=500,
    conditional="time",
    output_dimension=3,
    distance="cosine",
    device="cuda_if_available",
    verbose=True,
    time_offsets=10,
)

2. Load the Data#

We’ll use the rat hippocampus dataset from CEBRA’s built-in datasets. This contains neural recordings from hippocampus during spatial navigation.

# Load hippocampus dataset
hippocampus = cebra.datasets.init("rat-hippocampus-single-achilles")

neural_data = hippocampus.neural.numpy()
position_labels = hippocampus.continuous_index.numpy()

print(f"Neural data shape: {neural_data.shape}")
print(f"Position labels shape: {position_labels.shape}")

3. Create Train/Validation Split#

For proper evaluation, we split the data temporally into training and validation sets.

# Time-based split (80% train, 20% validation)
split_idx = int(len(neural_data) * 0.8)

train_data = neural_data[:split_idx]
val_data = neural_data[split_idx:]

train_labels = position_labels[:split_idx]
val_labels = position_labels[split_idx:]

print(f"Train data: {train_data.shape}")
print(f"Validation data: {val_data.shape}")

4. Fit with ConsistencyEnsemble#

Now we wrap the CEBRA model in a ConsistencyEnsemble to train multiple instances and evaluate consistency.

# Create ConsistencyEnsemble with Linear indeterminacy (for CEBRA)
ensemble = ConsistencyEnsemble(
    estimator=cebra_model,
    consistency_transform=metrics.PairwiseConsistency(
        indeterminacy=metrics.Linear(),  # CEBRA is identifiable up to linear transform
        symmetric=False,
        include_diagonal=True,
    ),
    random_states=5,  # Train 5 models
)

# Fit on training data
ensemble.fit(train_data)
# Evaluate consistency
train_score = ensemble.score(train_data)
print(f"Train consistency score: {train_score:.4f}")

# Also check on validation data
val_score = ensemble.score(val_data)
print(f"Validation consistency score: {val_score:.4f}")

5. Visualize Embeddings#

Let’s visualize the learned embeddings colored by position.

# Get aligned embeddings
train_embeddings = ensemble.transform(train_data)
val_embeddings = ensemble.transform(val_data)

print(f"Train embedding shape: {train_embeddings.shape}")
print(f"Validation embedding shape: {val_embeddings.shape}")
# Plot 3D embeddings
fig = plt.figure(figsize=(12, 5))

# Train embeddings
ax1 = fig.add_subplot(121, projection="3d")
scatter1 = ax1.scatter(
    train_embeddings[:, 0],
    train_embeddings[:, 1],
    train_embeddings[:, 2],
    c=train_labels[:, 0],
    cmap="rainbow",
    s=1,
    alpha=0.5,
)
ax1.set_title(f"Train (consistency: {train_score:.3f})")
ax1.set_xlabel("Dim 1")
ax1.set_ylabel("Dim 2")
ax1.set_zlabel("Dim 3")

# Validation embeddings
ax2 = fig.add_subplot(122, projection="3d")
scatter2 = ax2.scatter(
    val_embeddings[:, 0],
    val_embeddings[:, 1],
    val_embeddings[:, 2],
    c=val_labels[:, 0],
    cmap="rainbow",
    s=1,
    alpha=0.5,
)
ax2.set_title(f"Validation (consistency: {val_score:.3f})")
ax2.set_xlabel("Dim 1")
ax2.set_ylabel("Dim 2")
ax2.set_zlabel("Dim 3")

plt.tight_layout()
plt.show()

6. Analyze Pairwise Consistency#

We can examine the consistency between individual model pairs.

# Get detailed pairwise scores
pairs, scores = train_embeddings.scores

print("Pairwise consistency scores:")
for (i, j), score in zip(pairs, scores):
    print(f"  Model {i} -> Model {j}: {score:.4f}")

print(f"\nMean pairwise score: {np.mean(scores):.4f}")
print(f"Std pairwise score: {np.std(scores):.4f}")

7. Grid Search with Consistency#

We can use sklearn’s GridSearchCV with iTuna to find hyperparameters that yield consistent representations.

from sklearn.model_selection import GridSearchCV

# Define parameter grid
param_grid = {
    "estimator__temperature": [0.5, 1.0, 1.5],
    "estimator__output_dimension": [3, 8],
}

# Create base ensemble
base_ensemble = ConsistencyEnsemble(
    estimator=CEBRA(
        model_architecture="offset10-model",
        batch_size=512,
        learning_rate=3e-4,
        max_iterations=200,  # Fewer iterations for grid search
        conditional="time",
        distance="cosine",
        device="cuda_if_available",
        verbose=False,
        time_offsets=10,
    ),
    consistency_transform=metrics.PairwiseConsistency(
        indeterminacy=metrics.Linear(),
        symmetric=False,
    ),
    random_states=3,
)

# Run grid search
# Note: This uses consistency score as the optimization target
grid_search = GridSearchCV(
    base_ensemble,
    param_grid,
    cv=2,
    scoring="r2",  # ConsistencyEnsemble.score() returns R2
    verbose=1,
    n_jobs=1,
)

# Fit (this will take a while)
# grid_search.fit(train_data)
# Uncomment after running grid search:
# print(f"Best parameters: {grid_search.best_params_}")
# print(f"Best consistency score: {grid_search.best_score_:.4f}")

Using Backends for Large Experiments#

For large-scale experiments with many hyperparameters, use iTuna’s caching backends to avoid re-training.

# Enable disk caching for grid search
with ituna.config.config_context(DEFAULT_BACKEND="disk_cache"):
    # Models will be cached, so re-running is fast
    ensemble.fit(train_data)
    print(f"Consistency: {ensemble.score(train_data):.4f}")

Summary#

Key takeaways for CEBRA with iTuna:

  1. Use metrics.Linear() for CEBRA - CEBRA embeddings are identifiable up to linear transformations

  2. Train multiple seeds - Use random_states=5 or more for robust consistency estimates

  3. Check both train and validation - High consistency on both suggests stable representations

  4. Use caching for grid search - Enable disk_cache backend to avoid re-training

  5. Consistency score > 0.9 - Generally indicates reliable, reproducible embeddings

For more examples, see:

  • ituna-experiments/cebra/ - Extended CEBRA experiments

  • iTune Reference.ipynb - Comprehensive reference notebook