From 0da81f26dddad19c4f6a7365860f01081c1b55ea Mon Sep 17 00:00:00 2001 From: Harrison Date: Thu, 23 Mar 2023 16:24:25 -0500 Subject: [PATCH] Updated 'statistics_tests.py' to conform to mutations API --- .gitignore | 5 + .vscode/launch.json | 53 ++++ .vscode/settings.json | 23 ++ tox.ini | 3 + vgaat/filters.py | 12 + vgaat/mutations.py | 553 ++++++++++++++++++++++++++++++++++++++ vgaat/statistics_tests.py | 115 ++++++++ vgaat/utils.py | 204 ++++++++++++++ vgaat/variants.py | 164 +++++++++++ vgaat/vgaat.py | 176 ++++++++++++ 10 files changed, 1308 insertions(+) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 tox.ini create mode 100644 vgaat/filters.py create mode 100644 vgaat/mutations.py create mode 100644 vgaat/statistics_tests.py create mode 100644 vgaat/utils.py create mode 100644 vgaat/variants.py create mode 100755 vgaat/vgaat.py diff --git a/.gitignore b/.gitignore index 207f105..ca93118 100644 --- a/.gitignore +++ b/.gitignore @@ -238,3 +238,8 @@ $RECYCLE.BIN/ # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) +data +*-data +lineage_report.csv +cache +*.gb \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..8de3c3d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,53 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "VGAAT: run all", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/vgaat/vgaat.py", + "args": [ + "${workspaceFolder}/test_data/", + "${workspaceFolder}/MN908947_3.gb", + "unknown", + "${workspaceFolder}/nml_lineage_report.csv", + "--log", "DEBUG", + "--alpha", "0.05", + "--output", "${workspaceFolder}/results.md", + "--threads", "1", + ], + "console": "integratedTerminal", + "justMyCode": true, + }, + { + "name": "VGAAT: run all - NO CACHE", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/vgaat/vgaat.py", + "args": [ + "${workspaceFolder}/test_data/", + "${workspaceFolder}/MN908947_3.gb", + "unknown", + "${workspaceFolder}/nml_lineage_report.csv", + "--log", "DEBUG", + "--alpha", "0.05", + "--output", "${workspaceFolder}/results.md", + "--threads", "1", + "--clear-cache", "True" + ], + "console": "integratedTerminal", + "justMyCode": true, + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9659530 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,23 @@ +{ + "cSpell.words": [ + "Biopython", + "funcs", + "genbank", + "heatmaps", + "matplotlib", + "NCBI", + "orjson", + "percache", + "pyplot", + "scipy", + "utrs", + "vcfpy", + "vcfs", + "VGAAT" + ], + "python.formatting.provider": "black", + "python.linting.prospectorEnabled": false, + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.analysis.typeCheckingMode": "basic" +} diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..e0ea542 --- /dev/null +++ b/tox.ini @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 \ No newline at end of file diff --git a/vgaat/filters.py b/vgaat/filters.py new file mode 100644 index 0000000..4eb21e1 --- /dev/null +++ b/vgaat/filters.py @@ -0,0 +1,12 @@ +def filter_by_alpha(result_groups: dict[str, list[dict]], alpha=0.05): + filtered_result_groups = {} + for identifier, result_group in result_groups.items(): + test_alias, viral_strand, regions, synonymity, fishers = identifier + filtered_result_data = [] + title, desc, result_data = result_group + for info in result_data: + if (not isinstance(info["p"], str)) and info["p"] < float(alpha): + filtered_result_data.append(dict(info)) + if len(filtered_result_data) > 0: + filtered_result_groups[identifier] = (title, desc, filtered_result_data) + return filtered_result_groups diff --git a/vgaat/mutations.py b/vgaat/mutations.py new file mode 100644 index 0000000..128e10f --- /dev/null +++ b/vgaat/mutations.py @@ -0,0 +1,553 @@ +from collections import defaultdict +from typing import Callable, Union +from Bio.Seq import MutableSeq, translate +from Bio import GenBank +from frozendict import frozendict +from utils import accession_to_genbank_filename +import logging +from functools import cache + + +class Mutation: + def __init__( + self, + vcf_record, + sample_id, + sample_type, + sample_day, + viral_lineage, + alt, + sample_call_name="unknown", + ): + self._sample_viral_lineage = viral_lineage + self._patient_id = sample_id + self._sample_type = sample_type + self._sample_day = sample_day + self._start = vcf_record.begin + self._end = vcf_record.end or self.start + self._reference = vcf_record.REF + self._call = vcf_record.call_for_sample[sample_call_name] + self._genotype = self.call.gt_bases[0] # Haploid + self._all_alts = vcf_record.ALT + self._alt = alt.value + self._ref_obs = self.call.data["RO"] + self._depth = self.call.data["DP"] + + if len(self.genotype) > len(self.ref): + self._variant_type = "insertion" + elif len(self.genotype) < len(self.ref): + self._variant_type = "deletion" + elif len(self.genotype) == len(self.ref): + if len(self.genotype) == 1: + self._variant_type = "SNV" + else: + self._variant_type = "MNV" + + self._hash = hash( + ( + self._patient_id, + self._sample_type, + self._sample_day, + self._start, + self._end, + self._reference, + self._alt, + ) + ) + + @property + def patient_id(self): + return self._patient_id + + @property + def sample_type(self): + return self._sample_type + + @property + def sample_day(self): + return self._sample_day + + @property + def sample_viral_lineage(self): + return self._sample_viral_lineage + + @property + def start(self): + return self._start + + @property + def end(self): + return self._end + + @property + def ref(self): + return self._reference + + @property + def call(self): + return self._call + + @property + def genotype(self): + return self._genotype + + @property + def alternatives(self): + return self._all_alts + + @property + def ref_obs(self): + return self._ref_obs + + @property + def depth(self): + return self._depth + + @property + def variant_type(self): + return self._variant_type + + def __len__(self): + return self.depth + + def __hash__(self): + return self._hash + + def __eq__(self, other): + return ( + self.patient_id == other.get_patient_id + and self.sample_type == other.get_sample_type + and self.sample_day == other.sample_day + and self.start == other.start + and self.end == other.end + and self.ref == other.reference() + and self.call == other.call + and self.genotype == other.genotype + and self.alternatives == other.alternatives + and self.ref_obs == other.ref_obs + and self.depth == other.depth + and self.variant_type == other.variant_type + and self._alt == other._alt_nucleotide + ) + + @staticmethod + def vcf_reads_to_mutations( + vcf_reads: dict, accession, sample_lineage, sample_call_name="unknown" + ): + for sample_info, vcf_records in vcf_reads.items(): + sample_id, sample_type, sample_day = sample_info + for vcf_record in vcf_records: + if vcf_record.CHROM != accession: + continue + for alternate in vcf_record.ALT: + yield Mutation( + vcf_record, + sample_id, + sample_type, + sample_day, + sample_lineage[sample_info] + if sample_info in sample_lineage + else None, + alternate, + sample_call_name, + ) + + +class CategorizedMutations: + def __init__(self, genbank_record, mutations: list[Mutation]): + self._genbank_record = genbank_record + self._primitive_category_groups = {} + self._primitive_categories = {} + self._custom_category_groups = {} + self._custom_categories = {} + genes_and_utrs_locs = {} + + logging.debug("Categorized mutations being constructed.") + logging.debug("Cataloguing features...") + for genbank_feature in self.get_genbank_record().features: + f_key = genbank_feature.key + if f_key == "gene" or f_key.endswith("UTR"): + f_qual = f_key + f_start, f_end = genbank_feature.location.split("..") + f_start = int(f_start) + f_end = int(f_end) + for qualifier in genbank_feature.qualifiers: + if qualifier.key == "/gene=": + f_qual = qualifier.value + break + if (f_start, f_end) in genes_and_utrs_locs: + raise Exception("Feature with duplicate location detected.") + genes_and_utrs_locs[(f_start, f_end)] = f_qual + logging.debug("Done") + + self._accession = genbank_record.accession + + self._primitive_category_groups["locations"] = defaultdict(set) + self._primitive_category_groups["allele change"] = defaultdict(set) + self._primitive_category_groups["sample type"] = defaultdict(set) + self._primitive_category_groups["sample day"] = defaultdict(set) + self._primitive_category_groups["variant type"] = defaultdict(set) + self._primitive_category_groups["patient"] = defaultdict(set) + self._primitive_category_groups["viral lineage"] = defaultdict(set) + self._primitive_category_groups["regions"] = defaultdict(set) + self._primitive_category_groups["patient earliest"] = defaultdict(set) + + self._primitive_categories["non-synonymous"] = set() + self._primitive_categories["all"] = set(mutations) + + logging.debug("Organizing mutations...") + num_of_mutations = len(mutations) + done = 0 + for mutation in mutations: + self._primitive_category_groups["locations"][ + (mutation.start, mutation.end) + ].add(mutation) + self._primitive_category_groups["allele change"][ + (mutation.ref, mutation.genotype) + ].add(mutation) + self._primitive_category_groups["sample type"][mutation.sample_type].add( + mutation + ) + self._primitive_category_groups["variant type"][mutation.variant_type].add( + mutation + ) + self._primitive_category_groups["sample day"][mutation.sample_day].add( + mutation + ) + if mutation.sample_viral_lineage: + self._primitive_category_groups["viral lineage"][ + mutation.sample_viral_lineage + ].add(mutation) + # Maintain all samples having the same day and by the end, that day + # is the latest... If we start with nothing, then everything in the + # set has the same day. This is base case. Then, when we add + # something to the set, if one of the previous set has a older time + # stamp, they all do (I.H). We clear the set and add the more + # recent version. If the new timestamp is older then any arbitrary + # item in the current set, then it's older than all items in the + # set by I.H. We do not add this item. The last case if the item's + # timestamp is the same as any arbitrary one. Therefore, all + # items in the set will be the same. Since we are replacing + # all older items with newer ones, the last update should contain + # the latest set. + if self._primitive_category_groups["patient earliest"][mutation.patient_id]: + earliest_day = next( + iter( + self._primitive_category_groups["patient earliest"][ + mutation.patient_id + ] + ) + ).sample_day + else: + earliest_day = None + if earliest_day is None or earliest_day > mutation.sample_day: + self._primitive_category_groups["patient earliest"][ + mutation.patient_id + ].clear() + if earliest_day is None or mutation.sample_day == earliest_day: + self._primitive_category_groups["patient earliest"][ + mutation.patient_id + ].add(mutation) + self._primitive_category_groups["patient"][mutation.patient_id].add( + mutation + ) + mutable_sequence = MutableSeq(self.ref_sequence) + for relative_mutation, mutation_position in enumerate( + range(mutation.start, mutation.end + 1) + ): + mutable_sequence[mutation_position] = mutation.genotype[ + relative_mutation + ] + if translate(mutable_sequence) != translate(self.ref_sequence): + self._primitive_categories["non-synonymous"].add(mutation) + + for location in genes_and_utrs_locs.keys(): + start, end = location + if start <= mutation.start and end >= mutation.end: + self._primitive_category_groups["regions"][ + (genes_and_utrs_locs[location], location) + ].add(mutation) + done += 1 + logging.info( + f"Organized {done}/{num_of_mutations} \ + ({done/num_of_mutations:.2%}) mutations" + ) + + self._primitive_categories["non-synonymous"] = set() + self._primitive_categories["all"] = set(mutations) + + self._primitive_category_groups = recursive_freeze_dict( + self._primitive_category_groups + ) + + def get_filename(self): + """ + Returns the accession filename. + """ + return accession_to_genbank_filename(self._accession) + + def get_accession(self): + return self._accession + + def get_genbank_record(self): + return self._genbank_record + + @property + def ref_sequence(self): + return self._genbank_record.sequence + + def pget_category_or_category_group(self, property_name: str): + return ( + self._primitive_category_groups[property_name] + or self._primitive_categories[property_name] + ) + + def pget_category(self, property_name: str): + return self._primitive_categories[property_name] + + def pget_category_group(self, property_name: str): + return self._primitive_category_groups[property_name] + + @property + def all(self): + return self._primitive_category_groups["all"] + + @property + def total_mutations(self): + return len(self.all()) + + @cache + def mutations_per_site_for_pcategory_group(self, category_group: str): + categorized_mutations = self.pget_category_group(category_group) + if isinstance(categorized_mutations, dict): + resulting_dict = {} + for key, value in categorized_mutations.items(): + resulting_dict[key] = self.count_mutations_per_site(value) + return resulting_dict + else: + return self.count_mutations_per_site(categorized_mutations) + + @cache + def flattened_pcategory_group(self, category: str) -> set: + categorized_mutations = self.pget_category_group(category) + flattened = set() + for key, set_value in categorized_mutations.items(): + flattened.update(set_value) + return flattened + + @cache + def mutation_sites_for_pccategory_group(self, category: str): + categorized_mutations = self.pcget_or_create_category_groups(category) + resulting_dict = {} + for key, value in categorized_mutations.items(): + resulting_dict[key] = CategorizedMutations.count_mutation_sites(value) + return resulting_dict + + @cache + def mutations_for_pccategory_group( + self, category: str, assembler: Union[Callable, None] = None, **kwargs: dict + ): + dictionary_set = self.pcget_or_create_category_groups(category) + return CategorizedMutations.count_mutations_for_category_group(dictionary_set) + + @cache + def punion_category_groups_with_category(self, category_groups: str, category: str): + category_set = set(self.pget_category(category)) + each = self.pget_category_group(category_groups) + one = category_set + result = {} + for subcat, subcat_mutations in each.items(): + result[subcat] = subcat_mutations.union(one) + return result + + @cache + def pintersect_category_groups_with_category( + self, category_groups: str, category: str + ): + category_set = set(self.pget_category(category)) + each = self.pget_category_group(category_groups) + result = {} + for subcat, subcat_mutations in each.items(): + result[subcat] = subcat_mutations.intersection(category_set) + return result + + def pcget_or_create_category_groups( + self, category_group: str, assembler: Union[Callable, None] = None, **kwargs + ): + if category_group in self._primitive_category_groups: + return self.pget_category_group(category_group) + if category_group not in self._custom_category_groups: + if not assembler: + raise Exception( + "Assembler was not provided for a non-existent custom category " + f'groups "{category_group}".' + ) + result = assembler(**kwargs) + if not isinstance(result, dict): + raise Exception("Assembler result was not a group of categories...") + self._custom_category_groups[category_group] = result + return self._custom_category_groups[category_group] + + @cache + def pcunion_category_groups_with_categories( + self, + union_name: str, + category_group: str, + category: tuple[Mutation], + assembler: Union[Callable, None] = None, + assembler_args: dict = {}, + ): + category_set = set(category) + if union_name in self._custom_category_groups: + return self._custom_category_groups[union_name] + + subcategories = self.pcget_or_create_category_groups( + category_group, assembler, **assembler_args + ) + result = {} + for subcat, mutations in subcategories.items(): + result[subcat] = mutations.union(category_set) + + self._custom_category_groups[union_name] = result + return result + + @cache + def pcintersect_category_groups_with_categories( + self, + intersect_name: str, + category_group: str, + category: tuple[Mutation], + assembler: Union[Callable, None] = None, + assembler_args: dict = {}, + ): + category_set = set(category) + + if intersect_name in self._custom_category_groups: + return self._custom_category_groups[intersect_name] + + subcategories = self.pcget_or_create_category_groups( + category_group, assembler, **assembler_args + ) + result = {} + for subcat, mutations in subcategories.items(): + result[subcat] = mutations.intersection(category_set) + self._custom_category_groups[intersect_name] = result + return result + + @cache + def pcbuild_contingency( + self, category_group_a_name: str, category_group_b_name: str, modifier: Callable + ): + category_group_a = self.pcget_or_create_category_groups(category_group_a_name) + category_group_b = self.pcget_or_create_category_groups(category_group_b_name) + proto_contingency = defaultdict(dict) + for a_group, mutations_cat_group_a in category_group_a.items(): + for b_group, mutations_cat_group_b in category_group_b.items(): + proto_contingency[a_group][b_group] = modifier( + self.intersect( + tuple(mutations_cat_group_a), tuple(mutations_cat_group_b) + ) + ) + return proto_contingency + + def clear_custom_categories(self): + self._custom_category_groups.clear() + + @staticmethod + @cache + def count_mutations_per_site(mutations: tuple[Mutation]): + mutation_sites = defaultdict(int) + for mutation in mutations: + for position in range(mutation.start, mutation.end + 1): + mutation_sites[position] += 1 + return mutation_sites + + @staticmethod + @cache + def count_mutation_sites(mutations: tuple[Mutation]): + mutation_sites = set() + for mutation in mutations: + for nucleotide_position in range(mutation.start, mutation.end + 1): + mutation_sites.add(nucleotide_position) + return len(mutation_sites) + + @staticmethod + @cache + def intersect(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]): + intersect = set(mutation) + for add_mut in mutations: + intersect = intersect & set(add_mut) + return intersect + + @staticmethod + @cache + def union(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]): + union = set(mutation) + for add_mut in mutations: + union = union | set(add_mut) + return union + + @staticmethod + def count_mutations_for_category_group( + group_of_mutations: dict[str, set[Mutation]] + ): + resulting_dict = {} + for key, value in group_of_mutations.items(): + resulting_dict[key] = len(value) + return resulting_dict + + @staticmethod + def build_with_args( + categorized_vcf_records, genbank_file, sample_call_name="sample_call_name" + ): + genbank_record = None + with open(genbank_file) as genbank_fd: + genbank_record = next( + GenBank.parse(genbank_fd) + ) # Assume there's only one accession per file. + + return CategorizedMutations( + genbank_record, + list( + Mutation.vcf_reads_to_mutations( + categorized_vcf_records.get_all_vcf(), + genbank_record.version, # type: ignore + categorized_vcf_records.get_all_lineage_info(), + sample_call_name, + ) + ), + ) + + +class MutationOrganizer: + def __init__(self, categorized_vcf_records, on_change=None): + self.genbank_path = "unknown" + self.sample_call_name = "unknown" + self._on_change = on_change + self._categorized_vcf_records = categorized_vcf_records + + def update(self, genbank_path, sample, **kwargs): + self.genbank_path = genbank_path + self.sample_call_name = sample + if self._on_change: + self._on_change(self, **kwargs) + + def build(self): + return CategorizedMutations.build_with_args( + self._categorized_vcf_records, self.genbank_path, self.sample_call_name + ) + + +def recursive_freeze_dict(to_freeze: dict) -> frozendict: + for key, value in to_freeze.items(): + if isinstance(value, dict): + to_freeze[key] = recursive_freeze_dict(value) + elif isinstance(value, list): + to_freeze[key] = tuple(value) + return frozendict(to_freeze) + + +def mutations_collection_to_set( + data: Union[tuple[Mutation], set[Mutation]] +) -> set[Mutation]: + return set(data) if isinstance(data, tuple) else data diff --git a/vgaat/statistics_tests.py b/vgaat/statistics_tests.py new file mode 100644 index 0000000..8588954 --- /dev/null +++ b/vgaat/statistics_tests.py @@ -0,0 +1,115 @@ +import pandas as pd +from mutations import CategorizedMutations +from variants import CategorizedVariantRecords +import utils +from collections import defaultdict + + +def mutation_site_count_vs_sample_type( + cat_muts: CategorizedMutations, + cat_vars: CategorizedVariantRecords, + viral_strand, + region, + non_synonymous, + allow_fishers, +): + # Create list of earliest mutations + earliest_mutations_category = "earliest mutations" + cat_muts.pcintersect_category_groups_with_categories( + earliest_mutations_category, + "sample type", + tuple(cat_muts.flattened_pcategory_group("patient earliest")) + ) + + working_category_name = earliest_mutations_category + next_working_category_name = earliest_mutations_category + if viral_strand != "all": + next_working_category_name += f" - {viral_strand}" + cat_muts.pcintersect_category_groups_with_categories( + next_working_category_name, + working_category_name, + tuple(cat_muts.pget_category_group("viral lineage")[viral_strand]), + ) + working_category_name = next_working_category_name + if region != "all": + next_working_category_name += f" - {region}" + cat_muts.pcintersect_category_groups_with_categories( + next_working_category_name, + working_category_name, + tuple(cat_muts.pget_category_group("regions")[region]), + ) + working_category_name = next_working_category_name + + if non_synonymous: + next_working_category_name += " - non-synonymous" + cat_muts.pcintersect_category_groups_with_categories( + next_working_category_name, + working_category_name, + tuple(cat_muts.pget_category("non-synonymous")), + ) + working_category_name = next_working_category_name + + mutated_site_counts_per_sample_type = ( + cat_muts.mutation_sites_for_pccategory_group(working_category_name) + ) + reference_site_counts_per_sample_type = dict(mutated_site_counts_per_sample_type) + + for key, value in reference_site_counts_per_sample_type.items(): + reference_site_counts_per_sample_type[key] = len(cat_muts.ref_sequence) - value + + conti_table_dict = { + "Mutated": mutated_site_counts_per_sample_type, + "Reference": reference_site_counts_per_sample_type, + } + + contigency_test = pd.DataFrame(conti_table_dict) + result = utils.row_pairwise_dependency_test(contigency_test, allow_fishers) + + return ( + "Is mutation site count dependent on sample type?", + "Here, we will count the sites of mutation in the " + "reference sequence and compare that number for each " + "type of mutation. The following is the table we will " + "we will use: ", + utils.publish_results(result), + ) + + +# TODO Write individual nucleotide test +def individual_mutation_count_vs_sample_type( + cat_muts: CategorizedMutations, + cat_vars: CategorizedVariantRecords, + viral_strand, + region, + non_synonymous, + allow_fishers, +): + results = [] + for location, location_based_mutations in cat_muts.pget_category_group("locations"): + contingency_table_per_location = defaultdict(dict) + for sample_type, sample_type_based_mutations in cat_muts.pget_category_group( + "sample type" + ): + contingency_table_per_location[f"Mutations at {location}"][ + sample_type + ] = len( + cat_muts.intersect( + location_based_mutations, sample_type_based_mutations + ) + ) + contingency_table_per_location["Reference"][sample_type] = len( + location_based_mutations + ) - len(cat_muts.pget_category("patients")) + utils.publish_results(contingency_table_per_location, results) + + return "Location based mutations vs Sample type", results + + +# TODO Write distribution difference test (?) + +tests = [ + ( + "Correlation between the number of mutation locations and sample types", + mutation_site_count_vs_sample_type, + ) +] diff --git a/vgaat/utils.py b/vgaat/utils.py new file mode 100644 index 0000000..c919d6c --- /dev/null +++ b/vgaat/utils.py @@ -0,0 +1,204 @@ +import hashlib +import os +import shutil +import sys +from scipy.stats import chi2_contingency, fisher_exact +import pandas as pd +import itertools +from multiprocessing.pool import ThreadPool +import logging + + +def accession_to_genbank_filename(accession): + return accession.replace(".", "_") + ".gb" + + +def genbank_file_accession(gb_name): + return gb_name.replace("_", ".").replace(".gb", "") + + +def row_pairwise_dependency_test(table: pd.DataFrame, allow_fishers=True): + pairs = [] + indices = table.index + results = {} + for i in range(len(indices) - 1): + for j in range(i + 1, len(indices)): + pairs.append((indices[i], indices[j])) + for a, b in pairs: + row_pair_unmodified = table.loc[[a, b]] + row_pair = row_pair_unmodified.loc[ + :, (row_pair_unmodified != 0).any(axis=0) + ] # Magic to drop columns that are empty... + row_pair = row_pair.loc[(row_pair != 0).any(axis=1)] + if row_pair.shape[0] == 2 and row_pair.shape[1] == 2 and allow_fishers: + odds_ratio, p = fisher_exact(row_pair) + results[f"{a} - {b}"] = { + "type": "exact", + } + elif row_pair.shape[1] > 2 and row_pair.shape[0] >= 2: + chi2, p, dof, expected = chi2_contingency(row_pair) + results[f"{a} - {b}"] = { + "type": "chi2", + } + elif row_pair.shape[1] == 0 and row_pair.shape[0] == 0: + continue + else: + results[f"{a} - {b}"] = { + "type": "chi2", + "original_table": row_pair_unmodified, + } + p = "Error" + results[f"{a} - {b}"]["p"] = p + results[f"{a} - {b}"]["table"] = row_pair + return results + + +# TODO Clean this up... +def publish_results(results: dict, results_list=None): + results_list = results_list or [] + for result_info in results.values(): + results_list.append(result_info) + return results_list + + +class Tester: + def __init__( + self, + test_funcs, + viral_strands, + regions, + synonymity, + fishers, + categorized_mutations, + categorized_variants, + max_threads=16, + ): + self.tests = test_funcs + self.viral_strands = viral_strands + self.regions = regions + self.synonymity = synonymity + self.fishers = fishers + self.categorized_mutations = categorized_mutations + self.categorized_variants = categorized_variants + self._max_threads = max_threads + self._results = {} + + def run_all_async(self): + self._thread_pool = ThreadPool(processes=self._max_threads) + param_product = itertools.product( + self.tests, self.viral_strands, self.regions, self.synonymity, self.fishers + ) + total = len(self.tests) * \ + len(self.viral_strands) * \ + len(self.regions) * \ + len(self.synonymity) * \ + len(self.fishers) + runs = self._thread_pool.imap_unordered( + self.run_test, + param_product, + chunksize=self._max_threads # TODO Perhaps add more sophisticated param... + ) + for running_test in runs: + identifier, res = running_test + self._results[identifier] = res + logging.info(f"Test progress: {(len(self._results))/total:.1%}") + + def run_test(self, params): + test_tuple, viral_strand, regions, synonymity, fishers = params + test_alias, test_func = test_tuple + logging.debug( + "Running {0} with parameters: {1} {2} {3} {4}".format( + test_alias, viral_strand, regions, synonymity, fishers + ) + ) + res = test_func( + self.categorized_mutations, + self.categorized_variants, + viral_strand, + regions, + synonymity, + fishers + ) + logging.debug("Completed running {0} with parameters: {1} {2} {3} {4}".format( + test_alias, viral_strand, regions, synonymity, fishers + )) + return (test_alias, viral_strand, regions, synonymity, fishers), res + + def get_result_list(self, test_alias: str, viral_strand: str, regions, + synonymity: bool, fishers: bool): + return self._results[test_alias, viral_strand, regions, synonymity, fishers] + + def get_all_results(self): + return self._results + + +def write_markdown_results(result_groups: dict[str, list[dict]], md_results_path=None): + if md_results_path: + large_dir = md_results_path + ".large" + if os.path.exists(large_dir): + shutil.rmtree(large_dir) + + writer_p = sys.stdout + if md_results_path: + writer_p = open(md_results_path, "w") + for test_id, result_list in result_groups.items(): + result_group_title, result_group_desc, results = result_list + writer_p.write(f"# {result_group_title}\n\n") + writer_p.write(f"{result_group_desc}\n\n") + writer_p.write("Run with the following parameters:\n") + writer_p.write( + """ Test name: {0}; \ + Lineage: {1}; \ + Region: {2}; \ + Allow synonymous: {3}; \ + Allow Exact (Fishers): {4};\ + \n\n""".format(*test_id) + ) + + for result in results: + writer_p.write(f"P-value: {result['p']:.5%}; \ + Test Type: {result['type']}\n\n") + if result["table"].shape[0] < 10 and result["table"].shape[1] < 10: + writer_p.write(f"{result['table'].to_markdown()}\n---\n\n") + else: + writer_p.write( + f"Table was too large {result['table'].shape} to display.\n" + ) + if md_results_path: + large_table_dir = os.path.join( + md_results_path + ".large", + os.path.sep.join(map(str, test_id)). + replace(" ", ""). + replace("\t", ""). + replace("'", ""). + replace("\"", ""). + replace("(", ""). + replace(")", ""). + replace(",", "_") + ) + filename = str( + hashlib.shake_128( + result["table"].to_markdown().encode("utf-8") + ).hexdigest(8) + ) + filepath = os.path.join( + large_table_dir, + filename + ) + same_num = 0 + while os.path.exists(filepath + str(same_num) + ".csv"): + same_num += 1 + large_table_path = filepath + str(same_num) + ".csv" + relative_table_path = large_table_path.replace( + os.path.abspath(md_results_path), + "." + os.path.sep + os.path.basename(md_results_path) + ) + os.makedirs(large_table_dir, exist_ok=True) + result["table"].to_csv(path_or_buf=large_table_path) + writer_p.write( + "Table stored as CSV file. See at:\n" + f"[{relative_table_path}]({relative_table_path})" + ) + writer_p.write("\n") + if md_results_path: + writer_p.close() diff --git a/vgaat/variants.py b/vgaat/variants.py new file mode 100644 index 0000000..9ef3312 --- /dev/null +++ b/vgaat/variants.py @@ -0,0 +1,164 @@ +import csv +import re +import os +from typing import Iterable +import vcfpy +from collections import defaultdict +import logging + + +class VariantRecordsOrganizer: + def __init__(self, on_change=None): + self._on_change = on_change + self._vcfs_dir = "" + self._lineage_report_csv_path = "" + self._vcf_filename_regex = "" + + def update(self, vcfs_dir, lineage_report_csv_path, vcf_filename_regex, **kwargs): + self._vcfs_dir = vcfs_dir + self._lineage_report_csv_path = lineage_report_csv_path + self._vcf_filename_regex = vcf_filename_regex + if self._on_change: + self._on_change(self, **kwargs) + + def build(self): + return CategorizedVariantRecords.from_files( + self._vcfs_dir, self._lineage_report_csv_path, self._vcf_filename_regex + ) + + +class VariantRecordGroup: + def __init__(self, patient_id, sample_type, sample_day, records: list): + self._patient_id = patient_id + self._sample_type = sample_type + self._sample_day = sample_day + self._variants = records + + @property + def vcf_info(self): + return self._patient_id, self._sample_type, self._sample_day + + @property + def variants(self): + return self._variants + + def __hash__(self): + return hash(self.vcf_info) + + def __eq__(self, o): + return self.vcf_info == o.vcf_info + + def __iter__(self): + return iter(self.variants) + + +class CategorizedVariantRecords: + def __init__(self, vcf_data, lineage_info): + self._primitive_category_groups = {} + self._primitive_category_groups["sample day"] = defaultdict(set) + self._primitive_category_groups["patient id"] = defaultdict(set) + self._primitive_category_groups["sample type"] = defaultdict(set) + self._primitive_category_groups["viral lineage"] = defaultdict(set) + self._all_vcf = vcf_data + self._all_lineage_info = lineage_info + for vcf_info, records in vcf_data.items(): + patient_id, sample_type, sample_day = vcf_info + self._primitive_category_groups["sample day"][sample_day].add(records) + self._primitive_category_groups["patient id"][patient_id].add(records) + self._primitive_category_groups["sample type"][sample_type].add(records) + if vcf_info in lineage_info: + self._primitive_category_groups["viral lineage"][ + lineage_info[vcf_info] + ].add(records) + + def get_category_groups(self, category: str): + return self._primitive_category_groups[category] + + def get_all_vcf(self): + return self._all_vcf + + def get_all_lineage_info(self): + return self._all_lineage_info + + @staticmethod + def intersect_each_with_one(each: dict, one: Iterable): + result = {} + for key, value in each.items(): + result[key] = value & one + return result + + @staticmethod + def get_sample_info_from_filename( + vcf_filename, filename_regex=r"([BNE]{1,2})(\d+)(?:-D(\d+))?" + ): + filename_match = re.match(filename_regex, vcf_filename) + if filename_match is None: + logging.warn( + f'The regex "{filename_regex}" did not match ' f'"{vcf_filename}"' + ) + return None + sample_type = filename_match.group(1) + sample_day = ( + filename_match.group(3) or "1" + ) # We assume it's day 1 in the case there were no annotations. + sample_id = filename_match.group(2) + return sample_id, sample_type, sample_day + + @staticmethod + def read_lineage_csv( + file_name: str, sample_info_column: str, viral_strand_column: str + ): + lineage_info = {} + with open(file_name) as lineage_csv_fd: + reader = csv.reader(lineage_csv_fd) + header = None + for row in reader: + if header is None: + header = row + continue + sample_info = CategorizedVariantRecords.get_sample_info_from_filename( + row[header.index("taxon")] + ) + viral_strand_ver = row[header.index(viral_strand_column)] + if viral_strand_ver: + lineage_info[sample_info] = viral_strand_ver + else: + lineage_info[sample_info] = "unknown" + return lineage_info + + @staticmethod + def count_samples_in_dict_of_sets(dict_samples: dict[str, list]): + result = {} + for key, value in dict_samples.items(): + result[key] = len(value) + return result + + @staticmethod + def from_files( + vcfs_path, + lineage_report_csv_path, + filename_regex=r"([BNE]{1,2})(\d+)(?:-D(\d+))?", + ): + lineage_dict = CategorizedVariantRecords.read_lineage_csv( + lineage_report_csv_path, "taxon", "scorpio_call" + ) + vcf_records = {} + for vcf_filename in os.listdir(vcfs_path): + vcf_full_path = os.path.abspath(os.path.join(vcfs_path, vcf_filename)) + if os.path.isdir(vcf_full_path): + continue + if os.path.splitext(vcf_full_path)[1] != ".vcf": + continue + sample_info = CategorizedVariantRecords.get_sample_info_from_filename( + vcf_filename, filename_regex + ) + if sample_info is None: + continue + + patient_id, sample_type, sample_day = sample_info + with open(vcf_full_path) as vcf_file_desc: + curr_list = [] + for vcf_record in vcfpy.Reader(vcf_file_desc): + curr_list.append(vcf_record) + vcf_records[sample_info] = VariantRecordGroup(*sample_info, curr_list) + return CategorizedVariantRecords(vcf_records, lineage_dict) diff --git a/vgaat/vgaat.py b/vgaat/vgaat.py new file mode 100755 index 0000000..fbbb89f --- /dev/null +++ b/vgaat/vgaat.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 + +import os +import asyncio +import percache +import tempfile +import argparse +import logging +import variants +import mutations +import statistics_tests +import pickle +import shutil +import utils +import filters + + +async def main(args): + logging.basicConfig(level=args.log.upper()) + logging.debug(f'Caching to "{args.cache_dir}".') + if args.clear_cache: + logging.info("Clearing previous cache...") + shutil.rmtree(args.cache_dir) + logging.info("Cache cleared") + + logging.info('Using call name "{0}"'.format(args.call_name)) + + vcf_dir_path = os.path.abspath(args.vcf_dir) + logging.info(f'Fetching VCF files from "{vcf_dir_path}"') + lineage_file = os.path.abspath(args.lineage_file) + logging.info(f'Fetching Lineage file from "{lineage_file}"') + variant_organizer = variants.VariantRecordsOrganizer() + variant_organizer.update(vcf_dir_path, lineage_file, args.sample_filename_re) + logging.info("Building categorized variants...") + categorized_variants = variant_organizer.build() + logging.info("Done") + mutation_organizer = mutations.MutationOrganizer(categorized_variants) + + logging.info(f"Using GenBank file from {args.ref_genbank}") + mutation_organizer.update(args.ref_genbank, args.call_name) + categorized_mutations_cache_path = os.path.join( + args.cache_dir, os.path.basename(vcf_dir_path), "categorized_mutations.pickle" + ) + if not os.path.exists(categorized_mutations_cache_path): + logging.info("Building categorized mutations...") + categorized_mutations = mutation_organizer.build() + os.makedirs(os.path.dirname(categorized_mutations_cache_path)) + with open(categorized_mutations_cache_path, "wb") as fd: + pickle.dump(categorized_mutations, fd) + else: + logging.info( + f"Loading categorized mutations from {categorized_mutations_cache_path}" + ) + with open(categorized_mutations_cache_path, "rb") as fd: + categorized_mutations = pickle.load(fd) + logging.info("Done") + + # TODO Add all categories as parameters + # TODO How do we create a unanimous test suite??? + tester = utils.Tester( + statistics_tests.tests, + ["all", *categorized_mutations.pget_category_group("viral lineage").keys()], + ["all", *categorized_mutations.pget_category_group("regions").keys()], + [True, False], + [not args.disable_fishers], + categorized_mutations, + categorized_variants, + max_threads=args.threads, + ) + results_cache_path = os.path.join( + args.cache_dir, os.path.basename(vcf_dir_path), "results.pickle" + ) + if not os.path.exists(results_cache_path): + logging.info("Running all tests...") + tester.run_all_async() + results = tester.get_all_results() + os.makedirs(args.cache_dir, exist_ok=True) + with open(results_cache_path, "wb") as fd: + pickle.dump(results, fd) + else: + logging.info(f"Loading test results from {results_cache_path}") + with open(results_cache_path, "rb") as fd: + results = pickle.load(fd) + + logging.info(f"Applying alpha filter of {args.alpha}") + results = filters.filter_by_alpha(results, args.alpha) + + if not args.output: + logging.debug("Outputting results to stdout...") + utils.write_markdown_results(results) + else: + logging.debug(f'Outputting to "{args.output}"...') + utils.write_markdown_results(results, md_results_path=args.output) + logging.debug("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="VGAAT", + description="Virus Genome Association Analytics Tools (VGAAT) \ + is a python program tool set containing a variety of associative \ + algorithms that may be run upon large amounts of VCFs.", + ) + parser.add_argument( + "vcf_dir", metavar="i", help="Path to directory containing VCF files" + ) + parser.add_argument( + "ref_genbank", + metavar="a", + help="The path to the NCBI GenBank file containing the reference used to \ + produce the VCF calls.", + ) + parser.add_argument( + "call_name", + metavar="c", + help="The call name to use when reading the VCF files.", + ) + parser.add_argument( + "lineage_file", + metavar="l", + help="The CSV file containing information on the samples lineage.", + ) + parser.add_argument( + "--sample-filename-re", + metavar="-S", + help="The regex used to interpret the individual sample filenames.", + default=r"([BNE]{1,2})(\d+)(?:-D(\d+))?", + ) + parser.add_argument( + "--log", + metavar="-L", + help="Sets the verbosity of the program.", + default="INFO", + ) + parser.add_argument( + "--cache-dir", + metavar="-C", + help="Set data cache location. Choose a persistent location if you'd like to \ + persist data after a run.", + default="./tmp/VGAAT/data_cache", + ) + parser.add_argument( + "--disable-fishers", + metavar="-X", + help="Disables use of the Fisher's Exact Test even when it is possible.", + default=False, + ) + parser.add_argument( + "--threads", + metavar="-T", + help="Number of threads to use when performing statistical tests.", + default=16, + type=int, + ) + parser.add_argument( + "--clear-cache", metavar="-S", help="Clears cache and then runs.", default=False + ) + parser.add_argument( + "--alpha", + metavar="A", + help="Filter results to be within given alpha value.", + default=0.05, + ) + parser.add_argument( + "--output", + metavar="-o", + help="Where to output the results.", + ) + # TODO Complete adding output and file format options + + args = parser.parse_args() + try: + cache = percache.Cache(os.path.join(tempfile.gettempdir(), "cache")) + except PermissionError: + cache = percache.Cache(args.cache_dir) + asyncio.run(main(args))