Fixed concurrent profile_multiple_strings implementation
All checks were successful
autoBIGS.engine/pipeline/head This commit looks good

This commit is contained in:
Harrison Deng 2025-02-26 05:16:24 +00:00
parent 27ae89fde7
commit 4b34036d17
2 changed files with 20 additions and 13 deletions

View File

@ -7,7 +7,7 @@ from os import path
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Set, Union from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
@ -135,20 +135,24 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
return await self.determine_mlst_st(alleles) return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
tasks = [] tasks: list[Coroutine[Any, Any, Union[NamedMLSTProfile, MLSTProfile]]] = []
async for named_strings in query_named_string_groups: async for named_strings in query_named_string_groups:
tasks.append(self.profile_string(named_strings)) tasks.append(self.profile_string(named_strings))
for task in asyncio.as_completed(tasks): for task in asyncio.as_completed(tasks):
try: named_mlst_profile = await task
yield await task try:
except NoBIGSdbMatchesException as e: if isinstance(named_mlst_profile, NamedMLSTProfile):
if stop_on_fail: yield named_mlst_profile
raise e else:
causal_name = e.get_causal_query_name() raise TypeError("MLST profile is not named.")
if causal_name is None: except NoBIGSdbMatchesException as e:
raise ValueError("Missing query name despite requiring names.") if stop_on_fail:
else: raise e
yield NamedMLSTProfile(causal_name, None) causal_name = e.get_causal_query_name()
if causal_name is None:
raise ValueError("Missing query name despite requiring names.")
else:
yield NamedMLSTProfile(causal_name, None)
async def close(self): async def close(self):
await self._http_client.close() await self._http_client.close()

View File

@ -102,6 +102,7 @@ class TestBIGSdbMLSTProfiler:
continue continue
scrambled = gene_scrambler(str(target_sequence.seq), 0.125) scrambled = gene_scrambler(str(target_sequence.seq), 0.125)
async for partial_match in profiler.determine_mlst_allele_variants([scrambled]): async for partial_match in profiler.determine_mlst_allele_variants([scrambled]):
assert isinstance(partial_match, Allele)
assert partial_match.partial_match_profile is not None assert partial_match.partial_match_profile is not None
mlst_targets.remove(gene) mlst_targets.remove(gene)
@ -119,6 +120,7 @@ class TestBIGSdbMLSTProfiler:
dummy_alleles = bad_profile.alleles dummy_alleles = bad_profile.alleles
async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler: async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler:
mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles) mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles)
assert isinstance(mlst_profile, MLSTProfile)
assert mlst_profile.clonal_complex == "unknown" assert mlst_profile.clonal_complex == "unknown"
assert mlst_profile.sequence_type == "unknown" assert mlst_profile.sequence_type == "unknown"
@ -207,5 +209,6 @@ class TestBIGSdbIndex:
async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler: async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler:
assert isinstance(profiler, BIGSdbMLSTProfiler) assert isinstance(profiler, BIGSdbMLSTProfiler)
profile = await profiler.profile_string(sequence) profile = await profiler.profile_string(sequence)
assert isinstance(profile, MLSTProfile)
assert profile.clonal_complex == "ST-2 complex" assert profile.clonal_complex == "ST-2 complex"
assert profile.sequence_type == "1" assert profile.sequence_type == "1"