From b8cebb8ba43ccab5ad5d62f9ef1c7cb5ceb2bf42 Mon Sep 17 00:00:00 2001 From: Harrison Deng Date: Wed, 19 Feb 2025 15:49:46 +0000 Subject: [PATCH] Infrastructure for concurrent processing implemented --- src/autobigs/engine/analysis/bigsdb.py | 68 +++++++++++++--------- src/autobigs/engine/exceptions/database.py | 6 +- src/autobigs/engine/reading.py | 5 +- 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/src/autobigs/engine/analysis/bigsdb.py b/src/autobigs/engine/analysis/bigsdb.py index 0b52ce9..d9e11e9 100644 --- a/src/autobigs/engine/analysis/bigsdb.py +++ b/src/autobigs/engine/analysis/bigsdb.py @@ -22,15 +22,15 @@ from Bio.Align import PairwiseAligner class BIGSdbMLSTProfiler(AbstractAsyncContextManager): @abstractmethod - def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: + def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[Union[NamedString, str]], Union[NamedString, str]]) -> AsyncGenerator[Union[Allele, tuple[str, Allele]], Any]: pass @abstractmethod - async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: + async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]: pass @abstractmethod - async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile: + async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]: pass @abstractmethod @@ -52,14 +52,14 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): async def __aenter__(self): return self - async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[str], str]) -> AsyncGenerator[Allele, Any]: + async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[Union[NamedString, str]], Union[NamedString, str]]) -> AsyncGenerator[Union[Allele, tuple[str, Allele]], Any]: # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes uri_path = "sequence" - if isinstance(query_sequence_strings, str): + if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString): query_sequence_strings = [query_sequence_strings] for sequence_string in query_sequence_strings: async with self._http_client.post(uri_path, json={ - "sequence": sequence_string, + "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence, "partial_matches": True }) as response: sequence_response: dict = await response.json() @@ -70,7 +70,8 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): for allele_loci, alleles in exact_matches.items(): for allele in alleles: alelle_id = allele["allele_id"] - yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) + result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) + yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) elif "partial_matches" in sequence_response: partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] for allele_loci, partial_match in partial_matches.items(): @@ -82,23 +83,33 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): gaps=int(partial_match["gaps"]), match_metric=int(partial_match["bitscore"]) ) - yield Allele( + result_allele = Allele( allele_locus=allele_loci, allele_variant=str(partial_match["allele"]), partial_match_profile=partial_match_profile ) + yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) else: - raise NoBIGSdbMatchesException(self._database_name, self._schema_id) + raise NoBIGSdbMatchesException(self._database_name, self._schema_id, sequence_string.name if isinstance(sequence_string, NamedString) else None) - async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: + async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]: uri_path = "designations" allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list) + names_list = [] + def insert_allele_to_request_dict(allele: Union[Allele, tuple[str, Allele]]): + if isinstance(allele, Allele): + allele_val = allele + else: + allele_val = allele[1] + names_list.append(allele[0]) + allele_request_dict[allele_val.allele_locus].append({"allele": str(allele_val.allele_variant)}) + if isinstance(alleles, AsyncIterable): async for allele in alleles: - allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) + insert_allele_to_request_dict(allele) else: for allele in alleles: - allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) + insert_allele_to_request_dict(allele) request_json = { "designations": allele_request_dict } @@ -111,30 +122,33 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): schema_fields_returned.setdefault("clonal_complex", "unknown") schema_exact_matches: dict = response_json["exact_matches"] for exact_match_locus, exact_match_alleles in schema_exact_matches.items(): - if len(exact_match_alleles) > 1: - raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})") allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)) if len(allele_set) == 0: raise ValueError("Passed in no alleles.") - return MLSTProfile(allele_set, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) + result_mlst_profile = MLSTProfile(allele_set, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) + if len(names_list) > 0: + result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)), result_mlst_profile) + return result_mlst_profile - async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile: + async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]: alleles = self.determine_mlst_allele_variants(query_sequence_strings) return await self.determine_mlst_st(alleles) async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: + tasks = [] async for named_strings in query_named_string_groups: - names: list[str] = list() - sequences: list[str] = list() - for named_string in named_strings: - names.append(named_string.name) - sequences.append(named_string.sequence) - try: - yield NamedMLSTProfile("-".join(names), (await self.profile_string(sequences))) - except NoBIGSdbMatchesException as e: - if stop_on_fail: - raise e - yield NamedMLSTProfile("-".join(names), None) + tasks.append(self.profile_string(named_strings)) + for task in asyncio.as_completed(tasks): + try: + yield await task + except NoBIGSdbMatchesException as e: + if stop_on_fail: + raise e + causal_name = e.get_causal_query_name() + if causal_name is None: + raise ValueError("Missing query name despite requiring names.") + else: + yield NamedMLSTProfile(causal_name, None) async def close(self): await self._http_client.close() diff --git a/src/autobigs/engine/exceptions/database.py b/src/autobigs/engine/exceptions/database.py index c60190a..10787d2 100644 --- a/src/autobigs/engine/exceptions/database.py +++ b/src/autobigs/engine/exceptions/database.py @@ -5,8 +5,12 @@ class BIGSDbDatabaseAPIException(Exception): class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException): - def __init__(self, database_name: str, database_schema_id: int, *args): + def __init__(self, database_name: str, database_schema_id: int, query_name: Union[None, str], *args): + self._query_name = query_name super().__init__(f"No matches found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args) + + def get_causal_query_name(self) -> Union[str, None]: + return self._query_name class NoBIGSdbExactMatchesException(NoBIGSdbMatchesException): def __init__(self, database_name: str, database_schema_id: int, *args): diff --git a/src/autobigs/engine/reading.py b/src/autobigs/engine/reading.py index 6618427..9949da4 100644 --- a/src/autobigs/engine/reading.py +++ b/src/autobigs/engine/reading.py @@ -13,5 +13,8 @@ async def read_fasta(handle: Union[str, TextIOWrapper]) -> Iterable[NamedString] return results async def read_multiple_fastas(handles: Iterable[Union[str, TextIOWrapper]]) -> AsyncGenerator[Iterable[NamedString], Any]: + tasks = [] for handle in handles: - yield await read_fasta(handle) \ No newline at end of file + tasks.append(read_fasta(handle)) + for task in asyncio.as_completed(tasks): + yield await task \ No newline at end of file