Updated 'statistics_tests.py' to conform to mutations API

This commit is contained in:
Harrison Deng 2023-03-23 16:24:25 -05:00
parent f26cc5618c
commit 0da81f26dd
10 changed files with 1308 additions and 0 deletions

5
.gitignore vendored
View File

@ -238,3 +238,8 @@ $RECYCLE.BIN/
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
data
*-data
lineage_report.csv
cache
*.gb

53
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,53 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "VGAAT: run all",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/vgaat/vgaat.py",
"args": [
"${workspaceFolder}/test_data/",
"${workspaceFolder}/MN908947_3.gb",
"unknown",
"${workspaceFolder}/nml_lineage_report.csv",
"--log", "DEBUG",
"--alpha", "0.05",
"--output", "${workspaceFolder}/results.md",
"--threads", "1",
],
"console": "integratedTerminal",
"justMyCode": true,
},
{
"name": "VGAAT: run all - NO CACHE",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/vgaat/vgaat.py",
"args": [
"${workspaceFolder}/test_data/",
"${workspaceFolder}/MN908947_3.gb",
"unknown",
"${workspaceFolder}/nml_lineage_report.csv",
"--log", "DEBUG",
"--alpha", "0.05",
"--output", "${workspaceFolder}/results.md",
"--threads", "1",
"--clear-cache", "True"
],
"console": "integratedTerminal",
"justMyCode": true,
}
]
}

23
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,23 @@
{
"cSpell.words": [
"Biopython",
"funcs",
"genbank",
"heatmaps",
"matplotlib",
"NCBI",
"orjson",
"percache",
"pyplot",
"scipy",
"utrs",
"vcfpy",
"vcfs",
"VGAAT"
],
"python.formatting.provider": "black",
"python.linting.prospectorEnabled": false,
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
"python.analysis.typeCheckingMode": "basic"
}

3
tox.ini Normal file
View File

@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203

12
vgaat/filters.py Normal file
View File

@ -0,0 +1,12 @@
def filter_by_alpha(result_groups: dict[str, list[dict]], alpha=0.05):
filtered_result_groups = {}
for identifier, result_group in result_groups.items():
test_alias, viral_strand, regions, synonymity, fishers = identifier
filtered_result_data = []
title, desc, result_data = result_group
for info in result_data:
if (not isinstance(info["p"], str)) and info["p"] < float(alpha):
filtered_result_data.append(dict(info))
if len(filtered_result_data) > 0:
filtered_result_groups[identifier] = (title, desc, filtered_result_data)
return filtered_result_groups

553
vgaat/mutations.py Normal file
View File

