multivelovae.VAEChrom

class multivelovae.VAEChrom(adata, adata_atac=None, dim_z=None, batch_key=None, ref_batch=None, batch_hvg_key=None, var_to_regress=None, device='cpu', hidden_size=256, full_vb=False, parallel_arch=True, t_network=True, four_basis=False, run_2nd_stage=True, tmax=1, init_method='steady', init_tprior=None, tprior=None, deming_std=False, rna_only=False, rna_only_idx=[], learning_rate=None, early_stop_thred=None, checkpoints=[None, None], plot_init=False, gene_plot=[], cluster_key='clusters', figure_path='figures', embed=None, vram_constrained=False)

MultiVeloVAE model for joint multi-omics velocity inference.

This is the main class implementing the MultiVeloVAE model, which integrates chromatin accessibility and RNA sequencing data to infer cellular dynamics through a variational autoencoder framework with a mechanistic ODE model.

The model learns a low-dimensional latent representation of cell state, a latent time for each cell, and the parameters of velocity ODE equations that explain the observed data. It can handle batch effects through a conditional VAE design. It can handle both multi-omic and RNA-only samples.

The model consists of: 1. An encoder that maps data to a distribution over latent variables 2. A decoder that generates predictions using a mechanistic ODE model 3. Training procedures including two-stage refinement

Attributes:

adata: AnnData object containing RNA data adata_atac: AnnData object containing chromatin accessibility data encoder: Neural network encoder decoder: Neural network decoder with ODE model config: Dictionary containing model configuration and hyperparameters Other attributes for training and inference

Args:
adata (AnnData):

Input AnnData object for RNA data.

adata_atac (AnnData, optional):

Input AnnData object for chromatin data. Defaults to None for RNA-only.

dim_z (int, optional):

Latent cell state dimension. Defaults to None.

batch_key (str, optional):

Key of batch labels in adata.obs. Defaults to None.

ref_batch (int, optional):

Index to use as the reference batch. Defaults to None.

batch_hvg_key (str, optional):

Prefix of key for batch highly-variable genes in adata.var. Defaults to None.

var_to_regress (str or list, optional):

Continuous variable(s) to be regressed out. Defaults to None.

device (torch.device, optional):

Device used for training. Defaults to ‘cpu’.

hidden_size (int, optional):

The width of the hidden layers of the encoder and decoder. Defaults to 256.

full_vb (bool, optional):

Whether to use full variational Bayes for rate parameters. Defaults to False.

parallel_arch (bool, optional):

Whether to use parallel architecture for the indicator functions. Defaults to True.

t_network (bool, optional):

Whether to use a neural network to estimate the time distribution. Defaults to True.

four_basis (bool, optional):

Whether to enable BasisVAE to model genes as induction and repression. Defaults to False.

run_2nd_stage (bool, optional):

Whether to run the second stage of training. Defaults to True.

tmax (float):

Maximum time, specifies the time range for initialization.

init_method (str, optional):

{‘steady’, ‘tprior’}, initialization method. Defaults to ‘steady’.

init_tprior (str, optional):

Key in adata.obs for initialization. Defaults to None.

tprior (str, optional):

Key in adata.obs containing the informative time prior. Defaults to None.

deming_std (bool, optional):

Whether to use Deming residual for the loss function std. Defaults to False.

rna_only (bool, optional):

Whether to run in RNA-only mode. Defaults to False.

rna_only_idx (list, optional):

List of indices of RNA-only samples. Defaults to [].

learning_rate (float, optional):

Learning rate for training. Defaults to None.

early_stop_thred (float, optional):

Early stopping threshold for training. Defaults to None.

checkpoints (list, optional):

List of two .pt files with pretrained parameters. Defaults to [None, None].

plot_init (bool, optional):

Whether to plot initialization results. Defaults to False.

gene_plot (list, optional):

List of gene names to plot. Defaults to [].

cluster_key (str, optional):

Key in adata.obs for plot colors. Defaults to ‘clusters’.

figure_path (str, optional):

Path to save figures. Defaults to ‘figures’.

embed (str, optional):

Key in adata.obsm of 2D embedding (tsne, umap, etc.). Defaults to None.

vram_constrained (bool, optional):

Whether to enable VRAM-constrained mode. Defaults to False.

