CEBRA Best Practices with 🐟iTuna#
This notebook is based on the “Best Practices for Training CEBRA models” notebook
This demo shows a complete workflow for training CEBRA models with consistency evaluation using iTuna. We cover:
Setting up a CEBRA model
Loading neural data
Train/validation splits
Consistency evaluation with
ConsistencyEnsembleVisualizing and interpreting results
Grid search for hyperparameters
Prerequisites#
pip install ituna cebra[datasets,integrations]
import numpy as np
import matplotlib.pyplot as plt
from cebra import CEBRA
import cebra.datasets
import ituna
from ituna import ConsistencyEnsemble, metrics
1. Set up a CEBRA Model#
CEBRA is a self-supervised representation learning method for neural data. It learns embeddings that capture the temporal structure of neural activity.
CEBRA models are identifiable up to an affine transformation, so we use metrics.Linear() (which includes the intercept) as our indeterminacy class.
# Define a CEBRA-Time model
cebra_model = CEBRA(
model_architecture="offset10-model",
batch_size=512,
learning_rate=3e-4,
temperature=1.12,
max_iterations=500,
conditional="time",
output_dimension=3,
distance="cosine",
device="cuda_if_available",
verbose=True,
time_offsets=10,
)
2. Load the Data#
We’ll use the rat hippocampus dataset from CEBRA’s built-in datasets. This contains neural recordings from hippocampus during spatial navigation.
# Load hippocampus dataset
hippocampus = cebra.datasets.init("rat-hippocampus-single-achilles")
neural_data = hippocampus.neural.numpy()
position_labels = hippocampus.continuous_index.numpy()
print(f"Neural data shape: {neural_data.shape}")
print(f"Position labels shape: {position_labels.shape}")
3. Create Train/Validation Split#
For proper evaluation, we split the data temporally into training and validation sets.
# Time-based split (80% train, 20% validation)
split_idx = int(len(neural_data) * 0.8)
train_data = neural_data[:split_idx]
val_data = neural_data[split_idx:]
train_labels = position_labels[:split_idx]
val_labels = position_labels[split_idx:]
print(f"Train data: {train_data.shape}")
print(f"Validation data: {val_data.shape}")
4. Fit with ConsistencyEnsemble#
Now we wrap the CEBRA model in a ConsistencyEnsemble to train multiple instances and evaluate consistency.
# Create ConsistencyEnsemble with Linear indeterminacy (for CEBRA)
ensemble = ConsistencyEnsemble(
estimator=cebra_model,
consistency_transform=metrics.PairwiseConsistency(
indeterminacy=metrics.Linear(), # CEBRA is identifiable up to linear transform
symmetric=False,
include_diagonal=True,
),
random_states=5, # Train 5 models
)
# Fit on training data
ensemble.fit(train_data)
# Evaluate consistency
train_score = ensemble.score(train_data)
print(f"Train consistency score: {train_score:.4f}")
# Also check on validation data
val_score = ensemble.score(val_data)
print(f"Validation consistency score: {val_score:.4f}")
5. Visualize Embeddings#
Let’s visualize the learned embeddings colored by position.
# Get aligned embeddings
train_embeddings = ensemble.transform(train_data)
val_embeddings = ensemble.transform(val_data)
print(f"Train embedding shape: {train_embeddings.shape}")
print(f"Validation embedding shape: {val_embeddings.shape}")
# Plot 3D embeddings
fig = plt.figure(figsize=(12, 5))
# Train embeddings
ax1 = fig.add_subplot(121, projection="3d")
scatter1 = ax1.scatter(
train_embeddings[:, 0],
train_embeddings[:, 1],
train_embeddings[:, 2],
c=train_labels[:, 0],
cmap="rainbow",
s=1,
alpha=0.5,
)
ax1.set_title(f"Train (consistency: {train_score:.3f})")
ax1.set_xlabel("Dim 1")
ax1.set_ylabel("Dim 2")
ax1.set_zlabel("Dim 3")
# Validation embeddings
ax2 = fig.add_subplot(122, projection="3d")
scatter2 = ax2.scatter(
val_embeddings[:, 0],
val_embeddings[:, 1],
val_embeddings[:, 2],
c=val_labels[:, 0],
cmap="rainbow",
s=1,
alpha=0.5,
)
ax2.set_title(f"Validation (consistency: {val_score:.3f})")
ax2.set_xlabel("Dim 1")
ax2.set_ylabel("Dim 2")
ax2.set_zlabel("Dim 3")
plt.tight_layout()
plt.show()
6. Analyze Pairwise Consistency#
We can examine the consistency between individual model pairs.
# Get detailed pairwise scores
pairs, scores = train_embeddings.scores
print("Pairwise consistency scores:")
for (i, j), score in zip(pairs, scores):
print(f" Model {i} -> Model {j}: {score:.4f}")
print(f"\nMean pairwise score: {np.mean(scores):.4f}")
print(f"Std pairwise score: {np.std(scores):.4f}")
7. Grid Search with Consistency#
We can use sklearn’s GridSearchCV with iTuna to find hyperparameters that yield consistent representations.
from sklearn.model_selection import GridSearchCV
# Define parameter grid
param_grid = {
"estimator__temperature": [0.5, 1.0, 1.5],
"estimator__output_dimension": [3, 8],
}
# Create base ensemble
base_ensemble = ConsistencyEnsemble(
estimator=CEBRA(
model_architecture="offset10-model",
batch_size=512,
learning_rate=3e-4,
max_iterations=200, # Fewer iterations for grid search
conditional="time",
distance="cosine",
device="cuda_if_available",
verbose=False,
time_offsets=10,
),
consistency_transform=metrics.PairwiseConsistency(
indeterminacy=metrics.Linear(),
symmetric=False,
),
random_states=3,
)
# Run grid search
# Note: This uses consistency score as the optimization target
grid_search = GridSearchCV(
base_ensemble,
param_grid,
cv=2,
scoring="r2", # ConsistencyEnsemble.score() returns R2
verbose=1,
n_jobs=1,
)
# Fit (this will take a while)
# grid_search.fit(train_data)
# Uncomment after running grid search:
# print(f"Best parameters: {grid_search.best_params_}")
# print(f"Best consistency score: {grid_search.best_score_:.4f}")
Using Backends for Large Experiments#
For large-scale experiments with many hyperparameters, use iTuna’s caching backends to avoid re-training.
# Enable disk caching for grid search
with ituna.config.config_context(DEFAULT_BACKEND="disk_cache"):
# Models will be cached, so re-running is fast
ensemble.fit(train_data)
print(f"Consistency: {ensemble.score(train_data):.4f}")
Summary#
Key takeaways for CEBRA with iTuna:
Use
metrics.Linear()for CEBRA - CEBRA embeddings are identifiable up to linear transformationsTrain multiple seeds - Use
random_states=5or more for robust consistency estimatesCheck both train and validation - High consistency on both suggests stable representations
Use caching for grid search - Enable
disk_cachebackend to avoid re-trainingConsistency score > 0.9 - Generally indicates reliable, reproducible embeddings
For more examples, see:
ituna-experiments/cebra/- Extended CEBRA experimentsiTune Reference.ipynb- Comprehensive reference notebook