Fixed concurrent profile_multiple_strings implementation
All checks were successful
autoBIGS.engine/pipeline/head This commit looks good

This commit is contained in:
Harrison Deng 2025-02-26 05:16:24 +00:00
parent 27ae89fde7
commit 4b34036d17
2 changed files with 20 additions and 13 deletions

View File

@ -7,7 +7,7 @@ from os import path
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Set, Union from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
@ -135,12 +135,16 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
return await self.determine_mlst_st(alleles) return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
tasks = [] tasks: list[Coroutine[Any, Any, Union[NamedMLSTProfile, MLSTProfile]]] = []
async for named_strings in query_named_string_groups: async for named_strings in query_named_string_groups:
tasks.append(self.profile_string(named_strings)) tasks.append(self.profile_string(named_strings))
for task in asyncio.as_completed(tasks): for task in asyncio.as_completed(tasks):
named_mlst_profile = await task
try: try:
yield await task if isinstance(named_mlst_profile, NamedMLSTProfile):
yield named_mlst_profile
else:
raise TypeError("MLST profile is not named.")
except NoBIGSdbMatchesException as e: except NoBIGSdbMatchesException as e:
if stop_on_fail: if stop_on_fail:
raise e raise e

View File

@ -102,6 +102,7 @@ class TestBIGSdbMLSTProfiler:
continue continue
scrambled = gene_scrambler(str(target_sequence.seq), 0.125) scrambled = gene_scrambler(str(target_sequence.seq), 0.125)
async for partial_match in profiler.determine_mlst_allele_variants([scrambled]): async for partial_match in profiler.determine_mlst_allele_variants([scrambled]):
assert isinstance(partial_match, Allele)
assert partial_match.partial_match_profile is not None assert partial_match.partial_match_profile is not None
mlst_targets.remove(gene) mlst_targets.remove(gene)
@ -119,6 +120,7 @@ class TestBIGSdbMLSTProfiler:
dummy_alleles = bad_profile.alleles dummy_alleles = bad_profile.alleles
async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler: async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler:
mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles) mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles)
assert isinstance(mlst_profile, MLSTProfile)
assert mlst_profile.clonal_complex == "unknown" assert mlst_profile.clonal_complex == "unknown"
assert mlst_profile.sequence_type == "unknown" assert mlst_profile.sequence_type == "unknown"
@ -207,5 +209,6 @@ class TestBIGSdbIndex:
async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler: async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler:
assert isinstance(profiler, BIGSdbMLSTProfiler) assert isinstance(profiler, BIGSdbMLSTProfiler)
profile = await profiler.profile_string(sequence) profile = await profiler.profile_string(sequence)
assert isinstance(profile, MLSTProfile)
assert profile.clonal_complex == "ST-2 complex" assert profile.clonal_complex == "ST-2 complex"
assert profile.sequence_type == "1" assert profile.sequence_type == "1"