Implemented annotated local typing method without testing

This commit is contained in:
2025-02-04 16:19:00 +00:00
parent 341ca933a3
commit ff8a1aff08
21 changed files with 27726 additions and 374 deletions

View File

@@ -0,0 +1,71 @@
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import AbstractContextManager
from typing import Any, Set, Union
from Bio.Align import PairwiseAligner
from queue import Queue
from autobigs.engine.structures.alignment import AlignmentStats, PairwiseAlignment
class AsyncPairwiseAlignmentEngine(AbstractContextManager):
def __enter__(self):
self._thread_pool = ThreadPoolExecutor(self._max_threads)
return self
def __init__(self, aligner: PairwiseAligner, max_threads: int = 4):
self._max_threads = max_threads
self._aligner = aligner
self._work_left: Set[Future] = set()
self._work_complete: Queue[Future] = Queue()
def align(self, reference: str, query: str, **associated_data):
work = self._thread_pool.submit(
self.work, reference, query, **associated_data)
work.add_done_callback(self._on_complete)
self._work_left.add(work)
def _on_complete(self, future: Future):
self._work_complete.put(future)
def work(self, reference, query, **associated_data):
alignment_results = sorted(self._aligner.align(reference, query))[0]
top_alignment_stats = alignment_results.counts()
top_alignment_gaps = top_alignment_stats.gaps
top_alignment_identities = top_alignment_stats.identities
top_alignment_mismatches = top_alignment_stats.mismatches
top_alignment_score = alignment_results.score # type: ignore
return PairwiseAlignment(
alignment_results.sequences[0],
alignment_results.sequences[1],
alignment_results.indices[0],
alignment_results.indices[1],
AlignmentStats(
percent_identity=top_alignment_identities/alignment_results.length,
mismatches=top_alignment_mismatches,
gaps=top_alignment_gaps,
score=top_alignment_score
)), associated_data
async def next_completed(self) -> Union[tuple[PairwiseAlignment, dict[str, Any]], None]:
if self._work_complete.empty() and len(self._work_left):
return None
future_now: Future = await asyncio.wrap_future(self._work_complete.get())
completed: tuple[PairwiseAlignment, dict[str, Any]] = (future_now).result()
self._work_left.remove(future_now)
return completed
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown()
def __aiter__(self):
return self
async def __anext__(self):
result = await self.next_completed()
if result is None:
raise StopAsyncIteration
return result
def shutdown(self):
self._thread_pool.shutdown(wait=True, cancel_futures=True)

View File

