From 4b34036d17fec4988a28a50bc8fa7b09d883c38a Mon Sep 17 00:00:00 2001 From: Harrison Deng Date: Wed, 26 Feb 2025 05:16:24 +0000 Subject: [PATCH] Fixed concurrent profile_multiple_strings implementation --- src/autobigs/engine/analysis/bigsdb.py | 30 +++++++++++-------- tests/autobigs/engine/analysis/test_bigsdb.py | 3 ++ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/autobigs/engine/analysis/bigsdb.py b/src/autobigs/engine/analysis/bigsdb.py index d186753..1195c30 100644 --- a/src/autobigs/engine/analysis/bigsdb.py +++ b/src/autobigs/engine/analysis/bigsdb.py @@ -7,7 +7,7 @@ from os import path import os import shutil import tempfile -from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Set, Union +from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union from aiohttp import ClientSession, ClientTimeout @@ -135,20 +135,24 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): return await self.determine_mlst_st(alleles) async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: - tasks = [] + tasks: list[Coroutine[Any, Any, Union[NamedMLSTProfile, MLSTProfile]]] = [] async for named_strings in query_named_string_groups: tasks.append(self.profile_string(named_strings)) - for task in asyncio.as_completed(tasks): - try: - yield await task - except NoBIGSdbMatchesException as e: - if stop_on_fail: - raise e - causal_name = e.get_causal_query_name() - if causal_name is None: - raise ValueError("Missing query name despite requiring names.") - else: - yield NamedMLSTProfile(causal_name, None) + for task in asyncio.as_completed(tasks): + named_mlst_profile = await task + try: + if isinstance(named_mlst_profile, NamedMLSTProfile): + yield named_mlst_profile + else: + raise TypeError("MLST profile is not named.") + except NoBIGSdbMatchesException as e: + if stop_on_fail: + raise e + causal_name = e.get_causal_query_name() + if causal_name is None: + raise ValueError("Missing query name despite requiring names.") + else: + yield NamedMLSTProfile(causal_name, None) async def close(self): await self._http_client.close() diff --git a/tests/autobigs/engine/analysis/test_bigsdb.py b/tests/autobigs/engine/analysis/test_bigsdb.py index ed01fd3..233e311 100644 --- a/tests/autobigs/engine/analysis/test_bigsdb.py +++ b/tests/autobigs/engine/analysis/test_bigsdb.py @@ -102,6 +102,7 @@ class TestBIGSdbMLSTProfiler: continue scrambled = gene_scrambler(str(target_sequence.seq), 0.125) async for partial_match in profiler.determine_mlst_allele_variants([scrambled]): + assert isinstance(partial_match, Allele) assert partial_match.partial_match_profile is not None mlst_targets.remove(gene) @@ -119,6 +120,7 @@ class TestBIGSdbMLSTProfiler: dummy_alleles = bad_profile.alleles async with bigsdb.get_BIGSdb_MLST_profiler(local_db, database_api, database_name, scheme_id) as dummy_profiler: mlst_profile = await dummy_profiler.determine_mlst_st(dummy_alleles) + assert isinstance(mlst_profile, MLSTProfile) assert mlst_profile.clonal_complex == "unknown" assert mlst_profile.sequence_type == "unknown" @@ -207,5 +209,6 @@ class TestBIGSdbIndex: async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler: assert isinstance(profiler, BIGSdbMLSTProfiler) profile = await profiler.profile_string(sequence) + assert isinstance(profile, MLSTProfile) assert profile.clonal_complex == "ST-2 complex" assert profile.sequence_type == "1"