mutation-case-controller/vgaat/mutations.py

554 lines
20 KiB
Python

from collections import defaultdict
from typing import Callable, Union
from Bio.Seq import MutableSeq, translate
from Bio import GenBank
from frozendict import frozendict
from utils import accession_to_genbank_filename
import logging
from functools import cache
class Mutation:
def __init__(
self,
vcf_record,
sample_id,
sample_type,
sample_day,
viral_lineage,
alt,
sample_call_name="unknown",
):
self._sample_viral_lineage = viral_lineage
self._patient_id = sample_id
self._sample_type = sample_type
self._sample_day = sample_day
self._start = vcf_record.begin
self._end = vcf_record.end or self.start
self._reference = vcf_record.REF
self._call = vcf_record.call_for_sample[sample_call_name]
self._genotype = self.call.gt_bases[0] # Haploid
self._all_alts = vcf_record.ALT
self._alt = alt.value
self._ref_obs = self.call.data["RO"]
self._depth = self.call.data["DP"]
if len(self.genotype) > len(self.ref):
self._variant_type = "insertion"
elif len(self.genotype) < len(self.ref):
self._variant_type = "deletion"
elif len(self.genotype) == len(self.ref):
if len(self.genotype) == 1:
self._variant_type = "SNV"
else:
self._variant_type = "MNV"
self._hash = hash(
(
self._patient_id,
self._sample_type,
self._sample_day,
self._start,
self._end,
self._reference,
self._alt,
)
)
@property
def patient_id(self):
return self._patient_id
@property
def sample_type(self):
return self._sample_type
@property
def sample_day(self):
return self._sample_day
@property
def sample_viral_lineage(self):
return self._sample_viral_lineage
@property
def start(self):
return self._start
@property
def end(self):
return self._end
@property
def ref(self):
return self._reference
@property
def call(self):
return self._call
@property
def genotype(self):
return self._genotype
@property
def alternatives(self):
return self._all_alts
@property
def ref_obs(self):
return self._ref_obs
@property
def depth(self):
return self._depth
@property
def variant_type(self):
return self._variant_type
def __len__(self):
return self.depth
def __hash__(self):
return self._hash
def __eq__(self, other):
return (
self.patient_id == other.get_patient_id
and self.sample_type == other.get_sample_type
and self.sample_day == other.sample_day
and self.start == other.start
and self.end == other.end
and self.ref == other.reference()
and self.call == other.call
and self.genotype == other.genotype
and self.alternatives == other.alternatives
and self.ref_obs == other.ref_obs
and self.depth == other.depth
and self.variant_type == other.variant_type
and self._alt == other._alt_nucleotide
)
@staticmethod
def vcf_reads_to_mutations(
vcf_reads: dict, accession, sample_lineage, sample_call_name="unknown"
):
for sample_info, vcf_records in vcf_reads.items():
sample_id, sample_type, sample_day = sample_info
for vcf_record in vcf_records:
if vcf_record.CHROM != accession:
continue
for alternate in vcf_record.ALT:
yield Mutation(
vcf_record,
sample_id,
sample_type,
sample_day,
sample_lineage[sample_info]
if sample_info in sample_lineage
else None,
alternate,
sample_call_name,
)
class CategorizedMutations:
def __init__(self, genbank_record, mutations: list[Mutation]):
self._genbank_record = genbank_record
self._primitive_category_groups = {}
self._primitive_categories = {}
self._custom_category_groups = {}
self._custom_categories = {}
genes_and_utrs_locs = {}
logging.debug("Categorized mutations being constructed.")
logging.debug("Cataloguing features...")
for genbank_feature in self.get_genbank_record().features:
f_key = genbank_feature.key
if f_key == "gene" or f_key.endswith("UTR"):
f_qual = f_key
f_start, f_end = genbank_feature.location.split("..")
f_start = int(f_start)
f_end = int(f_end)
for qualifier in genbank_feature.qualifiers:
if qualifier.key == "/gene=":
f_qual = qualifier.value
break
if (f_start, f_end) in genes_and_utrs_locs:
raise Exception("Feature with duplicate location detected.")
genes_and_utrs_locs[(f_start, f_end)] = f_qual
logging.debug("Done")
self._accession = genbank_record.accession
self._primitive_category_groups["locations"] = defaultdict(set)
self._primitive_category_groups["allele change"] = defaultdict(set)
self._primitive_category_groups["sample type"] = defaultdict(set)
self._primitive_category_groups["sample day"] = defaultdict(set)
self._primitive_category_groups["variant type"] = defaultdict(set)
self._primitive_category_groups["patient"] = defaultdict(set)
self._primitive_category_groups["viral lineage"] = defaultdict(set)
self._primitive_category_groups["regions"] = defaultdict(set)
self._primitive_category_groups["patient earliest"] = defaultdict(set)
self._primitive_categories["non-synonymous"] = set()
self._primitive_categories["all"] = set(mutations)
logging.debug("Organizing mutations...")
num_of_mutations = len(mutations)
done = 0
for mutation in mutations:
self._primitive_category_groups["locations"][
(mutation.start, mutation.end)
].add(mutation)
self._primitive_category_groups["allele change"][
(mutation.ref, mutation.genotype)
].add(mutation)
self._primitive_category_groups["sample type"][mutation.sample_type].add(
mutation
)
self._primitive_category_groups["variant type"][mutation.variant_type].add(
mutation
)
self._primitive_category_groups["sample day"][mutation.sample_day].add(
mutation
)
if mutation.sample_viral_lineage:
self._primitive_category_groups["viral lineage"][
mutation.sample_viral_lineage
].add(mutation)
# Maintain all samples having the same day and by the end, that day
# is the latest... If we start with nothing, then everything in the
# set has the same day. This is base case. Then, when we add
# something to the set, if one of the previous set has a older time
# stamp, they all do (I.H). We clear the set and add the more
# recent version. If the new timestamp is older then any arbitrary
# item in the current set, then it's older than all items in the
# set by I.H. We do not add this item. The last case if the item's
# timestamp is the same as any arbitrary one. Therefore, all
# items in the set will be the same. Since we are replacing
# all older items with newer ones, the last update should contain
# the latest set.
if self._primitive_category_groups["patient earliest"][mutation.patient_id]:
earliest_day = next(
iter(
self._primitive_category_groups["patient earliest"][
mutation.patient_id
]
)
).sample_day
else:
earliest_day = None
if earliest_day is None or earliest_day > mutation.sample_day:
self._primitive_category_groups["patient earliest"][
mutation.patient_id
].clear()
if earliest_day is None or mutation.sample_day == earliest_day:
self._primitive_category_groups["patient earliest"][
mutation.patient_id
].add(mutation)
self._primitive_category_groups["patient"][mutation.patient_id].add(
mutation
)
mutable_sequence = MutableSeq(self.ref_sequence)
for relative_mutation, mutation_position in enumerate(
range(mutation.start, mutation.end + 1)
):
mutable_sequence[mutation_position] = mutation.genotype[
relative_mutation
]
if translate(mutable_sequence) != translate(self.ref_sequence):
self._primitive_categories["non-synonymous"].add(mutation)
for location in genes_and_utrs_locs.keys():
start, end = location
if start <= mutation.start and end >= mutation.end:
self._primitive_category_groups["regions"][
(genes_and_utrs_locs[location], location)
].add(mutation)
done += 1
logging.info(
f"Organized {done}/{num_of_mutations} \
({done/num_of_mutations:.2%}) mutations"
)
self._primitive_categories["non-synonymous"] = set()
self._primitive_categories["all"] = set(mutations)
self._primitive_category_groups = recursive_freeze_dict(
self._primitive_category_groups
)
def get_filename(self):
"""
Returns the accession filename.
"""
return accession_to_genbank_filename(self._accession)
def get_accession(self):
return self._accession
def get_genbank_record(self):
return self._genbank_record
@property
def ref_sequence(self):
return self._genbank_record.sequence
def pget_category_or_category_group(self, property_name: str):
return (
self._primitive_category_groups[property_name]
or self._primitive_categories[property_name]
)
def pget_category(self, property_name: str):
return self._primitive_categories[property_name]
def pget_category_group(self, property_name: str):
return self._primitive_category_groups[property_name]
@property
def all(self):
return self._primitive_category_groups["all"]
@property
def total_mutations(self):
return len(self.all())
@cache
def mutations_per_site_for_pcategory_group(self, category_group: str):
categorized_mutations = self.pget_category_group(category_group)
if isinstance(categorized_mutations, dict):
resulting_dict = {}
for key, value in categorized_mutations.items():
resulting_dict[key] = self.count_mutations_per_site(value)
return resulting_dict
else:
return self.count_mutations_per_site(categorized_mutations)
@cache
def flattened_pcategory_group(self, category: str) -> set:
categorized_mutations = self.pget_category_group(category)
flattened = set()
for key, set_value in categorized_mutations.items():
flattened.update(set_value)
return flattened
@cache
def mutation_sites_for_pccategory_group(self, category: str):
categorized_mutations = self.pcget_or_create_category_groups(category)
resulting_dict = {}
for key, value in categorized_mutations.items():
resulting_dict[key] = CategorizedMutations.count_mutation_sites(value)
return resulting_dict
@cache
def mutations_for_pccategory_group(
self, category: str, assembler: Union[Callable, None] = None, **kwargs: dict
):
dictionary_set = self.pcget_or_create_category_groups(category)
return CategorizedMutations.count_mutations_for_category_group(dictionary_set)
@cache
def punion_category_groups_with_category(self, category_groups: str, category: str):
category_set = set(self.pget_category(category))
each = self.pget_category_group(category_groups)
one = category_set
result = {}
for subcat, subcat_mutations in each.items():
result[subcat] = subcat_mutations.union(one)
return result
@cache
def pintersect_category_groups_with_category(
self, category_groups: str, category: str
):
category_set = set(self.pget_category(category))
each = self.pget_category_group(category_groups)
result = {}
for subcat, subcat_mutations in each.items():
result[subcat] = subcat_mutations.intersection(category_set)
return result
def pcget_or_create_category_groups(
self, category_group: str, assembler: Union[Callable, None] = None, **kwargs
):
if category_group in self._primitive_category_groups:
return self.pget_category_group(category_group)
if category_group not in self._custom_category_groups:
if not assembler:
raise Exception(
"Assembler was not provided for a non-existent custom category "
f'groups "{category_group}".'
)
result = assembler(**kwargs)
if not isinstance(result, dict):
raise Exception("Assembler result was not a group of categories...")
self._custom_category_groups[category_group] = result
return self._custom_category_groups[category_group]
@cache
def pcunion_category_groups_with_categories(
self,
union_name: str,
category_group: str,
category: tuple[Mutation],
assembler: Union[Callable, None] = None,
assembler_args: dict = {},
):
category_set = set(category)
if union_name in self._custom_category_groups:
return self._custom_category_groups[union_name]
subcategories = self.pcget_or_create_category_groups(
category_group, assembler, **assembler_args
)
result = {}
for subcat, mutations in subcategories.items():
result[subcat] = mutations.union(category_set)
self._custom_category_groups[union_name] = result
return result
@cache
def pcintersect_category_groups_with_categories(
self,
intersect_name: str,
category_group: str,
category: tuple[Mutation],
assembler: Union[Callable, None] = None,
assembler_args: dict = {},
):
category_set = set(category)
if intersect_name in self._custom_category_groups:
return self._custom_category_groups[intersect_name]
subcategories = self.pcget_or_create_category_groups(
category_group, assembler, **assembler_args
)
result = {}
for subcat, mutations in subcategories.items():
result[subcat] = mutations.intersection(category_set)
self._custom_category_groups[intersect_name] = result
return result
@cache
def pcbuild_contingency(
self, category_group_a_name: str, category_group_b_name: str, modifier: Callable
):
category_group_a = self.pcget_or_create_category_groups(category_group_a_name)
category_group_b = self.pcget_or_create_category_groups(category_group_b_name)
proto_contingency = defaultdict(dict)
for a_group, mutations_cat_group_a in category_group_a.items():
for b_group, mutations_cat_group_b in category_group_b.items():
proto_contingency[a_group][b_group] = modifier(
self.intersect(
tuple(mutations_cat_group_a), tuple(mutations_cat_group_b)
)
)
return proto_contingency
def clear_custom_categories(self):
self._custom_category_groups.clear()
@staticmethod
@cache
def count_mutations_per_site(mutations: tuple[Mutation]):
mutation_sites = defaultdict(int)
for mutation in mutations:
for position in range(mutation.start, mutation.end + 1):
mutation_sites[position] += 1
return mutation_sites
@staticmethod
@cache
def count_mutation_sites(mutations: tuple[Mutation]):
mutation_sites = set()
for mutation in mutations:
for nucleotide_position in range(mutation.start, mutation.end + 1):
mutation_sites.add(nucleotide_position)
return len(mutation_sites)
@staticmethod
@cache
def intersect(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]):
intersect = set(mutation)
for add_mut in mutations:
intersect = intersect & set(add_mut)
return intersect
@staticmethod
@cache
def union(mutation: tuple[Mutation], *mutations: tuple[tuple[Mutation]]):
union = set(mutation)
for add_mut in mutations:
union = union | set(add_mut)
return union
@staticmethod
def count_mutations_for_category_group(
group_of_mutations: dict[str, set[Mutation]]
):
resulting_dict = {}
for key, value in group_of_mutations.items():
resulting_dict[key] = len(value)
return resulting_dict
@staticmethod
def build_with_args(
categorized_vcf_records, genbank_file, sample_call_name="sample_call_name"
):
genbank_record = None
with open(genbank_file) as genbank_fd:
genbank_record = next(
GenBank.parse(genbank_fd)
) # Assume there's only one accession per file.
return CategorizedMutations(
genbank_record,
list(
Mutation.vcf_reads_to_mutations(
categorized_vcf_records.get_all_vcf(),
genbank_record.version, # type: ignore
categorized_vcf_records.get_all_lineage_info(),
sample_call_name,
)
),
)
class MutationOrganizer:
def __init__(self, categorized_vcf_records, on_change=None):
self.genbank_path = "unknown"
self.sample_call_name = "unknown"
self._on_change = on_change
self._categorized_vcf_records = categorized_vcf_records
def update(self, genbank_path, sample, **kwargs):
self.genbank_path = genbank_path
self.sample_call_name = sample
if self._on_change:
self._on_change(self, **kwargs)
def build(self):
return CategorizedMutations.build_with_args(
self._categorized_vcf_records, self.genbank_path, self.sample_call_name
)
def recursive_freeze_dict(to_freeze: dict) -> frozendict:
for key, value in to_freeze.items():
if isinstance(value, dict):
to_freeze[key] = recursive_freeze_dict(value)
elif isinstance(value, list):
to_freeze[key] = tuple(value)
return frozendict(to_freeze)
def mutations_collection_to_set(
data: Union[tuple[Mutation], set[Mutation]]
) -> set[Mutation]:
return set(data) if isinstance(data, tuple) else data