@@ -1,15 +1,21 @@
from abc import abstractmethod
import asyncio
from collections import defaultdict
from contextlib import AbstractAsyncContextManager
import csv
from os import path
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union
import os
import shutil
import tempfile
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Set, Union
from aiohttp import ClientSession, ClientTimeout
from autobigs.engine.data.local.fasta import read_fasta
from autobigs.engine.data.structures.genomics import NamedString
from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile
from autobigs.engine.analysis.aligners import AsyncPairwiseAlignmentEngine
from autobigs.engine.reading import read_fasta
from autobigs.engine.structures.alignment import PairwiseAlignment
from autobigs.engine.structures.genomics import NamedString
from autobigs.engine.structures.mlst import Allele, NamedMLSTProfile, AlignmentStats, MLSTProfile
from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
from Bio.Align import PairwiseAligner
@@ -17,26 +23,26 @@ from Bio.Align import PairwiseAligner
class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
@abstractmethod
def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
pass
@abstractmethod
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
pass
@abstractmethod
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile:
pass
@abstractmethod
def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
pass
@abstractmethod
async def close(self):
pass
class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, schema_id: int):
self._database_name = database_name
@@ -47,11 +53,13 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def __aenter__(self):
return self
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
uri_path = "sequence"
if not isinstance(query_sequence_strings, Iterable):
raise ValueError("Invalid data type for parameter \"sequence_strings\".")
for sequence_string in sequence_strings:
for sequence_string in query_sequence_strings:
async with self._http_client.post(uri_path, json={
"sequence": sequence_string,
"partial_matches": True
@@ -70,10 +78,11 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = PartialAllelicMatchProfile(
partial_match_profile = AlignmentStats(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"])
gaps=int(partial_match["gaps"]),
score=int(partial_match["score"])
)
yield Allele(
allele_locus=allele_loci,
@@ -83,7 +92,7 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
else:
raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
uri_path = "designations"
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
if isinstance(alleles, AsyncIterable):
@@ -97,7 +106,7 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
}
async with self._http_client.post(uri_path, json=request_json) as response:
response_json: dict = await response.json()
allele_map: dict[str, Allele] = {}
allele_set: Set[Allele] = set()
response_json.setdefault("fields", dict())
schema_fields_returned: dict[str, str] = response_json["fields"]
schema_fields_returned.setdefault("ST", "unknown")
@@ -106,17 +115,17 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
if len(exact_match_alleles) > 1:
raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)
if len(allele_map) == 0:
allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None))
if len(allele_set) == 0:
raise ValueError("Passed in no alleles.")
return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
return MLSTProfile(allele_set, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(sequence_strings)
return await self.fetch_mlst_st(alleles)
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.determine_mlst_allele_variants(query_sequence_strings)
return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in named_string_groups:
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in query_named_string_groups:
for named_string in named_strings:
try:
yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence])))
@@ -131,20 +140,36 @@ class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: str):
class LocalBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def __aenter__(self):
if self._prepare:
await self.update_scheme_locis()
await asyncio.gather(
self.download_alleles_cache_data(),
self.download_scheme_profiles()
)
await self.load_scheme_profiles()
return self
def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: Union[str, None] = None, prepare: bool =True):
self._database_api = database_api
self._database_name = database_name
self._schema_id = schema_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000))
self._cache_path = cache_path
if cache_path is None:
self._cache_path = tempfile.mkdtemp("BIGSdb")
self._cleanup_required = True
else:
self._cache_path = cache_path
self._cleanup_required = False
self._loci: list[str] = []
self._profiles = {}
self._profiles_st_map = {}
self._prepare = prepare
async def load_scheme_locis(self):
async def update_scheme_locis(self):
self._loci.clear()
async with self._http_client.get("") as schema_response:
async with self._http_client.get(f"/api/db/{self._database_name}/schemes/{self._schema_id}") as schema_response:
schema_json = await schema_response.json()
for locus in schema_json["loci"]:
locus_name = path.basename(locus)
@@ -152,14 +177,14 @@ class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
self._loci.sort()
async def load_scheme_profiles(self):
self._profiles.clear()
self._profiles_st_map.clear()
with open(self.get_scheme_profile_path()) as profile_cache_handle:
reader = csv.DictReader(profile_cache_handle, delimiter="\t")
for line in reader:
alleles = []
for locus in self._loci:
alleles.append(line[locus])
self._profiles[tuple(alleles)] = (line["ST"], line["clonal_complex"])
self._profiles_st_map[tuple(alleles)] = (line["ST"], line["clonal_complex"])
def get_locus_cache_path(self, locus) -> str:
return path.join(self._cache_path, locus + "." + "fasta")
@@ -170,8 +195,8 @@ class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def download_alleles_cache_data(self):
for locus in self._loci:
with open(self.get_locus_cache_path(locus), "wb") as fasta_handle:
async with self._http_client.get(f"/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response:
async for chunk, eof in fasta_response.content.iter_chunks(): # TODO maybe allow chunking to be configurable
async with self._http_client.get(f"/api/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response:
async for chunk, eof in fasta_response.content.iter_chunks():
fasta_handle.write(chunk)
async def download_scheme_profiles(self):
@@ -179,34 +204,41 @@ class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async with self._http_client.get("profiles_csv") as profiles_response:
async for chunk, eof in profiles_response.content.iter_chunks():
profile_cache_handle.write(chunk)
await self.load_scheme_profiles()
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
aligner = PairwiseAligner("blastn")
aligner.mode = "local"
for sequence_string in sequence_strings:
for locus in self._loci:
async for fasta_seq in read_fasta(self.get_locus_cache_path(locus)):
allele_variant = fasta_seq.name
alignment_results = aligner.align(sequence_string, fasta_seq.sequence)
top_alignment = sorted(alignment_results)[0]
top_alignment_stats = top_alignment.counts()
top_alignment_gaps = top_alignment_stats.gaps
top_alignment_identities = top_alignment_stats.identities
top_alignment_mismatches = top_alignment_stats.mismatches
if top_alignment_gaps == 0 and top_alignment_mismatches == 0:
yield Allele(locus, allele_variant, None)
with AsyncPairwiseAlignmentEngine(aligner) as aligner_engine:
for query_sequence_string in query_sequence_strings:
for locus in self._loci:
async for allele_variant in read_fasta(self.get_locus_cache_path(locus)):
aligner_engine.align(allele_variant.sequence, query_sequence_string, variant_name=allele_variant.name, full=True)
break # start a bunch of full alignments for each variant to select segments
alignment_rankings: dict[str, set[tuple[PairwiseAlignment, str]]] = defaultdict(set)
async for alignment_result, additional_information in aligner_engine:
result_variant_name = additional_information["variant_name"]
result_locus, variant_id = result_variant_name.split("_")
full_alignment = additional_information["full"]
if full_alignment:
if alignment_result.alignment_stats.gaps == 0 and alignment_result.alignment_stats.mismatches == 0:
# I.e., 100% exactly the same
yield Allele(result_locus, variant_id, None)
continue
else:
yield Allele(
locus,
allele_variant,
PartialAllelicMatchProfile(
percent_identity=top_alignment_identities/top_alignment.length,
mismatches=top_alignment_mismatches,
gaps=top_alignment_gaps
)
)
alignment_rankings[result_locus].add((alignment_result, variant_id))
interest_sequence = full_alignment[alignment_result.query_indices[0]:alignment_result.query_indices[-1]]
async for allele_variant in read_fasta(self.get_locus_cache_path(result_locus)):
if result_variant_name == allele_variant.name:
continue # Skip if we just finished aligning this
aligner_engine.align(allele_variant.sequence, interest_sequence, variant_name=result_variant_name.name, full=False)
else:
alignment_rankings[result_locus].add((alignment_result, variant_id))
for final_locus, alignments in alignment_rankings.items():
closest_alignment, closest_variant_id = sorted(alignments, key=lambda index: index[0].alignment_stats.score)[0]
yield Allele(final_locus, closest_variant_id, closest_alignment.alignment_stats)
async def fetch_mlst_st(self, alleles):
async def determine_mlst_st(self, alleles):
allele_variants: dict[str, Allele] = {}
if isinstance(alleles, AsyncIterable):
async for allele in alleles:
@@ -218,15 +250,15 @@ class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
for locus in self._loci:
ordered_profile.append(allele_variants[locus].allele_variant)
st, clonal_complex = self._profiles[tuple(ordered_profile)]
return MLSTProfile(allele_variants, st, clonal_complex)
st, clonal_complex = self._profiles_st_map[tuple(ordered_profile)]
return MLSTProfile(set(allele_variants.values()), st, clonal_complex)
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(sequence_strings)
return await self.fetch_mlst_st(alleles)
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.determine_mlst_allele_variants(query_sequence_strings)
return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in named_string_groups:
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in query_named_string_groups:
for named_string in named_strings:
try:
yield NamedMLSTProfile(named_string.name, await self.profile_string([named_string.sequence]))
@@ -237,6 +269,8 @@ class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def close(self):
await self._http_client.close()
if self._cleanup_required:
shutil.rmtree(self._cache_path)
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
@@ -290,12 +324,16 @@ class BIGSdbIndex(AbstractAsyncContextManager):
self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler:
return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> RemoteBIGSdbMLSTProfiler:
return RemoteBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def close(self):
await self._http_client.close()
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
def get_BIGSdb_MLST_profiler(local: bool, database_api: str, database_name: str, schema_id: int):
if local:
return LocalBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id)
return RemoteBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id)