@ -0,0 +1,553 @@
from collections import defaultdict
from typing import Callable, Union
from Bio.Seq import MutableSeq, translate
from Bio import GenBank
from frozendict import frozendict
from utils import accession_to_genbank_filename
import logging
from functools import cache
class Mutation:
def __init__(
self,
vcf_record,
sample_id,
sample_type,
sample_day,
viral_lineage,
alt,
sample_call_name="unknown",
):
self._sample_viral_lineage = viral_lineage
self._patient_id = sample_id
self._sample_type = sample_type
self._sample_day = sample_day
self._start = vcf_record.begin
self._end = vcf_record.end or self.start
self._reference = vcf_record.REF
self._call = vcf_record.call_for_sample[sample_call_name]
self._genotype = self.call.gt_bases[0] # Haploid
self._all_alts = vcf_record.ALT
self._alt = alt.value
self._ref_obs = self.call.data["RO"]
self._depth = self.call.data["DP"]
if len(self.genotype) > len(self.ref):
self._variant_type = "insertion"
elif len(self.genotype) < len(self.ref):
self._variant_type = "deletion"
elif len(self.genotype) == len(self.ref):
if len(self.genotype) == 1:
self._variant_type = "SNV"
else:
self._variant_type = "MNV"
self._hash = hash(
(
self._patient_id,
self._sample_type,
self._sample_day,
self._start,
self._end,
self._reference,
self._alt,
)
)
@property
def patient_id(self):
return self._patient_id
@property
def sample_type(self):
return self._sample_type
@property
def sample_day(self):
return self._sample_day
@property
def sample_viral_lineage(self):
return self._sample_viral_lineage
@property
def start(self):
return self._start
@property
def end(self):
return self._end
@property
def ref(self):
return self._reference
@property
def call(self):
return self._call
@property
def genotype(self):
return self._genotype
@property
def alternatives(self):
return self._all_alts
@property
def ref_obs(self):
return self._ref_obs
@property
def depth(self):
return self._depth
@property
def variant_type(self):
return self._variant_type
def __len__(self):
return self.depth
def __hash__(self):
return self._hash
def __eq__(self, other):
return (
self.patient_id == other.get_patient_id
and self.sample_type == other.get_sample_type
and self.sample_day == other.sample_day
and self.start == other.start
and self.end == other.end
and self.ref == other.reference()
and self.call == other.call
and self.genotype == other.genotype
and self.alternatives == other.alternatives
and self.ref_obs == other.ref_obs
and self.depth == other.depth
and self.variant_type == other.variant_type
and self._alt == other._alt_nucleotide
)
@staticmethod
def vcf_reads_to_mutations(
vcf_reads: dict, accession, sample_lineage, sample_call_name="unknown"
):
for sample_info, vcf_records in vcf_reads.items():
sample_id, sample_type, sample_day = sample_info
for vcf_record in vcf_records:
if vcf_record.CHROM != accession:
continue
for alternate in vcf_record.ALT:
yield Mutation(
vcf_record,
sample_id,
sample_type,
sample_day,
sample_lineage[sample_info]
if sample_info in sample_lineage
else None,
alternate,
sample_call_name,
)
class CategorizedMutations:
def __init__(self, genbank_record, mutations: list[Mutation]):
self._genbank_record = genbank_record
self._primitive_category_groups = {}
self._primitive_categories = {}
self._custom_category_groups = {}
self._custom_categories = {}
genes_and_utrs_locs = {}
logging.debug("Categorized mutations being constructed.")
logging.debug("Cataloguing features...")
for genbank_feature in self.get_genbank_record().features:
f_key = genbank_feature.key
if f_key == "gene" or f_key.endswith("UTR"):
f_qual = f_key
f_start, f_end = genbank_feature.location.split("..")
f_start = int(f_start)
f_end = int(f_end)
for qualifier in genbank_feature.qualifiers:
if qualifier.key == "/gene=":
f_qual = qualifier.value
break
if (f_start, f_end) in genes_and_utrs_locs:
raise Exception("Feature with duplicate location detected.")
genes_and_utrs_locs[(f_start, f_end)] = f_qual
logging.debug("Done")
self._accession = genbank_record.accession
self._primitive_category_groups["locations"] = defaultdict(set)
self._primitive_category_groups["allele change"] = defaultdict(set)
self._primitive_category_groups["sample type"] = defaultdict(set)
self._primitive_category_groups["sample day"] = defaultdict(set)
self._primitive_category_groups["variant type"] = defaultdict(set)
self._primitive_category_groups["patient"] = defaultdict(set)
self._primitive_category_groups["viral lineage"] = defaultdict(set)
self._primitive_category_groups["regions"] = defaultdict(set)
self._primitive_category_groups["patient earliest"] = defaultdict(set)
self._primitive_categories["non-synonymous"] = set()
self._primitive_categories["all"] = set(mutations)
logging.debug("Organizing mutations...")
num_of_mutations = len(mutations)
done = 0
for mutation in mutations:
self._primitive_category_groups["locations"][
(mutation.start, mutation.end)
].add(mutation)
self._primitive_category_groups["allele change"][
(mutation.ref, mutation.genotype)
].add(mutation)
self._primitive_category_groups["sample type"][mutation.sample_type].add(
mutation
)
self._primitive_category_groups["variant type"][mutation.variant_type].add(
mutation
)
self._primitive_category_groups["sample day"][mutation.sample_day].add(
mutation
)
if mutation.sample_viral_lineage:
self._primitive_category_groups["viral lineage"][
mutation.sample_viral_lineage
].add(mutation)
# Maintain all samples having the same day and by the end, that day
# is the latest... If we start with nothing, then everything in the
# set has the same day. This is base case. Then, when we add
# something to the set, if one of the previous set has a older time
# stamp, they all do (I.H). We clear the set and add the more
# recent version. If the new timestamp is older then any arbitrary
# item in the current set, then it's older than all items in the
# set by I.H. We do not add this item. The last case if the item's
# timestamp is the same as any arbitrary one. Therefore, all
# items in the set will be the same. Since we are replacing
# all older items with newer ones, the last update should contain
# the latest set.
if self._primitive_category_groups["patient earliest"][mutation.patient_id]:
earliest_day = next(
iter(
self._primitive_category_groups["patient earliest"][
mutation.patient_id
]
)
).sample_day
else:
earliest_day = None
if earliest_day is None or earliest_day > mutation.sample_day:
self._primitive_category_groups["patient earliest"][
mutation.patient_id
].clear()
if earliest_day is None or mutation.sample_day == earliest_day:
self._primitive_category_groups["patient earliest"][
mutation.patient_id
].add(mutation)
self._primitive_category_groups["patient"][mutation.patient_id].add(
mutation
)
mutable_sequence = MutableSeq(self.ref_sequence)
for relative_mutation, mutation_position in enumerate(
range(mutation.start, mutation.end + 1)
):
mutable_sequence[mutation_position] = mutation.genotype[
relative_mutation
]
if translate(mutable_sequence) != translate(self.ref_sequence):
self._primitive_categories["non-synonymous"].add(mutation)
for location in genes_and_utrs_locs.keys():
start, end = location
if start <= mutation.start and end >= mutation.end:
self._primitive_category_groups["regions"][
(genes_and_utrs_locs[location], location)
].add(mutation)
done += 1
logging.info(
f"Organized {done}/{num_of_mutations} \
({done/num_of_mutations:.2%}) mutations"
)
self._primitive_categories["non-synonymous"] = set()
self._primitive_categories["all"] = set(mutations)
self._primitive_category_groups = recursive_freeze_dict(
self._primitive_category_groups
)
def get_filename(self):
"""
Returns the accession filename.
"""
return accession_to_genbank_filename(self._accession)
def get_accession(self):
return self._accession
def get_genbank_record(self):
return self._genbank_record
@property
def ref_sequence(self):
return self._genbank_record.sequence
def pget_category_or_category_group(self, property_name: str):
return (
self._primitive_category_groups[property_name]
or self._primitive_categories[property_name]
)
def pget_category(self, property_name: str):
return self._primitive_categories[property_name]
def pget_category_group(self, property_name: str):
return self._primitive_category_groups[property_name]
@property
def all(self):
return self._primitive_category_groups["all"]
@property
def total_mutations(self):
return len(self.all())
@cache
def mutations_per_site_for_pcategory_group(self, category_group: str):
categorized_mutations = self.pget_category_group(category_group)
if isinstance(categorized_mutations, dict):
resulting_dict = {}
for key, value in categorized_mutations.items():
resulting_dict[key] = self.count_mutations_per_site(value)
return resulting_dict
else:
return self.count_mutations_per_site(categorized_mutations)
@cache
def flattened_pcategory_group(self, category: str) -> set:
categorized_mutations = self.pget_category_group(category)
flattened = set()
for key, set_value in categorized_mutations.items():
flattened.update(set_value)
return flattened
@cache
def mutation_sites_for_pccategory_group(self, category: str):
categorized_mutations = self.pcget_or_create_category_groups(category)
resulting_dict = {}
for key, value in categorized_mutations.items():
resulting_dict[key] = CategorizedMutations.count_mutation_sites(value)
return resulting_dict
@cache
def mutations_for_pccategory_group(
self, category: str, assembler: Union[Callable, None] = None, **kwargs: dict
):
dictionary_set = self.pcget_or_create_category_groups(category)
return CategorizedMutations.count_mutations_for_category_group(dictionary_set)
@cache
def punion_category_groups_with_category(self, category_groups: str, category: str):
category_set = set(self.pget_category(category))
each = self.pget_category_group(category_groups)
one = category_set
result = {}
for subcat, subcat_mutations in each.items():
result[subcat] = subcat_mutations.union(one)
return result
@cache
def pintersect_category_groups_with_category(
self, category_groups: str, category: str
):
category_set = set(self.pget_category(category))
each = self.pget_category_group(category_groups)
result = {}
for subcat, subcat_mutations in each.items():
result[subcat] = subcat_mutations.intersection(category_set)
return result
def pcget_or_create_category_groups(
self, category_group: str, assembler: Union[Callable, None] = None, **kwargs
):
if category_group in self._primitive_category_groups:
return self.pget_category_group(category_group)
if category_group not in self._custom_category_groups:
if not assembler:
raise Exception(
"Assembler was not provided for a non-existent custom category "
f'groups "{category_group}".'
)
result = assembler(**kwargs)
if not isinstance(result, dict):
raise Exception("Assembler result was not a group of categories...")
self._custom_category_groups[category_group] = result
return self._custom_category_groups[category_group]
@cache
def pcunion_category_groups_with_categories(
self,
union_name: str,
category_group: str,
category: tuple[Mutation],
assembler: Union[Callable, None] = None,
assembler_args: dict = {},
):
category_set = set(category)
if union_name in self._custom_category_groups:
return self._custom_category_groups[union_name]
subcategories = self.pcget_or_create_category_groups(
category_group, assembler, **assembler_args
)
result = {}
for subcat, mutations in subcategories.items():
result[subcat] = mutations.union(category_set)
self._custom_category_groups[union_name] = result
return result
@cache
def pcintersect_category_groups_with_categories(
self,
intersect_name: str,
category_group: str,
category: tuple[Mutation],
assembler: Union[Callable, None] = None,
assembler_args: dict = {},
):
category_set = set(category)
if intersect_name in self._custom_category_groups:
return self._custom_category_groups[intersect_name]
subcategories = self.pcget_or_create_category_groups(
category_group, assembler, **assembler_args
)
result = {}
for subcat, mutations in subcategories.items():
result[subcat] = mutations.intersection(category_set)
self._custom_category_groups[intersect_name] = result
return result
@cache
def pcbuild_contingency(
self, category_group_a_name: str, category_group_b_name: str, modifier: Callable
):
category_group_a = self.pcget_or_create_category_groups(category_group_a_name)
category_group_b = self.pcget_or_create_category_groups(category_group_b_name)
proto_contingency = defaultdict(dict)
for a_group, mutations_cat_group_a in category_group_a.items():
for b_group, mutations_cat_group_b in category_group_b.items():
proto_contingency[a_group][b_group] = modifier(
self.intersect(
tuple(mutations_cat_group_a), tuple(mutations_cat_group_b)
)
)
return proto_contingency
def clear_custom_categories(self):
self._custom_category_groups.clear()
@staticmethod
@cache
def count_mutations_per_site(mutations: tuple[Mutation]):
mutation_sites = defaultdict(int)
for mutation in mutations:
for position in range(mutation.start, mutation.end + 1):
mutation_sites[position] += 1
return mutation_sites
@staticmethod
@cache
def count_mutation_sites(mutations: tuple[Mutation]):
mutation_sites = set()
for mutation in mutations:
for nucleotide_position in range(mutation.start, mutation.end + 1):
mutation_sites.add(nucleotide_position)
return len(mutation_sites)
@staticmethod
@cache
def intersect(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]):
intersect = set(mutation)
for add_mut in mutations:
intersect = intersect & set(add_mut)
return intersect
@staticmethod
@cache
def union(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]):
union = set(mutation)
for add_mut in mutations:
union = union | set(add_mut)
return union
@staticmethod
def count_mutations_for_category_group(
group_of_mutations: dict[str, set[Mutation]]
):
resulting_dict = {}
for key, value in group_of_mutations.items():
resulting_dict[key] = len(value)
return resulting_dict
@staticmethod
def build_with_args(
categorized_vcf_records, genbank_file, sample_call_name="sample_call_name"
):
genbank_record = None
with open(genbank_file) as genbank_fd:
genbank_record = next(
GenBank.parse(genbank_fd)
) # Assume there's only one accession per file.
return CategorizedMutations(
genbank_record,
list(
Mutation.vcf_reads_to_mutations(
categorized_vcf_records.get_all_vcf(),
genbank_record.version, # type: ignore
categorized_vcf_records.get_all_lineage_info(),
sample_call_name,
)
),
)
class MutationOrganizer:
def __init__(self, categorized_vcf_records, on_change=None):
self.genbank_path = "unknown"
self.sample_call_name = "unknown"
self._on_change = on_change
self._categorized_vcf_records = categorized_vcf_records
def update(self, genbank_path, sample, **kwargs):
self.genbank_path = genbank_path
self.sample_call_name = sample
if self._on_change:
self._on_change(self, **kwargs)
def build(self):
return CategorizedMutations.build_with_args(
self._categorized_vcf_records, self.genbank_path, self.sample_call_name
)
def recursive_freeze_dict(to_freeze: dict) -> frozendict:
for key, value in to_freeze.items():
if isinstance(value, dict):
to_freeze[key] = recursive_freeze_dict(value)
elif isinstance(value, list):
to_freeze[key] = tuple(value)
return frozendict(to_freeze)
def mutations_collection_to_set(
data: Union[tuple[Mutation], set[Mutation]]
) -> set[Mutation]:
return set(data) if isinstance(data, tuple) else data

