Skip to content

new (?) statistic to efficiently calculate shared times between every pair of samples at every tree #1619

@mmosmond

Description

@mmosmond

Hi all, I've been messing around with different ways to calculate the shared time between every pair of samples in many trees, which is a useful metric because it describes the covariance of characteristics of the samples (e.g., traits, locations) under Brownian motion. At first I couldn't figure out how to use an existing statistic to do this, so I used the general_stat function to make one of my own:

import numpy as np
def shared_times(ts):
    """Use general_stat function to calculate shared branch lengths between all samples at all trees.
    
    Parameters
    ts: tskit tree sequence
    
    Returns
    A (n x k x k) numpy array, with n the number of trees in the sequence and k the number of sample nodes. 
    At each tree the (k x k) matrix gives the shared evolutionary times between each pair of sample nodes.
    """
    
    k = ts.num_samples #number of samples
    W = np.identity(k) #each node i in [0,1,...,k] given row vector of weights with 1 in column i and 0's elsewhere 
    def f(x): return (x.reshape(-1,1) * x).flatten() #determine which pairs of samples share the branch above a node, flattened

    return ts.general_stat(
        W, f, k**2, mode='branch', windows='trees', polarised=True, strict=False
    ).reshape(ts.num_trees, k, k)

For example:

import msprime
ts = msprime.sim_ancestry(samples=5, population_size=1e4, sequence_length=1e4, recombination_rate=1e-8, random_seed=1)
sts = shared_times(ts)

The result can also be achieved by simply looping through trees and samples to calculate pairwise TMRCAs:

def shared_times(ts):
    k = ts.num_samples
    sts = np.zeros((ts.num_trees,k,k))
    for t,tree in enumerate(ts.trees()):
        T = tree.time(tree.root)
        for i in range(k):
            for j in range(i):
                st = T - tree.tmrca(i,j)
                sts[t,i,j] = st
                sts[t,j,i] = st
    return sts

but this is considerably slower when there are many trees (although it seems to be faster when there are few trees with many samples, e.g., ts = msprime.sim_ancestry(samples=2e2, population_size=1e4, random_seed=1) -- not sure why).

Coming back to this much later (today), I realized you can use the divergence statistic, with a few tweaks, to do the same:

def shared_times(ts):
    
    k = ts.num_samples
    
    # get 2*tmrcas
    sample_sets = [[i] for i in ts.samples()]
    indexes = [(i,j) for i in ts.samples() for j in range(i,k)] #compare each sample with each other only once (entries of upper triangular)
    divs = ts.divergence(sample_sets=sample_sets, indexes=indexes, mode='branch', windows='trees')
    divs[np.isnan(divs)] = 0 #tmrcas with self are said to be nan, so convert to 0

    # convert to matrices
    divs_mat = np.zeros((ts.num_trees,k,k)) 
    for ix,(i,j) in enumerate(indexes):
        divs_mat[:,i,j] = divs[:,ix] #convert list to upper triangle
        divs_mat[:,j,i] = divs[:,ix] #fill in lower triangle symmetrically
    
    # convert to shared times
    sts = np.zeros(divs_mat.shape)
    for i,div in enumerate(divs_mat):
        sts[i] = (np.max(div) - div)/2 #convert from 2*tmrcas to shared times
    
    return sts

This latter method seems to be the fastest for small tree sequences (eg, ts = msprime.sim_ancestry(samples=5, population_size=1e4, sequence_length=1e6, recombination_rate=1e-8, random_seed=1)) but loses this advantage as tree sequences get larger (eg, msprime.sim_ancestry(samples=1e2, population_size=1e4, sequence_length=1e6, recombination_rate=1e-8, random_seed=1)). It also uses >2x the RAM of the first method, which quickly becomes important with larger tree sequences.

I'm wondering if my general_stat matrix formulation would be useful as a new statistic of its own: related to the divergence statistic but calculated differently (I think), sometimes faster, sometimes less memory intensive, and semantically simpler (for creating matrices). If so, I guess I'd have to also think about what happens as you vary the options in general_stat, e.g., mode='node'. Anyway, just wanted to put this out there in case it is helpful/useful. Thanks for building all this stuff!

(This is my first time writing an issue so sorry if I've messed up somehow -- I'd be keen to hear how to do this better!)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions