import dinf
import tskit
import torch
import numpy as np
import allel
import pandas as pd
from dinf.misc import ts_individuals
[docs]
class BaseProcessor:
def __init__(self, config: dict, default: dict):
for key in config:
if key == "class_name": continue
assert key in default, f"Option {key} not available for processor"
for key, default in default.items():
setattr(self, key, config.get(key, default))
[docs]
class genotypes_and_distances(BaseProcessor):
"""
Genotype matrix and distance to next SNP
"""
default_config = {
"max_snps": 2000,
"min_freq": 0.0,
"max_freq": 1.0,
"polarised": True,
"phased": False,
"position_scaling": 1e3,
}
def __init__(self, config: dict):
super().__init__(config, self.default_config)
def __call__(self, ts: tskit.TreeSequence) -> np.ndarray:
geno = ts.genotype_matrix()
freq = geno.sum(axis=1) / geno.shape[1]
keep = np.logical_and(
freq >= self.min_freq,
freq <= self.max_freq,
)
if not self.polarised:
geno[freq > 0.5] = 1 - geno[freq > 0.5]
if not self.phased:
individual_map = np.zeros((ts.num_samples, ts.num_individuals))
for i, ind in enumerate(ts.individuals()):
individual_map[ind.nodes, i] = 1.0
geno = geno @ individual_map
pos = (
np.append(ts.sites_position[0], np.diff(ts.sites_position))
/ self.position_scaling
)
geno = np.concatenate([geno, pos.reshape(ts.num_sites, -1)], axis=-1)
# filter SNPs
geno = geno[keep]
geno = geno[: self.max_snps]
# the output tensor is (snps, individuals) with no padding,
# and will be ragged across simulations, with the SNP dimension not
# exceeding `max_snps`
return geno
# FIXME: bug if pops have different sizes?
[docs]
class tskit_sfs(BaseProcessor):
"""
Site frequency spectrum processor that handles both single and multiple populations.
For single population: returns normalized SFS
For multiple populations: returns normalized joint SFS
"""
default_config = {
"sample_sets": None,
"windows": None,
"mode": "site",
"span_normalise": False,
"polarised": False,
"normalised": True,
"log1p": False,
}
def __init__(self, config: dict):
super().__init__(config, self.default_config)
def __call__(self, ts: tskit.TreeSequence) -> torch.Tensor:
sampled_pops = [
i for i in range(ts.num_populations) if len(ts.samples(population=i)) > 0
]
sample_sets = [ts.samples(population=i) for i in sampled_pops]
sfs = ts.allele_frequency_spectrum(
sample_sets=sample_sets,
windows=self.windows,
mode=self.mode,
span_normalise=self.span_normalise,
polarised=self.polarised,
)
if self.normalised:
sfs /= np.sum(sfs)
if self.log1p:
sfs = np.log1p(sfs)
return sfs
[docs]
class SPIDNA_processor(BaseProcessor):
default_config = {
"maf": 0.05,
"relative_position": True,
"n_snps": 400,
"polarised": True,
"phased": True,
}
def __init__(self, config: dict):
super().__init__(config, self.default_config)
def __call__(self, ts: tskit.TreeSequence) -> np.ndarray:
# Extract genotype matrix and positions
snp = ts.genotype_matrix() # Shape: (n_variants, n_samples)
pos = ts.sites_position
# Handle relative positions
if self.relative_position:
abs_pos = np.array(pos)
pos = abs_pos - np.concatenate(([0], abs_pos[:-1]))
pos = pos / ts.sequence_length # Normalize positions to [0, 1] range
# MAF filtering
if self.maf != 0:
num_sample = ts.num_samples
row_sum = np.sum(
snp, axis=1
) # Sum along rows since matrix isn't transposed
keep = np.logical_and.reduce(
[
row_sum != 0,
row_sum != num_sample,
row_sum > num_sample * self.maf,
num_sample - row_sum > num_sample * self.maf,
]
)
snp = snp[keep]
pos = pos[keep]
if not self.polarised:
freq = snp.sum(axis=1) / snp.shape[1]
snp[freq > 0.5] = 1 - snp[freq > 0.5]
if not self.phased:
individual_map = np.zeros((ts.num_samples, ts.num_individuals))
for i, ind in enumerate(ts.individuals()):
individual_map[ind.nodes, i] = 1.0
snp = snp @ individual_map
# Handle case where we have fewer than n_snps SNPs
n_samples = snp.shape[1]
n_snps = snp.shape[0]
if n_snps < self.n_snps:
# Pad with -1 to reach n_snps SNPs (consistent with cnn_extract padding)
snp_padded = np.full((self.n_snps, n_samples), -1, dtype=snp.dtype)
snp_padded[:n_snps, :] = snp[:, :n_samples]
snp = snp_padded
# Pad positions with -1 to indicate padding
pos_padded = np.full(self.n_snps, -1, dtype=pos.dtype)
pos_padded[:n_snps] = pos
pos = pos_padded
else:
# We have enough SNPs, just take the first n_snps
snp = snp[:self.n_snps, :n_samples]
pos = pos[:self.n_snps]
# Create output tensor matching legacy format
# First create position channel (1, n_snps)
pos_channel = pos.reshape(1, -1)
# Stack channels
output_val = np.concatenate(
[
pos_channel, # Shape: (1, n_snps)
snp.T, # Now transpose only at the end to match expected output format
]
)
return output_val.astype(np.float32)
[docs]
class ReLERNN_processor(BaseProcessor):
default_config = {
"n_snps": 2000,
"min_freq": 0.0,
"max_freq": 1.0,
"phased": True,
"polarised": True,
}
def __init__(self, config: dict):
super().__init__(config, self.default_config)
if not self.phased:
raise ValueError("Unphased features not implemented")
def __call__(self, ts: tskit.TreeSequence) -> np.ndarray:
geno = ts.genotype_matrix()
freq = geno.sum(axis=1) / geno.shape[1]
keep = np.logical_and(
freq >= self.min_freq,
freq <= self.max_freq,
)
if not self.polarised:
geno[freq > 0.5] = 1 - geno[freq > 0.5]
geno = geno * 2 - 1 # recode ancestral to -1, derived to 1
pos = ts.sites_position / ts.sequence_length
geno = np.concatenate([pos.reshape(ts.num_sites, -1), geno], axis=-1)
# filter SNPs
geno = geno[keep]
geno = geno[:self.n_snps]
# Pad with zeros if the number of rows is less than max_snps
if geno.shape[0] < self.n_snps:
pad_rows = self.n_snps - geno.shape[0]
geno = np.pad(
geno, ((0, pad_rows), (0, 0)), mode="constant", constant_values=0
)
return geno
[docs]
class tskit_windowed_sfs_plus_ld(BaseProcessor):
"""
Summary statistics processor that returns a vector of the mean r2 across distances and the mean afs
where the mean is taken over windows.
Mean currently only for the single population case.
"""
default_config = {
"sample_sets": None,
"mode": "site",
"span_normalise": False,
"polarised": True,
"window_size": 1_000_000,
}
def __init__(self, config: dict):
super().__init__(config, self.default_config)
# TODO: this was copied from Yuxin's code, need to clean it up.
# not sure if we need the circular option at all?
def _LD(self, haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
"""
Compute LD on a subsampled set of SNP pairs and return a DataFrame containing the mean r^2
per distance bin.
Gap sizes follow powers of 2. For each gap, SNP pairs are sampled with a random phase shift,
then LD is computed and squared. Results are binned by physical distance and averaged.
Parameters
----------
haplotype : array(n_SNP, n_samples)
pos_vec : array(n_SNP,)
Absolute genomic positions in [0, size_chr].
size_chr : int
Chromosome length.
circular : bool
Whether to consider the chromosome circular. If circular, the maximum distance between
two SNPs is half the chromosome. (Currently not used in distance computation.)
distance_bins : int or sequence of numbers, optional
If an int k is given, LD is averaged in k-1 logarithmic bins between 10^2 and size_chr,
with 0 inserted as the first edge. If a sequence is given, those values are used as
the bin edges directly. If None, 19 log-spaced bins between 10^2 and size_chr are used,
with 0 inserted as the first edge.
Returns
-------
pandas.DataFrame
Index: pandas.IntervalIndex of distance bins. Columns: 'mean_r2' containing the
mean of r^2 within each distance bin.
Notes
-----
- Subsampling is stochastic due to random shifts when forming SNP pairs.
- LD per pair is computed as allel.rogers_huff_r(...)**2.
"""
# Set up distance bins if not provided (kept here for potential grouping)
if distance_bins is None or isinstance(distance_bins, int):
if isinstance(distance_bins, int):
n_bins = distance_bins - 1
else:
n_bins = 19
distance_bins = np.logspace(2, np.log10(size_chr), n_bins)
distance_bins = np.insert(distance_bins, 0, [0])
n_SNP, n_samples = haplotype.shape
gaps = (2 ** np.arange(0, np.log2(n_SNP), 1)).astype(int)
selected_snps = []
for gap in gaps:
snps = np.arange(0, n_SNP, gap) + np.random.randint(0, (n_SNP - 1) % gap + 1)
snp_pairs = np.unique([((snps[i] + i) % n_SNP, (snps[i + 1] + i) % n_SNP) for i in range(len(snps) - 1)], axis=0)
snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
last_pair = snp_pairs[-1]
max_value = n_SNP - gap - 1
while len(snp_pairs) <= min(300, max_value):
random_shift = np.random.randint(1, n_SNP) % n_SNP
new_pair = (last_pair + random_shift) % n_SNP
snp_pairs = np.unique(
np.concatenate([snp_pairs, new_pair.reshape(1, 2)]), axis=0
)
last_pair = new_pair
snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
selected_snps.append(snp_pairs)
# Collect r2 values into a DataFrame.
agg_bins = {"snp_dist": ["mean"], "r2": ["mean"]}
ld = pd.DataFrame()
for i, snps_pos in enumerate(selected_snps):
sd = pd.DataFrame((np.diff(pos_vec[snps_pos])), columns=["snp_dist"])
sd["dist_group"] = pd.cut(sd.snp_dist, bins=distance_bins)
sr = [allel.rogers_huff_r(snps) ** 2 for snps in haplotype[snps_pos]]
sd["r2"] = sr
sd["gap_id"] = i
ld = pd.concat([ld, sd])
ld2 = ld.dropna().groupby("dist_group",observed=False).agg(agg_bins)
# Flatten the MultiIndex columns and rename explicitly
ld2.columns = ['_'.join(col).strip() for col in ld2.columns.values]
ld2 = ld2.rename(columns={
'snp_dist_mean': 'mean_dist',
'r2_mean': 'mean_r2'
})
return ld2[['mean_r2']]
def __call__(self, ts: tskit.TreeSequence) -> np.ndarray:
# Get number of populations with samples
sampled_pops = [
i for i in range(ts.num_populations) if len(ts.samples(population=i)) > 0
]
assert len(sampled_pops) == 1, "Only single population case currently supported"
sequence_length = ts.sequence_length
windows = np.arange(0, sequence_length, self.window_size)
# iterate over windows such that the last window is the remainder
# start at 0, end at sequence_length, step by 1_000_000
ld_stats = []
afs_stats = []
for i in range(len(windows)):
if i == len(windows) - 1:
window_end = sequence_length
else:
window_end = windows[i + 1]
window_start = windows[i]
ts_win = ts.keep_intervals([(window_start, window_end)]).trim()
ts_win_positions = ts_win.tables.sites.position
# calculate LD stats for each window
ld_stats.append(
self._LD(
ts_win.genotype_matrix(),
ts_win_positions,
window_end - window_start,
)
)
# get the AFS for each window
afs = ts_win.allele_frequency_spectrum(
mode="site",
polarised=self.polarised,
span_normalise=self.span_normalise,
)
# normalize the AFS
# afs = afs / np.sum(afs)
afs_stats.append(afs)
# calculate the mean r2 for each window at each distance
mean_r2_values = (
pd.concat(ld_stats)["mean_r2"]
.groupby(level=0, observed=False)
.mean() # Automatically skips NaNs by default
.fillna(0) # Replace any remaining NaNs with 0
)
# calculate the mean afs for each window
mean_afs_values = np.stack(afs_stats).mean(axis=0)[1:-1]
sum_stats = np.concatenate((mean_r2_values, mean_afs_values))
return sum_stats