__init__(adata, adata_atac=None, dim_z=None, batch_key=None, ref_batch=None, batch_hvg_key=None, var_to_regress=None, device='cpu', hidden_size=256, full_vb=False, parallel_arch=True, t_network=True, four_basis=False, run_2nd_stage=True, tmax=1, init_method='steady', init_tprior=None, tprior=None, deming_std=False, rna_only=False, rna_only_idx=[], learning_rate=None, early_stop_thred=None, checkpoints=[None, None], plot_init=False, gene_plot=[], cluster_key='clusters', figure_path='figures', embed=None, vram_constrained=False)

MultiVeloVAE Model

Args:
adata ((anndata.AnnData)):

Input AnnData object for RNA data.

adata_atac ((anndata.AnnData)):

Input AnnData object for chromatin data. Defaults to None for RNA-only.

dim_z (int, optional):

Latent cell state dimension. Defaults to None.

batch_key (str, optional):

Key of batch labels in adata.obs. Defaults to None.

ref_batch (int, optional):

Index to use as the reference batch. Defaults to None.

batch_hvg_key (str, optional):

Prefix of key for batch highly-variable genes in adata.var. Defaults to None.

var_to_regress (str or list, optional):

Continuous variable(s) to be regressed out. Defaults to None.

device (torch.device, optional):

Device used for training. Defaults to ‘cpu’.

hidden_size (int, optional):

The width of the hidden layers of the encoder and decoder. Defaults to 256.

full_vb (bool, optional):

Whether to use the full variational Bayes feature to estimate rate parameter uncertainty. Defaults to False.

parallel_arch (bool, optional):

Whether to use parallel architecture for the indicator functions. Defaults to True.

t_network (bool, optional):

Whether to use a neural network to estimate the time distribution. Defaults to True.

four_basis (bool, optional):

Whether to enable BasisVAE to model genes as induction and repression. Defaults to False.

run_2nd_stage (bool, optional):

Whether to run the second stage of training. Defaults to True.

tmax (float):

Maximum time, specifies the time range for initialization.

init_method (str, optional):

{‘steady’, ‘tprior’}, initialization method. Defaults to ‘steady’.

init_tprior (str, optional):

Key in adata.obs storing the capture time or any prior time information. This is used in initialization. Defaults to None.

tprior (str, optional):

Key in adata.obs containing the informative time prior. This is used in model training. Defaults to None.

deming_std (bool, optional):

Whether to use Deming residual for the standard deviation of the loss function. Defaults to False.

rna_only (bool, optional):

Whether to run in RNA-only mode. Defaults to False.

rna_only_idx (list, optional):

List of indices of RNA-only samples. Defaults to [].

learning_rate (float, optional):

Learning rate for training. Defaults to None.

early_stop_thred (float, optional):

Early stopping threshold for training. Defaults to None.

checkpoints (list, optional):

Contains a list of two .pt files containing pretrained or saved model parameters. Defaults to [None, None].

plot_init (bool, optional):

Whether to plot the initialization results. Defaults to False.

gene_plot (list, optional):

List of gene names to plot. Defaults to [].

cluster_key (str, optional):

Key in adata.obs containing the cluster labels for plot colors. Defaults to ‘clusters’.

figure_path (str, optional):

Path to save the figures. Defaults to ‘figures’.

embed (str, optional):

Key in adata.obsm of 2D embedding (tsne, umap, etc.). Defaults to None.

vram_constrained (bool, optional):

Whether to enable VRAM-constrained mode. Defaults to False.

Methods

__init__(adata[, adata_atac, dim_z, ...])

MultiVeloVAE Model

encode_batch(adata, adata_atac, batch_key, ...)

eval(dataset, Xembed[, testid, test_mode, ...])

forward(data_in[, data_in_e, c0, u0, s0, ...])

get_prior(adata)

init_regressor(adata, var_to_regress)

plot_initial(gene_plot[, figure_path, embed])

pred_all(dataset[, mode, output, gene_idx, ...])

prepare_dataset([c, u, s, only_full])

Prepare dataset for training and testing

save_anndata([file_path, file_name])

Function to save variables to Anndata object

save_model([file_path, enc_name, dec_name])

Function to save trained models

save_state_dict([loss_test, reset])

set_device(device)

set_lr(adata, adata_atac, learning_rate)

set_mode(mode[, net])

split_train_test(N)

test(dataset[, batch, covar, k, sample, ...])

Predict latent variables and modality outputs for new data

train([config, plot, gene_plot, ...])

The high-level API for training

train_epoch(train_loader, test_set, optimizer)

update_config(config)

update_std_noise(dataset)

update_x0(dataset[, save])

vae_risk(q_tx, p_t, q_zx, p_z, q_ex, c, u, ...)