import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# TODO:
# - make this "exchangeable" by shuffling all columns but the last (requires handling packed_sequence separately)
[docs]
class RNN(nn.Module):
def __init__(self, input_size, output_size, num_layers=2, dropout=0.0):
"""
:param input_size: the input size of the GRU layer, e.g. num_individuals*ploidy
or num_individuals depending if the data is phased or not
:param output_size: the output size of the network
"""
super().__init__()
self.rnn = nn.GRU(
input_size, 84, num_layers=num_layers, batch_first=True, bidirectional=True
)
self.mlp = nn.Sequential(
nn.Linear(168 * num_layers, 256),
nn.Dropout(dropout),
nn.Linear(256, output_size),
)
[docs]
def forward(self, x):
_, hn = self.rnn(x) # (2 * layers, batch, 84)
hn = hn.permute(1, 0, 2).reshape(hn.shape[1], -1) # (batch, 2 * layers * 84)
return self.mlp(hn)
class FrozenLayerNorm(nn.Module):
"""
Layer norm without learnable weights, that can be applied to tensors
of variable sizes. The batch dimension is assumed to be the first, so
that mean/variance are calculated over the remaining dimensions.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return F.layer_norm(x, x.shape[1:])
# TODO: SBI has a built-in exchangeable layer, why not use this?
[docs]
class SymmetricLayer(nn.Module):
"""
Layer that performs some permutation-invariant function along a
specified axis of input data.
The permuation invariant function can be any of max, mean, or sum
"""
def __init__(self, axis, func="max"):
super().__init__()
self.axis = axis
self.func = func
[docs]
def forward(self, x):
if self.func == "max":
# TODO: why is this indexed?
return torch.max(x, dim=self.axis, keepdim=True)[0]
elif self.func == "mean":
return torch.mean(x, dim=self.axis, keepdim=True)
elif self.func == "sum":
return torch.sum(x, dim=self.axis, keepdim=True)
else:
raise ValueError("func must be one of 'max', 'mean', or 'sum'")
# TODO cleanup:
# - BUG: this isn't working if sample sizes aren't equal across pops
# - should work with a single pop given the ts_processor (make as general as possible)
# - let the number of channels/kernel sizes/etc be settable
# - remove the need to specify the unpadded input shapes, this can be figured out in forward
# - the logic in forward requires a batch dimension
# - could use the built-in symmetric layer from SBI
[docs]
class ExchangeableCNN(nn.Module):
"""
This implements the Exchangeable CNN or permuation-invariant CNN from:
Chan et al. 2018, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7687905/
which builds in the invariance of the haplotype matrices to permutations of the individuals
Main difference is that the first cnn has wider kernel and stride to capture the long range LD.
If input features come from multiple populations that may differ in num_snps and/or
num_individuals, then provide a list of tuples with each populations haplotype matrix
shape in unmasked_x_shps. The forward pass will then mask out all padded values of -1
which pad each haplotype matrix to the shape of the largest in the set
It has two cnn layers, followed by symmetric layer that pools over the individual axis and feature extractor (fully connected network).
Each CNN layer has 2D convolution layer with kernel and stride height = 1, ELU activation, and Batch normalization layer.
If the number of popultion is greater than one, the output of the first CNN layer is concatenated along the last axis.
(same as pg-gan by Mathieson et al.)
Then global pool make output dim (batch_size, outchannels2, 1, 1) and then pass to the feature extractor.
"""
def __init__(
self,
output_dim=64,
input_rows=None,
input_cols=None,
channels=2,
symmetric_func="max",
):
"""
:param output_dim: The desired dimension of the final 1D output vector
to be used as the embedded data for training
:param input_rows: The number of rows (samples) per population genotype matrix
:param input_cols: The number of cols (SNPs) per population genotype matrix
:param channels: The number of channels in the input matrices. HaplotypeMatrices
have 2 channels and BinnedHaplotypeMatrices have 1 channel
:param symmetric_func: String denoting which symmetric function to use in our
permutation invariant layers
"""
super().__init__()
self.outchannels1 = 32
self.outchannels2 = 160
self.kernel_size1 = (1, 5)
self.kernel_size2 = (1, 5)
self.stride1 = (1, 2)
self.stride2 = (1, 2)
self.activation = nn.ELU
self.unmasked_x_shps = None
if input_rows is not None and input_cols is not None:
assert len(input_rows) == len(input_cols)
self.unmasked_x_shps = [
(channels, r, c) for r, c in zip(input_rows, input_cols)
]
cnn_layers = []
cnn_layers.append(
nn.Conv2d(
channels, self.outchannels1, self.kernel_size1, stride=self.stride1
)
)
cnn_layers.append(self.activation())
cnn_layers.append(FrozenLayerNorm())
cnn_layers.append(
nn.Conv2d(
self.outchannels1,
self.outchannels2,
self.kernel_size2,
stride=self.stride2,
)
)
cnn_layers.append(self.activation())
cnn_layers.append(FrozenLayerNorm())
self.cnn = nn.Sequential(*cnn_layers)
self.symmetric = SymmetricLayer(axis=2, func=symmetric_func)
self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
self.feature_extractor = nn.Sequential(
nn.Flatten(),
nn.Linear(self.outchannels2, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
)
[docs]
def forward(self, x):
# if unmasked_x_shps is not None this means we have mutliple populations and
# thus could have padded values of -1 we want to make sure to choose a mask
# that pulls out all values of the different populations feature matrices,
# EXCEPT those that equal -1
if self.unmasked_x_shps is not None and len(x.shape) == 5:
xs = []
batch_ndim = x.shape[0]
for i, shape in enumerate(self.unmasked_x_shps):
mask = x[:, i, :, :, :] != -1
inds = torch.where(mask)
x_ = x[:, i, :, :, :][inds].view(batch_ndim, *shape)
xs.append(self.symmetric(self.cnn(x_)))
x = torch.cat(xs, dim=-1)
x = self.globalpool(x)
return self.feature_extractor(x)
# Otherwise we know there are no padded values and can just run the
# input data through the network
return self.feature_extractor(self.globalpool(self.symmetric(self.cnn(x))))
[docs]
class SummaryStatisticsEmbedding(nn.Module):
"""
Embed summary statistics of a tree sequence.
This is simply an identity layer that takes in a tensor of summary statistics
(e.g., SFS) and outputs the same tensor.
For single population SFS: input shape is (num_samples + 1,)
For joint SFS: input shape is (num_samples_pop1 + 1, num_samples_pop2 + 1)
"""
def __init__(self, output_dim=None):
super().__init__()
self.identity = nn.Identity()
[docs]
def forward(self, x):
# Ensure input is a torch tensor and flatten if needed
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x).float()
return self.identity(x.reshape(x.shape[0], -1))
[docs]
def embedding(self, x):
"""
Consistent with other embedding networks, provide an embedding method
that returns the same output as forward() since this is an identity layer
"""
with torch.no_grad():
return self.forward(x)
[docs]
class SPIDNA_embedding_network(nn.Module):
"""
SPIDNA architecture for processing genetic data.
Parameters
----------
output_dim : int
Dimension of the output feature vector
num_block : int
Number of SPIDNA blocks in the network
num_feature : int
Number of features in the convolutional layers
"""
def __init__(self, output_dim=64, num_block=3, num_feature=64):
super().__init__()
# Validate that num_feature >= output_dim to prevent dimension mismatches
if num_feature < output_dim:
raise ValueError(
f"num_feature ({num_feature}) must be >= output_dim ({output_dim})"
)
self.output_dim = output_dim
self.num_feature = num_feature
# Initialize convolutional layers without padding to match reference
self.conv_pos = nn.Conv2d(1, num_feature, (1, 3))
# self.conv_pos_bn = nn.BatchNorm2d(num_feature, track_running_stats=False)
self.conv_snp = nn.Conv2d(1, num_feature, (1, 3))
# self.conv_snp_bn = nn.BatchNorm2d(num_feature, track_running_stats=False)
# Create SPIDNA blocks
self.blocks = nn.ModuleList(
[SPIDNABlock(num_feature, output_dim) for _ in range(num_block)]
)
[docs]
def forward(self, x):
# Reshape input: (batch, channels, samples, snps)
if len(x.shape) == 5:
x = x.squeeze(1)
# Split and reshape position and SNP data
pos = x[:, 0, :].view(x.shape[0], 1, 1, -1) # Shape: (batch, 1, 1, snps)
snp = x[:, 1:, :].unsqueeze(1) # Shape: (batch, 1, samples, snps)
# Process position data and expand to match SNP dimensions
# pos = F.relu(self.conv_pos_bn(self.conv_pos(pos)))
pos_conv = self.conv_pos(pos)
pos = F.relu(F.layer_norm(pos_conv, pos_conv.shape[1:])) # Using LayerNorm
pos = pos.expand(-1, -1, snp.size(2), -1)
# Process SNP data
# snp = F.relu(self.conv_snp_bn(self.conv_snp(snp)))
snp_conv = self.conv_snp(snp)
snp = F.relu(F.layer_norm(snp_conv, snp_conv.shape[1:])) # Using LayerNorm
# Combine features
x = torch.cat((pos, snp), dim=1)
# Initialize output tensor
output = torch.zeros(x.size(0), self.output_dim, device=x.device)
# Process through SPIDNA blocks
for block in self.blocks:
x, output = block(x, output)
return output
[docs]
def embedding(self, x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x).float()
with torch.no_grad():
return self.forward(x)
[docs]
class SPIDNABlock(nn.Module):
"""
SPIDNA architecture for processing genetic data, basic unit
"""
def __init__(self, num_feature, output_dim):
super().__init__()
self.num_feature = num_feature
self.output_dim = output_dim
# Remove padding to match reference implementation
self.phi = nn.Conv2d(num_feature * 2, num_feature, (1, 3))
# self.phi_bn = nn.BatchNorm2d(num_feature * 2, track_running_stats=False)
# Note: LayerNorm shape will be determined dynamically in forward pass
self.maxpool = nn.MaxPool2d((1, 2))
self.fc = nn.Linear(output_dim, output_dim)
[docs]
def forward(self, x, output):
# Apply batch norm first, then convolution (matching reference)
# x = self.phi(self.phi_bn(x))
# LayerNorm: normalize over channel, height, width dimensions
x_normalized = F.layer_norm(x, x.shape[1:])
x = self.phi(x_normalized)
# Average over samples dimension
psi = torch.mean(x, 2, keepdim=True)
# Process current output - slice only output_dim features
current_features = torch.mean(psi[:, : self.output_dim, :, :], 3).squeeze(2)
current_output = self.fc(current_features)
# Add to running output
output = output + current_output
# Expand psi and combine with x
psi = psi.expand(-1, -1, x.size(2), -1)
x = torch.cat((x, psi), 1)
# Apply maxpool then ReLU (matching reference)
x = self.maxpool(x)
x = F.relu(x)
return x, output
[docs]
class ReLERNN(nn.Module):
"""
This module constructs a bi-directional GRU based RNN following the architecture
from https://github.com/kr-colab/ReLERNN/blob/master/ReLERNN/networks.py#L7.
It processes haplotype data along with corresponding positional information to produce
a feature embedding. Its output is consistent with the other embedding networks in this module.
Parameters
----------
input_size : int
The input size for the GRU layer (typically num_individuals * ploidy).
num_positions : int
The number of genome positions in the input data.
output_dim : int, optional
The dimension of the final embedded feature vector (default: 64).
Input
-----
x : torch.Tensor, shape (batch, sequence_length, 1 + input_size)
The first feature along the last dimension is assumed to be positional data while
the remaining features are the haplotype representation.
Output
------
torch.Tensor, shape (batch, output_dim)
The embedded feature vector.
"""
def __init__(self, input_size, n_snps, output_size=64, shuffle_genotypes=False):
"""
:param input_size: the input size of the GRU layer, e.g. num_individuals*ploidy
or num_individuals depending if the data is phased or not
:param n_snps: the number of SNPs in the input data
:param output_size: the dimension of the final embedded feature vector
:param shuffle_genotypes: whether to shuffle the genotypes (default: False; training only)
"""
super().__init__()
self.shuffle_genotypes = shuffle_genotypes
self.rnn = nn.GRU(
input_size, 84, num_layers=1, batch_first=True, bidirectional=True
)
self.fc1 = nn.Sequential(nn.Linear(168, 256), nn.Dropout(0.35))
self.fc_pos = nn.Sequential(nn.Linear(n_snps, 256))
self.feature_ext = nn.Sequential(
nn.Linear(512, 64), nn.Dropout(0.35), nn.Linear(64, output_size)
)
[docs]
def forward(self, x):
# x is expected to be of shape (batch, sequence_length, 1 + input_size)
# where the first feature is positional data and the rest is haplotype data.
pos = x[..., 0] # (batch, sequence_length) == (batch, num_positions)
haps = x[..., 1:] # (batch, sequence_length, input_size)
# If shuffle_genotypes is True, shuffle the haplotype data
if self.shuffle_genotypes and self.training:
perm = torch.randperm(haps.shape[-1])
haps = haps[..., perm]
# Process haplotype data via GRU
_, hn = self.rnn(haps) # hn: (num_layers * num_directions, batch, 84)
hn = hn.permute(1, 0, 2).reshape(x.shape[0], -1) # (batch, 168)
hapout = self.fc1(hn) # (batch, 256)
posout = self.fc_pos(pos) # (batch, 256)
# Concatenate processed haplotype and position features
catout = torch.cat([hapout, posout], dim=-1) # (batch, 512)
# Final embedding extraction
return self.feature_ext(catout) # (batch, output_size)