multivelovae.VAEChrom
- class multivelovae.VAEChrom(adata, adata_atac=None, dim_z=None, batch_key=None, ref_batch=None, batch_hvg_key=None, var_to_regress=None, device='cpu', hidden_size=256, full_vb=False, parallel_arch=True, t_network=True, four_basis=False, run_2nd_stage=True, tmax=1, init_method='steady', init_tprior=None, tprior=None, deming_std=False, rna_only=False, rna_only_idx=[], learning_rate=None, early_stop_thred=None, checkpoints=[None, None], plot_init=False, gene_plot=[], cluster_key='clusters', figure_path='figures', embed=None, vram_constrained=False)
MultiVeloVAE model for joint multi-omics velocity inference.
This is the main class implementing the MultiVeloVAE model, which integrates chromatin accessibility and RNA sequencing data to infer cellular dynamics through a variational autoencoder framework with a mechanistic ODE model.
The model learns a low-dimensional latent representation of cell state, a latent time for each cell, and the parameters of velocity ODE equations that explain the observed data. It can handle batch effects through a conditional VAE design. It can handle both multi-omic and RNA-only samples.
The model consists of: 1. An encoder that maps data to a distribution over latent variables 2. A decoder that generates predictions using a mechanistic ODE model 3. Training procedures including two-stage refinement
- Attributes:
adata: AnnData object containing RNA data adata_atac: AnnData object containing chromatin accessibility data encoder: Neural network encoder decoder: Neural network decoder with ODE model config: Dictionary containing model configuration and hyperparameters Other attributes for training and inference
- Args:
- adata (AnnData):
Input AnnData object for RNA data.
- adata_atac (AnnData, optional):
Input AnnData object for chromatin data. Defaults to None for RNA-only.
- dim_z (int, optional):
Latent cell state dimension. Defaults to None.
- batch_key (str, optional):
Key of batch labels in adata.obs. Defaults to None.
- ref_batch (int, optional):
Index to use as the reference batch. Defaults to None.
- batch_hvg_key (str, optional):
Prefix of key for batch highly-variable genes in adata.var. Defaults to None.
- var_to_regress (str or list, optional):
Continuous variable(s) to be regressed out. Defaults to None.
- device (torch.device, optional):
Device used for training. Defaults to ‘cpu’.
- hidden_size (int, optional):
The width of the hidden layers of the encoder and decoder. Defaults to 256.
- full_vb (bool, optional):
Whether to use full variational Bayes for rate parameters. Defaults to False.
- parallel_arch (bool, optional):
Whether to use parallel architecture for the indicator functions. Defaults to True.
- t_network (bool, optional):
Whether to use a neural network to estimate the time distribution. Defaults to True.
- four_basis (bool, optional):
Whether to enable BasisVAE to model genes as induction and repression. Defaults to False.
- run_2nd_stage (bool, optional):
Whether to run the second stage of training. Defaults to True.
- tmax (float):
Maximum time, specifies the time range for initialization.
- init_method (str, optional):
{‘steady’, ‘tprior’}, initialization method. Defaults to ‘steady’.
- init_tprior (str, optional):
Key in adata.obs for initialization. Defaults to None.
- tprior (str, optional):
Key in adata.obs containing the informative time prior. Defaults to None.
- deming_std (bool, optional):
Whether to use Deming residual for the loss function std. Defaults to False.
- rna_only (bool, optional):
Whether to run in RNA-only mode. Defaults to False.
- rna_only_idx (list, optional):
List of indices of RNA-only samples. Defaults to [].
- learning_rate (float, optional):
Learning rate for training. Defaults to None.
- early_stop_thred (float, optional):
Early stopping threshold for training. Defaults to None.
- checkpoints (list, optional):
List of two .pt files with pretrained parameters. Defaults to [None, None].
- plot_init (bool, optional):
Whether to plot initialization results. Defaults to False.
- gene_plot (list, optional):
List of gene names to plot. Defaults to [].
- cluster_key (str, optional):
Key in adata.obs for plot colors. Defaults to ‘clusters’.
- figure_path (str, optional):
Path to save figures. Defaults to ‘figures’.
- embed (str, optional):
Key in adata.obsm of 2D embedding (tsne, umap, etc.). Defaults to None.
- vram_constrained (bool, optional):
Whether to enable VRAM-constrained mode. Defaults to False.
- __init__(adata, adata_atac=None, dim_z=None, batch_key=None, ref_batch=None, batch_hvg_key=None, var_to_regress=None, device='cpu', hidden_size=256, full_vb=False, parallel_arch=True, t_network=True, four_basis=False, run_2nd_stage=True, tmax=1, init_method='steady', init_tprior=None, tprior=None, deming_std=False, rna_only=False, rna_only_idx=[], learning_rate=None, early_stop_thred=None, checkpoints=[None, None], plot_init=False, gene_plot=[], cluster_key='clusters', figure_path='figures', embed=None, vram_constrained=False)
MultiVeloVAE Model
- Args:
- adata ((
anndata.AnnData)): Input AnnData object for RNA data.
- adata_atac ((
anndata.AnnData)): Input AnnData object for chromatin data. Defaults to None for RNA-only.
- dim_z (int, optional):
Latent cell state dimension. Defaults to None.
- batch_key (str, optional):
Key of batch labels in adata.obs. Defaults to None.
- ref_batch (int, optional):
Index to use as the reference batch. Defaults to None.
- batch_hvg_key (str, optional):
Prefix of key for batch highly-variable genes in adata.var. Defaults to None.
- var_to_regress (str or list, optional):
Continuous variable(s) to be regressed out. Defaults to None.
- device (torch.device, optional):
Device used for training. Defaults to ‘cpu’.
- hidden_size (int, optional):
The width of the hidden layers of the encoder and decoder. Defaults to 256.
- full_vb (bool, optional):
Whether to use the full variational Bayes feature to estimate rate parameter uncertainty. Defaults to False.
- parallel_arch (bool, optional):
Whether to use parallel architecture for the indicator functions. Defaults to True.
- t_network (bool, optional):
Whether to use a neural network to estimate the time distribution. Defaults to True.
- four_basis (bool, optional):
Whether to enable BasisVAE to model genes as induction and repression. Defaults to False.
- run_2nd_stage (bool, optional):
Whether to run the second stage of training. Defaults to True.
- tmax (float):
Maximum time, specifies the time range for initialization.
- init_method (str, optional):
{‘steady’, ‘tprior’}, initialization method. Defaults to ‘steady’.
- init_tprior (str, optional):
Key in adata.obs storing the capture time or any prior time information. This is used in initialization. Defaults to None.
- tprior (str, optional):
Key in adata.obs containing the informative time prior. This is used in model training. Defaults to None.
- deming_std (bool, optional):
Whether to use Deming residual for the standard deviation of the loss function. Defaults to False.
- rna_only (bool, optional):
Whether to run in RNA-only mode. Defaults to False.
- rna_only_idx (list, optional):
List of indices of RNA-only samples. Defaults to [].
- learning_rate (float, optional):
Learning rate for training. Defaults to None.
- early_stop_thred (float, optional):
Early stopping threshold for training. Defaults to None.
- checkpoints (list, optional):
Contains a list of two .pt files containing pretrained or saved model parameters. Defaults to [None, None].
- plot_init (bool, optional):
Whether to plot the initialization results. Defaults to False.
- gene_plot (list, optional):
List of gene names to plot. Defaults to [].
- cluster_key (str, optional):
Key in adata.obs containing the cluster labels for plot colors. Defaults to ‘clusters’.
- figure_path (str, optional):
Path to save the figures. Defaults to ‘figures’.
- embed (str, optional):
Key in adata.obsm of 2D embedding (tsne, umap, etc.). Defaults to None.
- vram_constrained (bool, optional):
Whether to enable VRAM-constrained mode. Defaults to False.
- adata ((
Methods
__init__(adata[, adata_atac, dim_z, ...])MultiVeloVAE Model
encode_batch(adata, adata_atac, batch_key, ...)eval(dataset, Xembed[, testid, test_mode, ...])forward(data_in[, data_in_e, c0, u0, s0, ...])get_prior(adata)init_regressor(adata, var_to_regress)plot_initial(gene_plot[, figure_path, embed])pred_all(dataset[, mode, output, gene_idx, ...])prepare_dataset([c, u, s, only_full])Prepare dataset for training and testing
save_anndata([file_path, file_name])Function to save variables to Anndata object
save_model([file_path, enc_name, dec_name])Function to save trained models
save_state_dict([loss_test, reset])set_device(device)set_lr(adata, adata_atac, learning_rate)set_mode(mode[, net])split_train_test(N)test(dataset[, batch, covar, k, sample, ...])Predict latent variables and modality outputs for new data
train([config, plot, gene_plot, ...])The high-level API for training
train_epoch(train_loader, test_set, optimizer)update_config(config)update_std_noise(dataset)update_x0(dataset[, save])vae_risk(q_tx, p_t, q_zx, p_z, q_ex, c, u, ...)