Infrastructure for concurrent processing implemented
All checks were successful
autoBIGS.engine/pipeline/head This commit looks good

This commit is contained in:
Harrison Deng 2025-02-19 15:49:46 +00:00
parent 7384895578
commit b8cebb8ba4
3 changed files with 50 additions and 29 deletions

View File

@ -22,15 +22,15 @@ from Bio.Align import PairwiseAligner
class BIGSdbMLSTProfiler(AbstractAsyncContextManager): class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
@abstractmethod @abstractmethod
def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[Union[NamedString, str]], Union[NamedString, str]]) -> AsyncGenerator[Union[Allele, tuple[str, Allele]], Any]:
pass pass
@abstractmethod @abstractmethod
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]:
pass pass
@abstractmethod @abstractmethod
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile: async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]:
pass pass
@abstractmethod @abstractmethod
@ -52,14 +52,14 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
async def __aenter__(self): async def __aenter__(self):
return self return self
async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[str], str]) -> AsyncGenerator[Allele, Any]: async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[Union[NamedString, str]], Union[NamedString, str]]) -> AsyncGenerator[Union[Allele, tuple[str, Allele]], Any]:
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
uri_path = "sequence" uri_path = "sequence"
if isinstance(query_sequence_strings, str): if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString):
query_sequence_strings = [query_sequence_strings] query_sequence_strings = [query_sequence_strings]
for sequence_string in query_sequence_strings: for sequence_string in query_sequence_strings:
async with self._http_client.post(uri_path, json={ async with self._http_client.post(uri_path, json={
"sequence": sequence_string, "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
"partial_matches": True "partial_matches": True
}) as response: }) as response:
sequence_response: dict = await response.json() sequence_response: dict = await response.json()
@ -70,7 +70,8 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
for allele_loci, alleles in exact_matches.items(): for allele_loci, alleles in exact_matches.items():
for allele in alleles: for allele in alleles:
alelle_id = allele["allele_id"] alelle_id = allele["allele_id"]
yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
elif "partial_matches" in sequence_response: elif "partial_matches" in sequence_response:
partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]
for allele_loci, partial_match in partial_matches.items(): for allele_loci, partial_match in partial_matches.items():
@ -82,23 +83,33 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
gaps=int(partial_match["gaps"]), gaps=int(partial_match["gaps"]),
match_metric=int(partial_match["bitscore"]) match_metric=int(partial_match["bitscore"])
) )
yield Allele( result_allele = Allele(
allele_locus=allele_loci, allele_locus=allele_loci,
allele_variant=str(partial_match["allele"]), allele_variant=str(partial_match["allele"]),
partial_match_profile=partial_match_profile partial_match_profile=partial_match_profile
) )
yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
else: else:
raise NoBIGSdbMatchesException(self._database_name, self._schema_id) raise NoBIGSdbMatchesException(self._database_name, self._schema_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
async def determine_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]:
uri_path = "designations" uri_path = "designations"
allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list) allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
names_list = []
def insert_allele_to_request_dict(allele: Union[Allele, tuple[str, Allele]]):
if isinstance(allele, Allele):
allele_val = allele
else:
allele_val = allele[1]
names_list.append(allele[0])
allele_request_dict[allele_val.allele_locus].append({"allele": str(allele_val.allele_variant)})
if isinstance(alleles, AsyncIterable): if isinstance(alleles, AsyncIterable):
async for allele in alleles: async for allele in alleles:
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) insert_allele_to_request_dict(allele)
else: else:
for allele in alleles: for allele in alleles:
allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) insert_allele_to_request_dict(allele)
request_json = { request_json = {
"designations": allele_request_dict "designations": allele_request_dict
} }
@ -111,30 +122,33 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
schema_fields_returned.setdefault("clonal_complex", "unknown") schema_fields_returned.setdefault("clonal_complex", "unknown")
schema_exact_matches: dict = response_json["exact_matches"] schema_exact_matches: dict = response_json["exact_matches"]
for exact_match_locus, exact_match_alleles in schema_exact_matches.items(): for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
if len(exact_match_alleles) > 1:
raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)) allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None))
if len(allele_set) == 0: if len(allele_set) == 0:
raise ValueError("Passed in no alleles.") raise ValueError("Passed in no alleles.")
return MLSTProfile(allele_set, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) result_mlst_profile = MLSTProfile(allele_set, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
if len(names_list) > 0:
result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)), result_mlst_profile)
return result_mlst_profile
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile: async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]:
alleles = self.determine_mlst_allele_variants(query_sequence_strings) alleles = self.determine_mlst_allele_variants(query_sequence_strings)
return await self.determine_mlst_st(alleles) return await self.determine_mlst_st(alleles)
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
tasks = []
async for named_strings in query_named_string_groups: async for named_strings in query_named_string_groups:
names: list[str] = list() tasks.append(self.profile_string(named_strings))
sequences: list[str] = list() for task in asyncio.as_completed(tasks):
for named_string in named_strings: try:
names.append(named_string.name) yield await task
sequences.append(named_string.sequence) except NoBIGSdbMatchesException as e:
try: if stop_on_fail:
yield NamedMLSTProfile("-".join(names), (await self.profile_string(sequences))) raise e
except NoBIGSdbMatchesException as e: causal_name = e.get_causal_query_name()
if stop_on_fail: if causal_name is None:
raise e raise ValueError("Missing query name despite requiring names.")
yield NamedMLSTProfile("-".join(names), None) else:
yield NamedMLSTProfile(causal_name, None)
async def close(self): async def close(self):
await self._http_client.close() await self._http_client.close()

View File

@ -5,8 +5,12 @@ class BIGSDbDatabaseAPIException(Exception):
class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException): class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException):
def __init__(self, database_name: str, database_schema_id: int, *args): def __init__(self, database_name: str, database_schema_id: int, query_name: Union[None, str], *args):
self._query_name = query_name
super().__init__(f"No matches found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args) super().__init__(f"No matches found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args)
def get_causal_query_name(self) -> Union[str, None]:
return self._query_name
class NoBIGSdbExactMatchesException(NoBIGSdbMatchesException): class NoBIGSdbExactMatchesException(NoBIGSdbMatchesException):
def __init__(self, database_name: str, database_schema_id: int, *args): def __init__(self, database_name: str, database_schema_id: int, *args):

View File

@ -13,5 +13,8 @@ async def read_fasta(handle: Union[str, TextIOWrapper]) -> Iterable[NamedString]
return results return results
async def read_multiple_fastas(handles: Iterable[Union[str, TextIOWrapper]]) -> AsyncGenerator[Iterable[NamedString], Any]: async def read_multiple_fastas(handles: Iterable[Union[str, TextIOWrapper]]) -> AsyncGenerator[Iterable[NamedString], Any]:
tasks = []
for handle in handles: for handle in handles:
yield await read_fasta(handle) tasks.append(read_fasta(handle))
for task in asyncio.as_completed(tasks):
yield await task