from collections import defaultdict from typing import Callable, Union from Bio.Seq import MutableSeq, translate from Bio import GenBank from frozendict import frozendict from utils import accession_to_genbank_filename import logging from functools import cache class Mutation: def __init__( self, vcf_record, sample_id, sample_type, sample_day, viral_lineage, alt, sample_call_name="unknown", ): self._sample_viral_lineage = viral_lineage self._patient_id = sample_id self._sample_type = sample_type self._sample_day = sample_day self._start = vcf_record.begin self._end = vcf_record.end or self.start self._reference = vcf_record.REF self._call = vcf_record.call_for_sample[sample_call_name] self._genotype = self.call.gt_bases[0] # Haploid self._all_alts = vcf_record.ALT self._alt = alt.value self._ref_obs = self.call.data["RO"] self._depth = self.call.data["DP"] if len(self.genotype) > len(self.ref): self._variant_type = "insertion" elif len(self.genotype) < len(self.ref): self._variant_type = "deletion" elif len(self.genotype) == len(self.ref): if len(self.genotype) == 1: self._variant_type = "SNV" else: self._variant_type = "MNV" self._hash = hash( ( self._patient_id, self._sample_type, self._sample_day, self._start, self._end, self._reference, self._alt, ) ) @property def patient_id(self): return self._patient_id @property def sample_type(self): return self._sample_type @property def sample_day(self): return self._sample_day @property def sample_viral_lineage(self): return self._sample_viral_lineage @property def start(self): return self._start @property def end(self): return self._end @property def ref(self): return self._reference @property def call(self): return self._call @property def genotype(self): return self._genotype @property def alternatives(self): return self._all_alts @property def ref_obs(self): return self._ref_obs @property def depth(self): return self._depth @property def variant_type(self): return self._variant_type def __len__(self): return self.depth def __hash__(self): return self._hash def __eq__(self, other): return ( self.patient_id == other.get_patient_id and self.sample_type == other.get_sample_type and self.sample_day == other.sample_day and self.start == other.start and self.end == other.end and self.ref == other.reference() and self.call == other.call and self.genotype == other.genotype and self.alternatives == other.alternatives and self.ref_obs == other.ref_obs and self.depth == other.depth and self.variant_type == other.variant_type and self._alt == other._alt_nucleotide ) @staticmethod def vcf_reads_to_mutations( vcf_reads: dict, accession, sample_lineage, sample_call_name="unknown" ): for sample_info, vcf_records in vcf_reads.items(): sample_id, sample_type, sample_day = sample_info for vcf_record in vcf_records: if vcf_record.CHROM != accession: continue for alternate in vcf_record.ALT: yield Mutation( vcf_record, sample_id, sample_type, sample_day, sample_lineage[sample_info] if sample_info in sample_lineage else None, alternate, sample_call_name, ) class CategorizedMutations: def __init__(self, genbank_record, mutations: list[Mutation]): self._genbank_record = genbank_record self._primitive_category_groups = {} self._primitive_categories = {} self._custom_category_groups = {} self._custom_categories = {} genes_and_utrs_locs = {} logging.debug("Categorized mutations being constructed.") logging.debug("Cataloguing features...") for genbank_feature in self.get_genbank_record().features: f_key = genbank_feature.key if f_key == "gene" or f_key.endswith("UTR"): f_qual = f_key f_start, f_end = genbank_feature.location.split("..") f_start = int(f_start) f_end = int(f_end) for qualifier in genbank_feature.qualifiers: if qualifier.key == "/gene=": f_qual = qualifier.value break if (f_start, f_end) in genes_and_utrs_locs: raise Exception("Feature with duplicate location detected.") genes_and_utrs_locs[(f_start, f_end)] = f_qual logging.debug("Done") self._accession = genbank_record.accession self._primitive_category_groups["locations"] = defaultdict(set) self._primitive_category_groups["allele change"] = defaultdict(set) self._primitive_category_groups["sample type"] = defaultdict(set) self._primitive_category_groups["sample day"] = defaultdict(set) self._primitive_category_groups["variant type"] = defaultdict(set) self._primitive_category_groups["patient"] = defaultdict(set) self._primitive_category_groups["viral lineage"] = defaultdict(set) self._primitive_category_groups["regions"] = defaultdict(set) self._primitive_category_groups["patient earliest"] = defaultdict(set) self._primitive_categories["non-synonymous"] = set() self._primitive_categories["all"] = set(mutations) logging.debug("Organizing mutations...") num_of_mutations = len(mutations) done = 0 for mutation in mutations: self._primitive_category_groups["locations"][ (mutation.start, mutation.end) ].add(mutation) self._primitive_category_groups["allele change"][ (mutation.ref, mutation.genotype) ].add(mutation) self._primitive_category_groups["sample type"][mutation.sample_type].add( mutation ) self._primitive_category_groups["variant type"][mutation.variant_type].add( mutation ) self._primitive_category_groups["sample day"][mutation.sample_day].add( mutation ) if mutation.sample_viral_lineage: self._primitive_category_groups["viral lineage"][ mutation.sample_viral_lineage ].add(mutation) # Maintain all samples having the same day and by the end, that day # is the latest... If we start with nothing, then everything in the # set has the same day. This is base case. Then, when we add # something to the set, if one of the previous set has a older time # stamp, they all do (I.H). We clear the set and add the more # recent version. If the new timestamp is older then any arbitrary # item in the current set, then it's older than all items in the # set by I.H. We do not add this item. The last case if the item's # timestamp is the same as any arbitrary one. Therefore, all # items in the set will be the same. Since we are replacing # all older items with newer ones, the last update should contain # the latest set. if self._primitive_category_groups["patient earliest"][mutation.patient_id]: earliest_day = next( iter( self._primitive_category_groups["patient earliest"][ mutation.patient_id ] ) ).sample_day else: earliest_day = None if earliest_day is None or earliest_day > mutation.sample_day: self._primitive_category_groups["patient earliest"][ mutation.patient_id ].clear() if earliest_day is None or mutation.sample_day == earliest_day: self._primitive_category_groups["patient earliest"][ mutation.patient_id ].add(mutation) self._primitive_category_groups["patient"][mutation.patient_id].add( mutation ) mutable_sequence = MutableSeq(self.ref_sequence) for relative_mutation, mutation_position in enumerate( range(mutation.start, mutation.end + 1) ): mutable_sequence[mutation_position] = mutation.genotype[ relative_mutation ] if translate(mutable_sequence) != translate(self.ref_sequence): self._primitive_categories["non-synonymous"].add(mutation) for location in genes_and_utrs_locs.keys(): start, end = location if start <= mutation.start and end >= mutation.end: self._primitive_category_groups["regions"][ (genes_and_utrs_locs[location], location) ].add(mutation) done += 1 logging.info( f"Organized {done}/{num_of_mutations} \ ({done/num_of_mutations:.2%}) mutations" ) self._primitive_categories["non-synonymous"] = set() self._primitive_categories["all"] = set(mutations) self._primitive_category_groups = recursive_freeze_dict( self._primitive_category_groups ) def get_filename(self): """ Returns the accession filename. """ return accession_to_genbank_filename(self._accession) def get_accession(self): return self._accession def get_genbank_record(self): return self._genbank_record @property def ref_sequence(self): return self._genbank_record.sequence def pget_category_or_category_group(self, property_name: str): return ( self._primitive_category_groups[property_name] or self._primitive_categories[property_name] ) def pget_category(self, property_name: str): return self._primitive_categories[property_name] def pget_category_group(self, property_name: str): return self._primitive_category_groups[property_name] @property def all(self): return self._primitive_category_groups["all"] @property def total_mutations(self): return len(self.all()) @cache def mutations_per_site_for_pcategory_group(self, category_group: str): categorized_mutations = self.pget_category_group(category_group) if isinstance(categorized_mutations, dict): resulting_dict = {} for key, value in categorized_mutations.items(): resulting_dict[key] = self.count_mutations_per_site(value) return resulting_dict else: return self.count_mutations_per_site(categorized_mutations) @cache def flattened_pcategory_group(self, category: str) -> set: categorized_mutations = self.pget_category_group(category) flattened = set() for key, set_value in categorized_mutations.items(): flattened.update(set_value) return flattened @cache def mutation_sites_for_pccategory_group(self, category: str): categorized_mutations = self.pcget_or_create_category_groups(category) resulting_dict = {} for key, value in categorized_mutations.items(): resulting_dict[key] = CategorizedMutations.count_mutation_sites(value) return resulting_dict @cache def mutations_for_pccategory_group( self, category: str, assembler: Union[Callable, None] = None, **kwargs: dict ): dictionary_set = self.pcget_or_create_category_groups(category) return CategorizedMutations.count_mutations_for_category_group(dictionary_set) @cache def punion_category_groups_with_category(self, category_groups: str, category: str): category_set = set(self.pget_category(category)) each = self.pget_category_group(category_groups) one = category_set result = {} for subcat, subcat_mutations in each.items(): result[subcat] = subcat_mutations.union(one) return result @cache def pintersect_category_groups_with_category( self, category_groups: str, category: str ): category_set = set(self.pget_category(category)) each = self.pget_category_group(category_groups) result = {} for subcat, subcat_mutations in each.items(): result[subcat] = subcat_mutations.intersection(category_set) return result def pcget_or_create_category_groups( self, category_group: str, assembler: Union[Callable, None] = None, **kwargs ): if category_group in self._primitive_category_groups: return self.pget_category_group(category_group) if category_group not in self._custom_category_groups: if not assembler: raise Exception( "Assembler was not provided for a non-existent custom category " f'groups "{category_group}".' ) result = assembler(**kwargs) if not isinstance(result, dict): raise Exception("Assembler result was not a group of categories...") self._custom_category_groups[category_group] = result return self._custom_category_groups[category_group] @cache def pcunion_category_groups_with_categories( self, union_name: str, category_group: str, category: tuple[Mutation], assembler: Union[Callable, None] = None, assembler_args: dict = {}, ): category_set = set(category) if union_name in self._custom_category_groups: return self._custom_category_groups[union_name] subcategories = self.pcget_or_create_category_groups( category_group, assembler, **assembler_args ) result = {} for subcat, mutations in subcategories.items(): result[subcat] = mutations.union(category_set) self._custom_category_groups[union_name] = result return result @cache def pcintersect_category_groups_with_categories( self, intersect_name: str, category_group: str, category: tuple[Mutation], assembler: Union[Callable, None] = None, assembler_args: dict = {}, ): category_set = set(category) if intersect_name in self._custom_category_groups: return self._custom_category_groups[intersect_name] subcategories = self.pcget_or_create_category_groups( category_group, assembler, **assembler_args ) result = {} for subcat, mutations in subcategories.items(): result[subcat] = mutations.intersection(category_set) self._custom_category_groups[intersect_name] = result return result @cache def pcbuild_contingency( self, category_group_a_name: str, category_group_b_name: str, modifier: Callable ): category_group_a = self.pcget_or_create_category_groups(category_group_a_name) category_group_b = self.pcget_or_create_category_groups(category_group_b_name) proto_contingency = defaultdict(dict) for a_group, mutations_cat_group_a in category_group_a.items(): for b_group, mutations_cat_group_b in category_group_b.items(): proto_contingency[a_group][b_group] = modifier( self.intersect( tuple(mutations_cat_group_a), tuple(mutations_cat_group_b) ) ) return proto_contingency def clear_custom_categories(self): self._custom_category_groups.clear() @staticmethod @cache def count_mutations_per_site(mutations: tuple[Mutation]): mutation_sites = defaultdict(int) for mutation in mutations: for position in range(mutation.start, mutation.end + 1): mutation_sites[position] += 1 return mutation_sites @staticmethod @cache def count_mutation_sites(mutations: tuple[Mutation]): mutation_sites = set() for mutation in mutations: for nucleotide_position in range(mutation.start, mutation.end + 1): mutation_sites.add(nucleotide_position) return len(mutation_sites) @staticmethod @cache def intersect(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]): intersect = set(mutation) for add_mut in mutations: intersect = intersect & set(add_mut) return intersect @staticmethod @cache def union(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]): union = set(mutation) for add_mut in mutations: union = union | set(add_mut) return union @staticmethod def count_mutations_for_category_group( group_of_mutations: dict[str, set[Mutation]] ): resulting_dict = {} for key, value in group_of_mutations.items(): resulting_dict[key] = len(value) return resulting_dict @staticmethod def build_with_args( categorized_vcf_records, genbank_file, sample_call_name="sample_call_name" ): genbank_record = None with open(genbank_file) as genbank_fd: genbank_record = next( GenBank.parse(genbank_fd) ) # Assume there's only one accession per file. return CategorizedMutations( genbank_record, list( Mutation.vcf_reads_to_mutations( categorized_vcf_records.get_all_vcf(), genbank_record.version, # type: ignore categorized_vcf_records.get_all_lineage_info(), sample_call_name, ) ), ) class MutationOrganizer: def __init__(self, categorized_vcf_records, on_change=None): self.genbank_path = "unknown" self.sample_call_name = "unknown" self._on_change = on_change self._categorized_vcf_records = categorized_vcf_records def update(self, genbank_path, sample, **kwargs): self.genbank_path = genbank_path self.sample_call_name = sample if self._on_change: self._on_change(self, **kwargs) def build(self): return CategorizedMutations.build_with_args( self._categorized_vcf_records, self.genbank_path, self.sample_call_name ) def recursive_freeze_dict(to_freeze: dict) -> frozendict: for key, value in to_freeze.items(): if isinstance(value, dict): to_freeze[key] = recursive_freeze_dict(value) elif isinstance(value, list): to_freeze[key] = tuple(value) return frozendict(to_freeze) def mutations_collection_to_set( data: Union[tuple[Mutation], set[Mutation]] ) -> set[Mutation]: return set(data) if isinstance(data, tuple) else data