192 lines
9.5 KiB
Python
192 lines
9.5 KiB
Python
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from contextlib import AbstractAsyncContextManager
|
|
import csv
|
|
from os import path
|
|
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union
|
|
|
|
from aiohttp import ClientSession, ClientTimeout
|
|
|
|
from autobigs.engine.data.local.fasta import read_fasta
|
|
from autobigs.engine.data.structures.genomics import NamedString
|
|
from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile
|
|
from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
|
|
|
|
from Bio.Align import PairwiseAligner
|
|
|
|
class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
|
|
|
|
@abstractmethod
|
|
def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def close(self):
|
|
pass
|
|
|
|
class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
|
|
|
|
def __init__(self, database_api: str, database_name: str, schema_id: int):
|
|
self._database_name = database_name
|
|
self._schema_id = schema_id
|
|
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
|
|
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000))
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
|
|
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
|
|
uri_path = "sequence"
|
|
|
|
for sequence_string in sequence_strings:
|
|
async with self._http_client.post(uri_path, json={
|
|
"sequence": sequence_string,
|
|
"partial_matches": True
|
|
}) as response:
|
|
sequence_response: dict = await response.json()
|
|
|
|
if "exact_matches" in sequence_response:
|
|
# loci -> list of alleles with id and loci
|
|
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
|
|
for allele_loci, alleles in exact_matches.items():
|
|
for allele in alleles:
|
|
alelle_id = allele["allele_id"]
|
|
yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
|
|
elif "partial_matches" in sequence_response:
|
|
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
|
|
for allele_loci, partial_match in partial_matches.items():
|
|
if len(partial_match) <= 0:
|
|
continue
|
|
partial_match_profile = PartialAllelicMatchProfile(
|
|
percent_identity=float(partial_match["identity"]),
|
|
mismatches=int(partial_match["mismatches"]),
|
|
gaps=int(partial_match["gaps"])
|
|
)
|
|
yield Allele(
|
|
allele_locus=allele_loci,
|
|
allele_variant=str(partial_match["allele"]),
|
|
partial_match_profile=partial_match_profile
|
|
)
|
|
else:
|
|
raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
|
|
|
|
async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
|
|
uri_path = "designations"
|
|
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
|
|
if isinstance(alleles, AsyncIterable):
|
|
async for allele in alleles:
|
|
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
|
|
else:
|
|
for allele in alleles:
|
|
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
|
|
request_json = {
|
|
"designations": allele_request_dict
|
|
}
|
|
async with self._http_client.post(uri_path, json=request_json) as response:
|
|
response_json: dict = await response.json()
|
|
allele_map: dict[str, Allele] = {}
|
|
response_json.setdefault("fields", dict())
|
|
schema_fields_returned: dict[str, str] = response_json["fields"]
|
|
schema_fields_returned.setdefault("ST", "unknown")
|
|
schema_fields_returned.setdefault("clonal_complex", "unknown")
|
|
schema_exact_matches: dict = response_json["exact_matches"]
|
|
for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
|
|
if len(exact_match_alleles) > 1:
|
|
raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
|
|
allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)
|
|
if len(allele_map) == 0:
|
|
raise ValueError("Passed in no alleles.")
|
|
return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
|
|
|
|
async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
|
|
alleles = self.fetch_mlst_allele_variants(sequence_strings)
|
|
return await self.fetch_mlst_st(alleles)
|
|
|
|
async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
|
|
async for named_strings in named_string_groups:
|
|
for named_string in named_strings:
|
|
try:
|
|
yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence])))
|
|
except NoBIGSdbMatchesException as e:
|
|
if stop_on_fail:
|
|
raise e
|
|
yield NamedMLSTProfile(named_string.name, None)
|
|
|
|
async def close(self):
|
|
await self._http_client.close()
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
await self.close()
|
|
|
|
class BIGSdbIndex(AbstractAsyncContextManager):
|
|
KNOWN_BIGSDB_APIS = {
|
|
"https://bigsdb.pasteur.fr/api",
|
|
"https://rest.pubmlst.org"
|
|
}
|
|
|
|
def __init__(self):
|
|
self._http_client = ClientSession()
|
|
self._known_seqdef_dbs_origin: Union[Mapping[str, str], None] = None
|
|
self._seqdefdb_schemas: dict[str, Union[Mapping[str, int], None]] = dict()
|
|
super().__init__()
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def get_known_seqdef_dbs(self, force: bool = False) -> Mapping[str, str]:
|
|
if self._known_seqdef_dbs_origin is not None and not force:
|
|
return self._known_seqdef_dbs_origin
|
|
known_seqdef_dbs = dict()
|
|
for known_bigsdb in BIGSdbIndex.KNOWN_BIGSDB_APIS:
|
|
async with self._http_client.get(f"{known_bigsdb}/db") as response:
|
|
response_json_databases = await response.json()
|
|
for database_group in response_json_databases:
|
|
for database_info in database_group["databases"]:
|
|
if str(database_info["name"]).endswith("seqdef"):
|
|
known_seqdef_dbs[database_info["name"]] = known_bigsdb
|
|
self._known_seqdef_dbs_origin = dict(known_seqdef_dbs)
|
|
return self._known_seqdef_dbs_origin
|
|
|
|
async def get_bigsdb_api_from_seqdefdb(self, seqdef_db_name: str) -> str:
|
|
known_databases = await self.get_known_seqdef_dbs()
|
|
if seqdef_db_name not in known_databases:
|
|
raise NoSuchBIGSdbDatabaseException(seqdef_db_name)
|
|
return known_databases[seqdef_db_name]
|
|
|
|
async def get_schemas_for_seqdefdb(self, seqdef_db_name: str, force: bool = False) -> Mapping[str, int]:
|
|
if seqdef_db_name in self._seqdefdb_schemas and not force:
|
|
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore since it's guaranteed to not be none by conditional
|
|
uri_path = f"{await self.get_bigsdb_api_from_seqdefdb(seqdef_db_name)}/db/{seqdef_db_name}/schemes"
|
|
async with self._http_client.get(uri_path) as response:
|
|
response_json = await response.json()
|
|
schema_descriptions: Mapping[str, int] = dict()
|
|
for scheme_definition in response_json["schemes"]:
|
|
scheme_id: int = int(str(scheme_definition["scheme"]).split("/")[-1])
|
|
scheme_desc: str = scheme_definition["description"]
|
|
schema_descriptions[scheme_desc] = scheme_id
|
|
self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
|
|
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
|
|
|
|
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler:
|
|
return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
|
|
|
|
async def close(self):
|
|
await self._http_client.close()
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
await self.close()
|
|
|