115
vgaat/statistics_tests.py Normal file
View File

@ -0,0 +1,115 @@
import pandas as pd
from mutations import CategorizedMutations
from variants import CategorizedVariantRecords
import utils
from collections import defaultdict
def mutation_site_count_vs_sample_type(
cat_muts: CategorizedMutations,
cat_vars: CategorizedVariantRecords,
viral_strand,
region,
non_synonymous,
allow_fishers,
):
# Create list of earliest mutations
earliest_mutations_category = "earliest mutations"
cat_muts.pcintersect_category_groups_with_categories(
earliest_mutations_category,
"sample type",
tuple(cat_muts.flattened_pcategory_group("patient earliest"))
)
working_category_name = earliest_mutations_category
next_working_category_name = earliest_mutations_category
if viral_strand != "all":
next_working_category_name += f" - {viral_strand}"
cat_muts.pcintersect_category_groups_with_categories(
next_working_category_name,
working_category_name,
tuple(cat_muts.pget_category_group("viral lineage")[viral_strand]),
)
working_category_name = next_working_category_name
if region != "all":
next_working_category_name += f" - {region}"
cat_muts.pcintersect_category_groups_with_categories(
next_working_category_name,
working_category_name,
tuple(cat_muts.pget_category_group("regions")[region]),
)
working_category_name = next_working_category_name
if non_synonymous:
next_working_category_name += " - non-synonymous"
cat_muts.pcintersect_category_groups_with_categories(
next_working_category_name,
working_category_name,
tuple(cat_muts.pget_category("non-synonymous")),
)
working_category_name = next_working_category_name
mutated_site_counts_per_sample_type = (
cat_muts.mutation_sites_for_pccategory_group(working_category_name)
)
reference_site_counts_per_sample_type = dict(mutated_site_counts_per_sample_type)
for key, value in reference_site_counts_per_sample_type.items():
reference_site_counts_per_sample_type[key] = len(cat_muts.ref_sequence) - value
conti_table_dict = {
"Mutated": mutated_site_counts_per_sample_type,
"Reference": reference_site_counts_per_sample_type,
}
contigency_test = pd.DataFrame(conti_table_dict)
result = utils.row_pairwise_dependency_test(contigency_test, allow_fishers)
return (
"Is mutation site count dependent on sample type?",
"Here, we will count the sites of mutation in the "
"reference sequence and compare that number for each "
"type of mutation. The following is the table we will "
"we will use: ",
utils.publish_results(result),
)
# TODO Write individual nucleotide test
def individual_mutation_count_vs_sample_type(
cat_muts: CategorizedMutations,
cat_vars: CategorizedVariantRecords,
viral_strand,
region,
non_synonymous,
allow_fishers,
):
results = []
for location, location_based_mutations in cat_muts.pget_category_group("locations"):
contingency_table_per_location = defaultdict(dict)
for sample_type, sample_type_based_mutations in cat_muts.pget_category_group(
"sample type"
):
contingency_table_per_location[f"Mutations at {location}"][
sample_type
] = len(
cat_muts.intersect(
location_based_mutations, sample_type_based_mutations
)
)
contingency_table_per_location["Reference"][sample_type] = len(
location_based_mutations
) - len(cat_muts.pget_category("patients"))
utils.publish_results(contingency_table_per_location, results)
return "Location based mutations vs Sample type", results
# TODO Write distribution difference test (?)
tests = [
(
"Correlation between the number of mutation locations and sample types",
mutation_site_count_vs_sample_type,
)
]

