Began implementing LazyPersistentCachedBIGSdbMLSTProfiler

This commit is contained in:
2025-01-27 22:03:49 +00:00
parent ba1f0aa318
commit 3e3898334f
5 changed files with 230 additions and 86 deletions

View File

@@ -1,16 +1,43 @@
from abc import abstractmethod
from collections import defaultdict
from contextlib import AbstractAsyncContextManager
from numbers import Number
from typing import Any, AsyncGenerator, AsyncIterable, Collection, Generator, Iterable, Mapping, Sequence, Union
import csv
from os import path
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union
from aiohttp import ClientSession, ClientTimeout
from autobigs.engine.data.local.fasta import read_fasta
from autobigs.engine.data.structures.genomics import NamedString
from autobigs.engine.data.structures.mlst import Allele, PartialAllelicMatchProfile, MLSTProfile
from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile
from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
from Bio.Align import PairwiseAligner
class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
@abstractmethod
def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
pass
@abstractmethod
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
pass
@abstractmethod
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
pass
@abstractmethod
def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
pass
@abstractmethod
async def close(self):
pass
class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, schema_id: int):
self._database_name = database_name
self._schema_id = schema_id
@@ -20,86 +47,193 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
async def __aenter__(self):
return self
async def fetch_mlst_allele_variants(self, sequence_string: str, exact: bool) -> AsyncGenerator[Allele, Any]:
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
uri_path = "sequence"
async with self._http_client.post(uri_path, json={
"sequence": sequence_string,
"partial_matches": not exact
}) as response:
sequence_response: dict = await response.json()
if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items():
for allele in alleles:
alelle_id = allele["allele_id"]
yield Allele(allele_loci=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
elif "partial_matches" in sequence_response:
if exact:
raise NoBIGSdbExactMatchesException(self._database_name, self._schema_id)
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = PartialAllelicMatchProfile(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
bitscore=float(partial_match["bitscore"]),
gaps=int(partial_match["gaps"])
)
yield Allele(
allele_loci=allele_loci,
allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile
)
else:
raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
for sequence_string in sequence_strings:
async with self._http_client.post(uri_path, json={
"sequence": sequence_string,
"partial_matches": True
}) as response:
sequence_response: dict = await response.json()
if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items():
for allele in alleles:
alelle_id = allele["allele_id"]
yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = PartialAllelicMatchProfile(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"])
)
yield Allele(
allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile
)
else:
raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
uri_path = "designations"
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
if isinstance(alleles, AsyncIterable):
async for allele in alleles:
allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)})
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
else:
for allele in alleles:
allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)})
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
request_json = {
"designations": allele_request_dict
}
async with self._http_client.post(uri_path, json=request_json) as response:
response_json: dict = await response.json()
allele_map: dict[str, list[Allele]] = defaultdict(list)
allele_map: dict[str, Allele] = {}
response_json.setdefault("fields", dict())
schema_fields_returned: dict[str, str] = response_json["fields"]
schema_fields_returned.setdefault("ST", "unknown")
schema_fields_returned.setdefault("clonal_complex", "unknown")
schema_exact_matches: dict = response_json["exact_matches"]
for exact_match_loci, exact_match_alleles in schema_exact_matches.items():
for exact_match_allele in exact_match_alleles:
allele_map[exact_match_loci].append(Allele(exact_match_loci, exact_match_allele["allele_id"], None))
for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
if len(exact_match_alleles) > 1:
raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)
if len(allele_map) == 0:
raise ValueError("Passed in no alleles.")
return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
async def profile_string(self, string: str, exact: bool = False) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(string, exact)
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(sequence_strings)
return await self.fetch_mlst_st(alleles)
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in named_string_groups:
for named_string in named_strings:
try:
yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence])))
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
yield NamedMLSTProfile(named_string.name, None)
async def profile_multiple_strings(self, namedStrings: AsyncIterable[NamedString], exact: bool = False, stop_on_fail: bool = False) -> AsyncGenerator[tuple[str, Union[MLSTProfile, None]], Any]:
async for named_string in namedStrings:
try:
yield (named_string.name, await self.profile_string(named_string.sequence, exact))
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
yield (named_string.name, None)
async def close(self):
await self._http_client.close()
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: str):
self._database_api = database_api
self._database_name = database_name
self._schema_id = schema_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000))
self._cache_path = cache_path
self._loci: list[str] = []
self._profiles = {}
async def load_scheme_locis(self):
self._loci.clear()
async with self._http_client.get("") as schema_response:
schema_json = await schema_response.json()
for locus in schema_json["loci"]:
locus_name = path.basename(locus)
self._loci.append(locus_name)
self._loci.sort()
async def load_scheme_profiles(self):
self._profiles.clear()
with open(self.get_scheme_profile_path()) as profile_cache_handle:
reader = csv.DictReader(profile_cache_handle, delimiter="\t")
for line in reader:
alleles = []
for locus in self._loci:
alleles.append(line[locus])
self._profiles[tuple(alleles)] = (line["ST"], line["clonal_complex"])
def get_locus_cache_path(self, locus) -> str:
return path.join(self._cache_path, locus + "." + "fasta")
def get_scheme_profile_path(self):
return path.join(self._cache_path, "profiles.csv")
async def download_alleles_cache_data(self):
for locus in self._loci:
with open(self.get_locus_cache_path(locus), "wb") as fasta_handle:
async with self._http_client.get(f"/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response:
async for chunk, eof in fasta_response.content.iter_chunks(): # TODO maybe allow chunking to be configurable
fasta_handle.write(chunk)
async def download_scheme_profiles(self):
with open(self.get_scheme_profile_path(), "wb") as profile_cache_handle:
async with self._http_client.get("profiles_csv") as profiles_response:
async for chunk, eof in profiles_response.content.iter_chunks():
profile_cache_handle.write(chunk)
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
aligner = PairwiseAligner("blastn")
aligner.mode = "local"
for sequence_string in sequence_strings:
for locus in self._loci:
async for fasta_seq in read_fasta(self.get_locus_cache_path(locus)):
allele_variant = fasta_seq.name
alignment_results = aligner.align(sequence_string, fasta_seq.sequence)
top_alignment = sorted(alignment_results)[0]
top_alignment_stats = top_alignment.counts()
top_alignment_gaps = top_alignment_stats.gaps
top_alignment_identities = top_alignment_stats.identities
top_alignment_mismatches = top_alignment_stats.mismatches
if top_alignment_gaps == 0 and top_alignment_mismatches == 0:
yield Allele(locus, allele_variant, None)
else:
yield Allele(
locus,
allele_variant,
PartialAllelicMatchProfile(
percent_identity=top_alignment_identities/top_alignment.length,
mismatches=top_alignment_mismatches,
gaps=top_alignment_gaps
)
)
async def fetch_mlst_st(self, alleles):
allele_variants: dict[str, Allele] = {}
if isinstance(alleles, AsyncIterable):
async for allele in alleles:
allele_variants[allele.allele_locus] = allele
else:
for allele in alleles:
allele_variants[allele.allele_locus] = allele
ordered_profile = []
for locus in self._loci:
ordered_profile.append(allele_variants[locus].allele_variant)
st, clonal_complex = self._profiles[tuple(ordered_profile)]
return MLSTProfile(allele_variants, st, clonal_complex)
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(sequence_strings)
return await self.fetch_mlst_st(alleles)
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in named_string_groups:
for named_string in named_strings:
try:
yield NamedMLSTProfile(named_string.name, await self.profile_string([named_string.sequence]))
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
yield NamedMLSTProfile(named_string.name, None)
async def close(self):
await self._http_client.close()
@@ -156,8 +290,8 @@ class BIGSdbIndex(AbstractAsyncContextManager):
self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> BIGSdbMLSTProfiler:
return BIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler:
return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def close(self):
await self._http_client.close()

View File

@@ -5,17 +5,21 @@ from typing import Mapping, Sequence, Union
class PartialAllelicMatchProfile:
percent_identity: float
mismatches: int
bitscore: float
gaps: int
@dataclass(frozen=True)
class Allele:
allele_loci: str
allele_locus: str
allele_variant: str
partial_match_profile: Union[None, PartialAllelicMatchProfile]
@dataclass(frozen=True)
class MLSTProfile:
alleles: Mapping[str, Sequence[Allele]]
alleles: Mapping[str, Allele]
sequence_type: str
clonal_complex: str
@dataclass(frozen=True)
class NamedMLSTProfile:
name: str
mlst_profile: Union[None, MLSTProfile]