Skip to content

API Reference

BaseGeneModel

Bases: ABC, Module

Base class for modeling the expression of gene expression using distributions with gene-specific parameters that are shared across cells.

Parameters:

Name Type Description Default
n_genes int

The number of genes to model.

required
Source code in iicd_workshop_2024/gene_model.py
class BaseGeneModel(abc.ABC, torch.nn.Module):
    """
    Base class for modeling the expression of gene expression using distributions with
    gene-specific parameters that are shared across cells.

    Args:
        n_genes (int): The number of genes to model.
    """

    def __init__(self, n_genes):
        super().__init__()
        self.n_genes = n_genes

    @property
    def distribution_name(self):
        """
        Get the name of the distribution used for modeling the gene expression.
        """
        return self.get_distribution().__class__.__name__.lower()

    @abc.abstractmethod
    def get_distribution(self, gene_idx=None) -> dist.Distribution:
        """
        Get the distribution that models the gene expression.

        Args:
            gene_idx (int or list[int] or None): If None, return the distribution over all genes. Otherwise, return the distribution
                of the specified gene or list of genes (given by their indices).

        Returns:
            dist.Distribution or list[dist.Distribution]: The distribution(s) of the gene(s).
        """
        pass

    def loss(self, data) -> torch.Tensor:
        """
        Return the negative log-likelihood of the data given the model.

        Returns:
            torch.Tensor: The negative log-likelihood of the data given the model.
        """
        return -self.get_distribution().log_prob(data).mean()

    def fit(self, adata, epochs=100, batch_size=128, lr=1e-2):
        """
        Fit the model to the data.

        Args:
            adata (AnnData): Annotated data matrix.
            epochs (int): Number of epochs to train the model.
            batch_size (int): Batch size.
            lr (float): Learning rate.
        """
        fit(self, adata, epochs=epochs, batch_size=batch_size, lr=lr)

distribution_name property

Get the name of the distribution used for modeling the gene expression.

fit(adata, epochs=100, batch_size=128, lr=0.01)

Fit the model to the data.

Parameters:

Name Type Description Default
adata AnnData

Annotated data matrix.

required
epochs int

Number of epochs to train the model.

100
batch_size int

Batch size.

128
lr float

Learning rate.

0.01
Source code in iicd_workshop_2024/gene_model.py
def fit(self, adata, epochs=100, batch_size=128, lr=1e-2):
    """
    Fit the model to the data.

    Args:
        adata (AnnData): Annotated data matrix.
        epochs (int): Number of epochs to train the model.
        batch_size (int): Batch size.
        lr (float): Learning rate.
    """
    fit(self, adata, epochs=epochs, batch_size=batch_size, lr=lr)

get_distribution(gene_idx=None) abstractmethod

Get the distribution that models the gene expression.

Parameters:

Name Type Description Default
gene_idx int or list[int] or None

If None, return the distribution over all genes. Otherwise, return the distribution of the specified gene or list of genes (given by their indices).

None

Returns:

Type Description
Distribution

dist.Distribution or list[dist.Distribution]: The distribution(s) of the gene(s).

Source code in iicd_workshop_2024/gene_model.py
@abc.abstractmethod
def get_distribution(self, gene_idx=None) -> dist.Distribution:
    """
    Get the distribution that models the gene expression.

    Args:
        gene_idx (int or list[int] or None): If None, return the distribution over all genes. Otherwise, return the distribution
            of the specified gene or list of genes (given by their indices).

    Returns:
        dist.Distribution or list[dist.Distribution]: The distribution(s) of the gene(s).
    """
    pass

loss(data)

Return the negative log-likelihood of the data given the model.

Returns:

Type Description
Tensor

torch.Tensor: The negative log-likelihood of the data given the model.

Source code in iicd_workshop_2024/gene_model.py
def loss(self, data) -> torch.Tensor:
    """
    Return the negative log-likelihood of the data given the model.

    Returns:
        torch.Tensor: The negative log-likelihood of the data given the model.
    """
    return -self.get_distribution().log_prob(data).mean()

plot_gene_distribution(model, adata, genes, n_cols=3)

Plot the learned distributions and the empirical distributions of the genes.