204
vgaat/utils.py Normal file
View File

@ -0,0 +1,204 @@
import hashlib
import os
import shutil
import sys
from scipy.stats import chi2_contingency, fisher_exact
import pandas as pd
import itertools
from multiprocessing.pool import ThreadPool
import logging
def accession_to_genbank_filename(accession):
return accession.replace(".", "_") + ".gb"
def genbank_file_accession(gb_name):
return gb_name.replace("_", ".").replace(".gb", "")
def row_pairwise_dependency_test(table: pd.DataFrame, allow_fishers=True):
pairs = []
indices = table.index
results = {}
for i in range(len(indices) - 1):
for j in range(i + 1, len(indices)):
pairs.append((indices[i], indices[j]))
for a, b in pairs:
row_pair_unmodified = table.loc[[a, b]]
row_pair = row_pair_unmodified.loc[
:, (row_pair_unmodified != 0).any(axis=0)
] # Magic to drop columns that are empty...
row_pair = row_pair.loc[(row_pair != 0).any(axis=1)]
if row_pair.shape[0] == 2 and row_pair.shape[1] == 2 and allow_fishers:
odds_ratio, p = fisher_exact(row_pair)
results[f"{a} - {b}"] = {
"type": "exact",
}
elif row_pair.shape[1] > 2 and row_pair.shape[0] >= 2:
chi2, p, dof, expected = chi2_contingency(row_pair)
results[f"{a} - {b}"] = {
"type": "chi2",
}
elif row_pair.shape[1] == 0 and row_pair.shape[0] == 0:
continue
else:
results[f"{a} - {b}"] = {
"type": "chi2",
"original_table": row_pair_unmodified,
}
p = "Error"
results[f"{a} - {b}"]["p"] = p
results[f"{a} - {b}"]["table"] = row_pair
return results
# TODO Clean this up...
def publish_results(results: dict, results_list=None):
results_list = results_list or []
for result_info in results.values():
results_list.append(result_info)
return results_list
class Tester:
def __init__(
self,
test_funcs,
viral_strands,
regions,
synonymity,
fishers,
categorized_mutations,
categorized_variants,
max_threads=16,
):
self.tests = test_funcs
self.viral_strands = viral_strands
self.regions = regions
self.synonymity = synonymity
self.fishers = fishers
self.categorized_mutations = categorized_mutations
self.categorized_variants = categorized_variants
self._max_threads = max_threads
self._results = {}
def run_all_async(self):
self._thread_pool = ThreadPool(processes=self._max_threads)
param_product = itertools.product(
self.tests, self.viral_strands, self.regions, self.synonymity, self.fishers
)
total = len(self.tests) * \
len(self.viral_strands) * \
len(self.regions) * \
len(self.synonymity) * \
len(self.fishers)
runs = self._thread_pool.imap_unordered(
self.run_test,
param_product,
chunksize=self._max_threads # TODO Perhaps add more sophisticated param...
)
for running_test in runs:
identifier, res = running_test
self._results[identifier] = res
logging.info(f"Test progress: {(len(self._results))/total:.1%}")
def run_test(self, params):
test_tuple, viral_strand, regions, synonymity, fishers = params
test_alias, test_func = test_tuple
logging.debug(
"Running {0} with parameters: {1} {2} {3} {4}".format(
test_alias, viral_strand, regions, synonymity, fishers
)
)
res = test_func(
self.categorized_mutations,
self.categorized_variants,
viral_strand,
regions,
synonymity,
fishers
)
logging.debug("Completed running {0} with parameters: {1} {2} {3} {4}".format(
test_alias, viral_strand, regions, synonymity, fishers
))
return (test_alias, viral_strand, regions, synonymity, fishers), res
def get_result_list(self, test_alias: str, viral_strand: str, regions,
synonymity: bool, fishers: bool):
return self._results[test_alias, viral_strand, regions, synonymity, fishers]
def get_all_results(self):
return self._results
def write_markdown_results(result_groups: dict[str, list[dict]], md_results_path=None):
if md_results_path:
large_dir = md_results_path + ".large"
if os.path.exists(large_dir):
shutil.rmtree(large_dir)
writer_p = sys.stdout
if md_results_path:
writer_p = open(md_results_path, "w")
for test_id, result_list in result_groups.items():
result_group_title, result_group_desc, results = result_list
writer_p.write(f"# {result_group_title}\n\n")
writer_p.write(f"{result_group_desc}\n\n")
writer_p.write("Run with the following parameters:\n")
writer_p.write(
""" Test name: {0}; \
Lineage: {1}; \
Region: {2}; \
Allow synonymous: {3}; \
Allow Exact (Fishers): {4};\
\n\n""".format(*test_id)
)
for result in results:
writer_p.write(f"P-value: {result['p']:.5%}; \
Test Type: {result['type']}\n\n")
if result["table"].shape[0] < 10 and result["table"].shape[1] < 10:
writer_p.write(f"{result['table'].to_markdown()}\n---\n\n")
else:
writer_p.write(
f"Table was too large {result['table'].shape} to display.\n"
)
if md_results_path:
large_table_dir = os.path.join(
md_results_path + ".large",
os.path.sep.join(map(str, test_id)).
replace(" ", "").
replace("\t", "").
replace("'", "").
replace("\"", "").
replace("(", "").
replace(")", "").
replace(",", "_")
)
filename = str(
hashlib.shake_128(
result["table"].to_markdown().encode("utf-8")
).hexdigest(8)
)
filepath = os.path.join(
large_table_dir,
filename
)
same_num = 0
while os.path.exists(filepath + str(same_num) + ".csv"):
same_num += 1
large_table_path = filepath + str(same_num) + ".csv"
relative_table_path = large_table_path.replace(
os.path.abspath(md_results_path),
"." + os.path.sep + os.path.basename(md_results_path)
)
os.makedirs(large_table_dir, exist_ok=True)
result["table"].to_csv(path_or_buf=large_table_path)
writer_p.write(
"Table stored as CSV file. See at:\n"
f"[{relative_table_path}]({relative_table_path})"
)
writer_p.write("\n")
if md_results_path:
writer_p.close()

