Fixed concurrent profile_multiple_strings implementation
All checks were successful
autoBIGS.engine/pipeline/head This commit looks good

This commit is contained in:
Harrison Deng 2025-02-26 05:16:24 +00:00
parent 27ae89fde7
commit 4b34036d17
2 changed files with 20 additions and 13 deletions

View File

@ -7,7 +7,7 @@ from os import path
import os
import shutil
import tempfile
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Set, Union
from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union
from aiohttp import ClientSession, ClientTimeout
@ -135,20 +135,24 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
tasks = []
tasks: list[Coroutine[Any, Any, Union[NamedMLSTProfile, MLSTProfile]]] = []
async for named_strings in query_named_string_groups:
tasks.append(self.profile_string(named_strings))
for task in asyncio.as_completed(tasks):
try:
yield await task
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
causal_name = e.get_causal_query_name()
if causal_name is None:
raise ValueError("Missing query name despite requiring names.")
else:
yield NamedMLSTProfile(causal_name, None)
for task in asyncio.as_completed(tasks):
named_mlst_profile = await task
try:
if isinstance(named_mlst_profile, NamedMLSTProfile):
yield named_mlst_profile
else:
raise TypeError("MLST profile is not named.")
except NoBIGSdbMatchesException as e:
if stop_on_fail:
raise e
causal_name = e.get_causal_query_name()
if causal_name is None:
raise ValueError("Missing query name despite requiring names.")
else:
yield NamedMLSTProfile(causal_name, None)
async def close(self):
await self._http_client.close()

View File

@ -102,6 +102,7 @@ class TestBIGSdbMLSTProfiler:
continue
scrambled = gene_scrambler(str(target_sequence.seq), 0.125)
async for partial_match in profiler.determine_mlst_allele_variants([scrambled]):
assert isinstance(partial_match, Allele)
assert partial_match.partial_match_profile is not None
mlst_targets.remove(gene)
@ -119,6 +120,7 @@ class TestBIGSdbMLSTProfiler:
dummy_alleles = bad_profile.alleles
async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler:
mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles)
assert isinstance(mlst_profile, MLSTProfile)
assert mlst_profile.clonal_complex == "unknown"
assert mlst_profile.sequence_type == "unknown"
@ -207,5 +209,6 @@ class TestBIGSdbIndex:
async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler:
assert isinstance(profiler, BIGSdbMLSTProfiler)
profile = await profiler.profile_string(sequence)
assert isinstance(profile, MLSTProfile)
assert profile.clonal_complex == "ST-2 complex"
assert profile.sequence_type == "1"