Parameters:

Name Type Description Default
model BaseGeneModel

The gene model.

required
adata AnnData

The annotated data matrix.

required
genes list[str]

The list of genes to plot.

required
n_cols int

The number of columns in the plot.

3
Source code in iicd_workshop_2024/gene_model.py
def plot_gene_distribution(model: BaseGeneModel, adata, genes, n_cols=3):
    """
    Plot the learned distributions and the empirical distributions of the genes.

    Args:
        model (BaseGeneModel): The gene model.
        adata (AnnData): The annotated data matrix.
        genes (list[str]): The list of genes to plot.
        n_cols (int): The number of columns in the plot.
    """
    n_rows = int(np.ceil(len(genes) / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3), squeeze=False)
    for i, gene in enumerate(genes):
        ax = axs[i // n_cols, i % n_cols]
        gene_idx = adata.var["gene_symbols"].tolist().index(gene)
        sns.histplot(adata.X[:, gene_idx].toarray(), stat="density", discrete=True, ax=ax)
        max_value = adata.X[:, gene_idx].max().item()
        if model.distribution_name in ["poisson", "negativebinomial"]:
            x = torch.arange(0, max_value + 1)
        else:
            x = torch.linspace(
                min(
                    -5,
                    model.get_distribution(gene_idx).mean.item()
                    - 2 * model.get_distribution(gene_idx).stddev.item(),
                ),
                max_value,
                200,
            )
        y = model.get_distribution(gene_idx).log_prob(x).exp().detach().numpy()
        sns.lineplot(x=x, y=y, ax=ax, color="red")
        ax.set_title(gene + f" (idx={gene_idx})")
    plt.tight_layout()
    plt.show()

fit(model, adata, epochs=100, batch_size=128, lr=0.01)

Fit the model to the data.

Parameters:

Name Type Description Default
model Module

The model to fit.

required
adata AnnData

The annotated data matrix.

required
epochs int

Number of epochs to train the model.

100
batch_size int

Batch size.

128
lr float

Learning rate.

0.01
Source code in iicd_workshop_2024/inference.py
def fit(model, adata, epochs=100, batch_size=128, lr=1e-2):
    """
    Fit the model to the data.

    Args:
        model (nn.Module): The model to fit.
        adata (AnnData): The annotated data matrix.
        epochs (int): Number of epochs to train the model.
        batch_size (int): Batch size.
        lr (float): Learning rate.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    data_X = adata.X
    # check if sparse
    if isinstance(data_X, scipy.sparse.csr_matrix):
        data_X = data_X.toarray()
    data_loader = torch.utils.data.DataLoader(data_X, batch_size=batch_size, shuffle=True)
    pbar = tqdm.tqdm(total=epochs * len(data_loader))
    for _ in range(epochs):
        for x in data_loader:
            optimizer.zero_grad()
            loss = model.loss(x).mean()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item())
            pbar.update()
    pbar.close()

DenseNN

Bases: Module

A simple feedforward neural network with ReLU activation function.

Parameters:

Name Type Description Default
n_input int

The number of input features.

required
n_output int

The number of output features.

required
n_hidden int

The number of hidden units in each hidden layer.

128
n_layers int

The number of hidden layers.

1
Source code in iicd_workshop_2024/neural_network.py
class DenseNN(torch.nn.Module):
    """
    A simple feedforward neural network with ReLU activation function.

    Args:
        n_input (int): The number of input features.
        n_output (int): The number of output features.
        n_hidden (int): The number of hidden units in each hidden layer.
        n_layers (int): The number of hidden layers.
    """

    def __init__(self, n_input: int, n_output: int, n_hidden: int = 128, n_layers: int = 1):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(n_input, n_hidden))
        for _ in range(n_layers - 1):
            self.layers.append(torch.nn.Linear(n_hidden, n_hidden))
        self.layers.append(torch.nn.Linear(n_hidden, n_output))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the neural network.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)

forward(x)

Forward pass of the neural network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor.

Source code in iicd_workshop_2024/neural_network.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the neural network.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor.
    """
    for layer in self.layers[:-1]:
        x = torch.relu(layer(x))
    return self.layers[-1](x)