164
vgaat/variants.py Normal file
View File

@ -0,0 +1,164 @@
import csv
import re
import os
from typing import Iterable
import vcfpy
from collections import defaultdict
import logging
class VariantRecordsOrganizer:
def __init__(self, on_change=None):
self._on_change = on_change
self._vcfs_dir = ""
self._lineage_report_csv_path = ""
self._vcf_filename_regex = ""
def update(self, vcfs_dir, lineage_report_csv_path, vcf_filename_regex, **kwargs):
self._vcfs_dir = vcfs_dir
self._lineage_report_csv_path = lineage_report_csv_path
self._vcf_filename_regex = vcf_filename_regex
if self._on_change:
self._on_change(self, **kwargs)
def build(self):
return CategorizedVariantRecords.from_files(
self._vcfs_dir, self._lineage_report_csv_path, self._vcf_filename_regex
)
class VariantRecordGroup:
def __init__(self, patient_id, sample_type, sample_day, records: list):
self._patient_id = patient_id
self._sample_type = sample_type
self._sample_day = sample_day
self._variants = records
@property
def vcf_info(self):
return self._patient_id, self._sample_type, self._sample_day
@property
def variants(self):
return self._variants
def __hash__(self):
return hash(self.vcf_info)
def __eq__(self, o):
return self.vcf_info == o.vcf_info
def __iter__(self):
return iter(self.variants)
class CategorizedVariantRecords:
def __init__(self, vcf_data, lineage_info):
self._primitive_category_groups = {}
self._primitive_category_groups["sample day"] = defaultdict(set)
self._primitive_category_groups["patient id"] = defaultdict(set)
self._primitive_category_groups["sample type"] = defaultdict(set)
self._primitive_category_groups["viral lineage"] = defaultdict(set)
self._all_vcf = vcf_data
self._all_lineage_info = lineage_info
for vcf_info, records in vcf_data.items():
patient_id, sample_type, sample_day = vcf_info
self._primitive_category_groups["sample day"][sample_day].add(records)
self._primitive_category_groups["patient id"][patient_id].add(records)
self._primitive_category_groups["sample type"][sample_type].add(records)
if vcf_info in lineage_info:
self._primitive_category_groups["viral lineage"][
lineage_info[vcf_info]
].add(records)
def get_category_groups(self, category: str):
return self._primitive_category_groups[category]
def get_all_vcf(self):
return self._all_vcf
def get_all_lineage_info(self):
return self._all_lineage_info
@staticmethod
def intersect_each_with_one(each: dict, one: Iterable):
result = {}
for key, value in each.items():
result[key] = value & one
return result
@staticmethod
def get_sample_info_from_filename(
vcf_filename, filename_regex=r"([BNE]{1,2})(\d+)(?:-D(\d+))?"
):
filename_match = re.match(filename_regex, vcf_filename)
if filename_match is None:
logging.warn(
f'The regex "{filename_regex}" did not match ' f'"{vcf_filename}"'
)
return None
sample_type = filename_match.group(1)
sample_day = (
filename_match.group(3) or "1"
) # We assume it's day 1 in the case there were no annotations.
sample_id = filename_match.group(2)
return sample_id, sample_type, sample_day
@staticmethod
def read_lineage_csv(
file_name: str, sample_info_column: str, viral_strand_column: str
):
lineage_info = {}
with open(file_name) as lineage_csv_fd:
reader = csv.reader(lineage_csv_fd)
header = None
for row in reader:
if header is None:
header = row
continue
sample_info = CategorizedVariantRecords.get_sample_info_from_filename(
row[header.index("taxon")]
)
viral_strand_ver = row[header.index(viral_strand_column)]
if viral_strand_ver:
lineage_info[sample_info] = viral_strand_ver
else:
lineage_info[sample_info] = "unknown"
return lineage_info
@staticmethod
def count_samples_in_dict_of_sets(dict_samples: dict[str, list]):
result = {}
for key, value in dict_samples.items():
result[key] = len(value)
return result
@staticmethod
def from_files(
vcfs_path,
lineage_report_csv_path,
filename_regex=r"([BNE]{1,2})(\d+)(?:-D(\d+))?",
):
lineage_dict = CategorizedVariantRecords.read_lineage_csv(
lineage_report_csv_path, "taxon", "scorpio_call"
)
vcf_records = {}
for vcf_filename in os.listdir(vcfs_path):
vcf_full_path = os.path.abspath(os.path.join(vcfs_path, vcf_filename))
if os.path.isdir(vcf_full_path):
continue
if os.path.splitext(vcf_full_path)[1] != ".vcf":
continue
sample_info = CategorizedVariantRecords.get_sample_info_from_filename(
vcf_filename, filename_regex
)
if sample_info is None:
continue
patient_id, sample_type, sample_day = sample_info
with open(vcf_full_path) as vcf_file_desc:
curr_list = []
for vcf_record in vcfpy.Reader(vcf_file_desc):
curr_list.append(vcf_record)
vcf_records[sample_info] = VariantRecordGroup(*sample_info, curr_list)
return CategorizedVariantRecords(vcf_records, lineage_dict)

