192 lines
9.5 KiB
Python

from abc import abstractmethod
from collections import defaultdict
from contextlib import AbstractAsyncContextManager
import csv
from os import path
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union
from aiohttp import ClientSession, ClientTimeout
from autobigs.engine.data.local.fasta import read_fasta
from autobigs.engine.data.structures.genomics import NamedString
from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile
from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
from Bio.Align import PairwiseAligner
class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
@abstractmethod
def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
pass
@abstractmethod
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
pass
@abstractmethod
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
pass
@abstractmethod
def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
pass
@abstractmethod
async def close(self):
pass
class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, schema_id: int):
self._database_name = database_name
self._schema_id = schema_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000))
async def __aenter__(self):
return self
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
uri_path = "sequence"
for sequence_string in sequence_strings:
async with self._http_client.post(uri_path, json={
"sequence": sequence_string,
"partial_matches": True
}) as response:
sequence_response: dict = await response.json()
if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items():
for allele in alleles:
alelle_id = allele["allele_id"]
yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = PartialAllelicMatchProfile(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"])
)
yield Allele(
allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile
)
else:
raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
uri_path = "designations"
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
if isinstance(alleles, AsyncIterable):
async for allele in alleles:
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
else:
for allele in alleles:
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
request_json = {
"designations": allele_request_dict
}
async with self._http_client.post(uri_path, json=request_json) as response:
response_json: dict = await response.json()
allele_map: dict[str, Allele] = {}
response_json.setdefault("fields", dict())
schema_fields_returned: dict[str, str] = response_json["fields"]
schema_fields_returned.setdefault("ST", "unknown")
schema_fields_returned.setdefault("clonal_complex", "unknown")
schema_exact_matches: dict = response_json["exact_matches"]
for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
if len(exact_match_alleles) > 1:
raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)
if len(allele_map) == 0:
raise ValueError("Passed in no alleles.")
return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
alleles = self.fetch_mlst_allele_variants(sequence_strings)
return await self.fetch_mlst_st(alleles)
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
async for named_strings in named_string_groups:
for named_string in named_strings:
try:
yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence])))
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
yield NamedMLSTProfile(named_string.name, None)
async def close(self):
await self._http_client.close()
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
class BIGSdbIndex(AbstractAsyncContextManager):
KNOWN_BIGSDB_APIS = {
"https://bigsdb.pasteur.fr/api",
"https://rest.pubmlst.org"
}
def __init__(self):
self._http_client = ClientSession()
self._known_seqdef_dbs_origin: Union[Mapping[str, str], None] = None
self._seqdefdb_schemas: dict[str, Union[Mapping[str, int], None]] = dict()
super().__init__()
async def __aenter__(self):
return self
async def get_known_seqdef_dbs(self, force: bool = False) -> Mapping[str, str]:
if self._known_seqdef_dbs_origin is not None and not force:
return self._known_seqdef_dbs_origin
known_seqdef_dbs = dict()
for known_bigsdb in BIGSdbIndex.KNOWN_BIGSDB_APIS:
async with self._http_client.get(f"{known_bigsdb}/db") as response:
response_json_databases = await response.json()
for database_group in response_json_databases:
for database_info in database_group["databases"]:
if str(database_info["name"]).endswith("seqdef"):
known_seqdef_dbs[database_info["name"]] = known_bigsdb
self._known_seqdef_dbs_origin = dict(known_seqdef_dbs)
return self._known_seqdef_dbs_origin
async def get_bigsdb_api_from_seqdefdb(self, seqdef_db_name: str) -> str:
known_databases = await self.get_known_seqdef_dbs()
if seqdef_db_name not in known_databases:
raise NoSuchBIGSdbDatabaseException(seqdef_db_name)
return known_databases[seqdef_db_name]
async def get_schemas_for_seqdefdb(self, seqdef_db_name: str, force: bool = False) -> Mapping[str, int]:
if seqdef_db_name in self._seqdefdb_schemas and not force:
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore since it's guaranteed to not be none by conditional
uri_path = f"{await self.get_bigsdb_api_from_seqdefdb(seqdef_db_name)}/db/{seqdef_db_name}/schemes"
async with self._http_client.get(uri_path) as response:
response_json = await response.json()
schema_descriptions: Mapping[str, int] = dict()
for scheme_definition in response_json["schemes"]:
scheme_id: int = int(str(scheme_definition["scheme"]).split("/")[-1])
scheme_desc: str = scheme_definition["description"]
schema_descriptions[scheme_desc] = scheme_id
self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler:
return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def close(self):
await self._http_client.close()
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()