Added retry functionality for allele variant determination

This commit is contained in:
Harrison Deng 2025-03-13 15:54:35 +00:00
parent af7edf0942
commit 8ffc7c7fb5

@ -43,7 +43,8 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
def __init__(self, database_api: str, database_name: str, scheme_id: int):
def __init__(self, database_api: str, database_name: str, scheme_id: int, retry_requests: int = 5):
self._retry_limit = retry_requests
self._database_name = database_name
self._scheme_id = scheme_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/"
@ -57,41 +58,56 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
uri_path = "sequence"
if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString):
query_sequence_strings = [query_sequence_strings]
for sequence_string in query_sequence_strings:
async with self._http_client.post(uri_path, json={
"sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
"partial_matches": True
}) as response:
sequence_response: dict = await response.json()
attempts = 0
success = False
last_error = None
while not success and attempts < self._retry_limit:
request = self._http_client.post(uri_path, json={
"sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
"partial_matches": True
})
try:
async with request as response:
sequence_response: dict = await response.json()
if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items():
for allele in alleles:
alelle_id = allele["allele_id"]
result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = AlignmentStats(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"]),
match_metric=int(partial_match["bitscore"])
)
result_allele = Allele(
allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile
)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
if "exact_matches" in sequence_response:
# loci -> list of alleles with id and loci
exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]
for allele_loci, alleles in exact_matches.items():
for allele in alleles:
alelle_id = allele["allele_id"]
result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items():
if len(partial_match) <= 0:
continue
partial_match_profile = AlignmentStats(
percent_identity=float(partial_match["identity"]),
mismatches=int(partial_match["mismatches"]),
gaps=int(partial_match["gaps"]),
match_metric=int(partial_match["bitscore"])
)
result_allele = Allele(
allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile
)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
else:
raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
except ConnectionResetError as e:
last_error = e
success = False
await asyncio.sleep(5) # In case the connection issue is due to rate issues
else:
raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
success = True
if not success and last_error is not None:
raise last_error
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]:
uri_path = "designations"
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
@ -113,6 +129,7 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
request_json = {
"designations": allele_request_dict
}
async with self._http_client.post(uri_path, json=request_json) as response:
response_json: dict = await response.json()
allele_set: Set[Allele] = set()