176
vgaat/vgaat.py Executable file
View File

@ -0,0 +1,176 @@
#!/usr/bin/env python3
import os
import asyncio
import percache
import tempfile
import argparse
import logging
import variants
import mutations
import statistics_tests
import pickle
import shutil
import utils
import filters
async def main(args):
logging.basicConfig(level=args.log.upper())
logging.debug(f'Caching to "{args.cache_dir}".')
if args.clear_cache:
logging.info("Clearing previous cache...")
shutil.rmtree(args.cache_dir)
logging.info("Cache cleared")
logging.info('Using call name "{0}"'.format(args.call_name))
vcf_dir_path = os.path.abspath(args.vcf_dir)
logging.info(f'Fetching VCF files from "{vcf_dir_path}"')
lineage_file = os.path.abspath(args.lineage_file)
logging.info(f'Fetching Lineage file from "{lineage_file}"')
variant_organizer = variants.VariantRecordsOrganizer()
variant_organizer.update(vcf_dir_path, lineage_file, args.sample_filename_re)
logging.info("Building categorized variants...")
categorized_variants = variant_organizer.build()
logging.info("Done")
mutation_organizer = mutations.MutationOrganizer(categorized_variants)
logging.info(f"Using GenBank file from {args.ref_genbank}")
mutation_organizer.update(args.ref_genbank, args.call_name)
categorized_mutations_cache_path = os.path.join(
args.cache_dir, os.path.basename(vcf_dir_path), "categorized_mutations.pickle"
)
if not os.path.exists(categorized_mutations_cache_path):
logging.info("Building categorized mutations...")
categorized_mutations = mutation_organizer.build()
os.makedirs(os.path.dirname(categorized_mutations_cache_path))
with open(categorized_mutations_cache_path, "wb") as fd:
pickle.dump(categorized_mutations, fd)
else:
logging.info(
f"Loading categorized mutations from {categorized_mutations_cache_path}"
)
with open(categorized_mutations_cache_path, "rb") as fd:
categorized_mutations = pickle.load(fd)
logging.info("Done")
# TODO Add all categories as parameters
# TODO How do we create a unanimous test suite???
tester = utils.Tester(
statistics_tests.tests,
["all", *categorized_mutations.pget_category_group("viral lineage").keys()],
["all", *categorized_mutations.pget_category_group("regions").keys()],
[True, False],
[not args.disable_fishers],
categorized_mutations,
categorized_variants,
max_threads=args.threads,
)
results_cache_path = os.path.join(
args.cache_dir, os.path.basename(vcf_dir_path), "results.pickle"
)
if not os.path.exists(results_cache_path):
logging.info("Running all tests...")
tester.run_all_async()
results = tester.get_all_results()
os.makedirs(args.cache_dir, exist_ok=True)
with open(results_cache_path, "wb") as fd:
pickle.dump(results, fd)
else:
logging.info(f"Loading test results from {results_cache_path}")
with open(results_cache_path, "rb") as fd:
results = pickle.load(fd)
logging.info(f"Applying alpha filter of {args.alpha}")
results = filters.filter_by_alpha(results, args.alpha)
if not args.output:
logging.debug("Outputting results to stdout...")
utils.write_markdown_results(results)
else:
logging.debug(f'Outputting to "{args.output}"...')
utils.write_markdown_results(results, md_results_path=args.output)
logging.debug("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="VGAAT",
description="Virus Genome Association Analytics Tools (VGAAT) \
is a python program tool set containing a variety of associative \
algorithms that may be run upon large amounts of VCFs.",
)
parser.add_argument(
"vcf_dir", metavar="i", help="Path to directory containing VCF files"
)
parser.add_argument(
"ref_genbank",
metavar="a",
help="The path to the NCBI GenBank file containing the reference used to \
produce the VCF calls.",
)
parser.add_argument(
"call_name",
metavar="c",
help="The call name to use when reading the VCF files.",
)
parser.add_argument(
"lineage_file",
metavar="l",
help="The CSV file containing information on the samples lineage.",
)
parser.add_argument(
"--sample-filename-re",
metavar="-S",
help="The regex used to interpret the individual sample filenames.",
default=r"([BNE]{1,2})(\d+)(?:-D(\d+))?",
)
parser.add_argument(
"--log",
metavar="-L",
help="Sets the verbosity of the program.",
default="INFO",
)
parser.add_argument(
"--cache-dir",
metavar="-C",
help="Set data cache location. Choose a persistent location if you'd like to \
persist data after a run.",
default="./tmp/VGAAT/data_cache",
)
parser.add_argument(
"--disable-fishers",
metavar="-X",
help="Disables use of the Fisher's Exact Test even when it is possible.",
default=False,
)
parser.add_argument(
"--threads",
metavar="-T",
help="Number of threads to use when performing statistical tests.",
default=16,
type=int,
)
parser.add_argument(
"--clear-cache", metavar="-S", help="Clears cache and then runs.", default=False
)
parser.add_argument(
"--alpha",
metavar="A",
help="Filter results to be within given alpha value.",
default=0.05,
)
parser.add_argument(
"--output",
metavar="-o",
help="Where to output the results.",
)
# TODO Complete adding output and file format options
args = parser.parse_args()
try:
cache = percache.Cache(os.path.join(tempfile.gettempdir(), "cache"))
except PermissionError:
cache = percache.Cache(args.cache_dir)
asyncio.run(main(args))