View File

@@ -1,25 +0,0 @@
from dataclasses import dataclass
from typing import Mapping, Sequence, Union
@dataclass(frozen=True)
class PartialAllelicMatchProfile:
percent_identity: float
mismatches: int
gaps: int
@dataclass(frozen=True)
class Allele:
allele_locus: str
allele_variant: str
partial_match_profile: Union[None, PartialAllelicMatchProfile]
@dataclass(frozen=True)
class MLSTProfile:
alleles: Mapping[str, Allele]
sequence_type: str
clonal_complex: str
@dataclass(frozen=True)
class NamedMLSTProfile:
name: str
mlst_profile: Union[None, MLSTProfile]

View File

@@ -1,9 +1,9 @@
import asyncio
from io import TextIOWrapper
from typing import Any, AsyncGenerator, Generator, Iterable, Sequence, Union
from typing import Any, AsyncGenerator, Iterable, Union
from Bio import SeqIO
from autobigs.engine.data.structures.genomics import NamedString
from autobigs.engine.structures.genomics import NamedString
async def read_fasta(handle: Union[str, TextIOWrapper]) -> AsyncGenerator[NamedString, Any]:
fasta_sequences = asyncio.to_thread(SeqIO.parse, handle=handle, format="fasta")

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
from numbers import Number
@dataclass(frozen=True)
class AlignmentStats:
percent_identity: float
mismatches: int
gaps: int
score: int
@dataclass(frozen=True)
class PairwiseAlignment:
reference: str
query: str
reference_indices: list[Number]
query_indices: list[Number]
alignment_stats: AlignmentStats

View File

@@ -0,0 +1,33 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Collection, Iterable, Mapping, Sequence, Union
from autobigs.engine.structures.alignment import AlignmentStats
@dataclass(frozen=True)
class Allele:
allele_locus: str
allele_variant: str
partial_match_profile: Union[None, AlignmentStats]
@dataclass(frozen=True)
class MLSTProfile:
alleles: Collection[Allele]
sequence_type: str
clonal_complex: str
@dataclass(frozen=True)
class NamedMLSTProfile:
name: str
mlst_profile: Union[None, MLSTProfile]
def alleles_to_mapping(alleles: Iterable[Allele]):
result = defaultdict(list)
for allele in alleles:
result[allele.allele_locus].append(allele.allele_variant)
result = dict(result)
for locus, variant in result.items():
if len(variant) == 1:
result[locus] = variant[0]
return result

View File

@@ -2,19 +2,13 @@ import csv
from os import PathLike
from typing import AsyncIterable, Mapping, Sequence, Union
from autobigs.engine.data.structures.mlst import Allele, MLSTProfile
from autobigs.engine.structures.mlst import Allele, MLSTProfile
def dict_loci_alleles_variants_from_loci(alleles_map: Mapping[str, Sequence[Allele]]):
def dict_loci_alleles_variants_from_loci(alleles_map: Mapping[str, Allele]):
result_dict: dict[str, Union[list[str], str]] = {}
for loci, alleles in alleles_map.items():
if len(alleles) == 1:
result_dict[loci] = alleles[0].allele_variant
else:
result_locis = list()
for allele in alleles:
result_locis.append(allele.allele_variant)
result_dict[loci] = result_locis
result_dict[loci] = alleles.allele_variant
return result_dict