Compare commits

..

No commits in common. "34bf02c75a53593a0b271e7f8217dff46b74e46d" and "af7edf0942cba82aaa342776014b1f03e8ae5d5b" have entirely different histories.

View File

@ -9,7 +9,7 @@ import shutil
import tempfile import tempfile
from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union
from aiohttp import ClientOSError, ClientSession, ClientTimeout, ServerDisconnectedError from aiohttp import ClientSession, ClientTimeout
from autobigs.engine.reading import read_fasta from autobigs.engine.reading import read_fasta
from autobigs.engine.structures.alignment import PairwiseAlignment from autobigs.engine.structures.alignment import PairwiseAlignment
@ -43,12 +43,11 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, scheme_id: int, retry_requests: int = 5): def __init__(self, database_api: str, database_name: str, scheme_id: int):
self._retry_limit = retry_requests
self._database_name = database_name self._database_name = database_name
self._scheme_id = scheme_id self._scheme_id = scheme_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/" self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(300)) self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60))
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -58,59 +57,40 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
uri_path = "sequence" uri_path = "sequence"
if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString): if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString):
query_sequence_strings = [query_sequence_strings] query_sequence_strings = [query_sequence_strings]
for sequence_string in query_sequence_strings: for sequence_string in query_sequence_strings:
attempts = 0 async with self._http_client.post(uri_path, json={
success = False "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
last_error = None "partial_matches": True
while not success and attempts < self._retry_limit: }) as response:
attempts += 1 sequence_response: dict = await response.json()
request = self._http_client.post(uri_path, json={
"sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
"partial_matches": True
})
try:
async with request as response:
sequence_response: dict = await response.json()
if "exact_matches" in sequence_response: if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci # loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"] exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items(): for allele_loci, alleles in exact_matches.items():
for allele in alleles: for allele in alleles:
alelle_id = allele["allele_id"] alelle_id = allele["allele_id"]
result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
elif "partial_matches" in sequence_response: elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items(): for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0: if len(partial_match) <= 0:
continue continue
partial_match_profile = AlignmentStats( partial_match_profile = AlignmentStats(
percent_identity=float(partial_match["identity"]), percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]), mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"]), gaps=int(partial_match["gaps"]),
match_metric=int(partial_match["bitscore"]) match_metric=int(partial_match["bitscore"])
) )
result_allele = Allele( result_allele = Allele(
allele_locus=allele_loci, allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]), allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile partial_match_profile=partial_match_profile
) )
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
else:
raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Errors we will retry
last_error = e
success = False
await asyncio.sleep(5) # In case the connection issue is due to rate issues
else: else:
success = True raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
if not success and last_error is not None:
try:
raise last_error
except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Non-fatal errors
yield Allele("error", "error", None)
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]: async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]:
uri_path = "designations" uri_path = "designations"
@ -133,42 +113,22 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
request_json = { request_json = {
"designations": allele_request_dict "designations": allele_request_dict
} }
async with self._http_client.post(uri_path, json=request_json) as response:
attempts = 0 response_json: dict = await response.json()
success = False allele_set: Set[Allele] = set()
last_error = None response_json.setdefault("fields", dict())
while attempts < self._retry_limit and not success: scheme_fields_returned: dict[str, str] = response_json["fields"]
attempts += 1 scheme_fields_returned.setdefault("ST", "unknown")
try: scheme_fields_returned.setdefault("clonal_complex", "unknown")
async with self._http_client.post(uri_path, json=request_json) as response: scheme_exact_matches: dict = response_json["exact_matches"]
response_json: dict = await response.json() for exact_match_locus, exact_match_alleles in scheme_exact_matches.items():
allele_set: Set[Allele] = set() allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None))
response_json.setdefault("fields", dict()) if len(allele_set) == 0:
scheme_fields_returned: dict[str, str] = response_json["fields"] raise ValueError("Passed in no alleles.")
scheme_fields_returned.setdefault("ST", "unknown") result_mlst_profile = MLSTProfile(allele_set, scheme_fields_returned["ST"], scheme_fields_returned["clonal_complex"])
scheme_fields_returned.setdefault("clonal_complex", "unknown") if len(names_list) > 0:
scheme_exact_matches: dict = response_json["exact_matches"] result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0], result_mlst_profile)
for exact_match_locus, exact_match_alleles in scheme_exact_matches.items(): return result_mlst_profile
allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None))
if len(allele_set) == 0:
raise ValueError("Passed in no alleles.")
result_mlst_profile = MLSTProfile(allele_set, scheme_fields_returned["ST"], scheme_fields_returned["clonal_complex"])
if len(names_list) > 0:
result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0], result_mlst_profile)
return result_mlst_profile
except (ConnectionError, ServerDisconnectedError, ClientOSError) as e:
last_error = e
success = False
await asyncio.sleep(5)
else:
success = True
try:
if last_error is not None:
raise last_error
except (ConnectionError, ServerDisconnectedError, ClientOSError) as e:
result_mlst_profile = NamedMLSTProfile((str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0]) + ":Error", None)
raise ValueError("Last error was not recorded.")
async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]: async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]:
alleles = self.determine_mlst_allele_variants(query_sequence_strings) alleles = self.determine_mlst_allele_variants(query_sequence_strings)