diff --git a/Jenkinsfile b/Jenkinsfile index fb4051e..d46f6b1 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -49,7 +49,7 @@ pipeline { steps { sh returnStatus: true, script: 'python -m twine upload -u __token__ -p ${TOKEN} --non-interactive --disable-progress-bar --verbose dist/*' } - } + }- } } } diff --git a/pyproject.toml b/pyproject.toml index 89a9a00..f13a2f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ requires-python = ">=3.12" description = "A library to rapidly fetch fetch MLST profiles given sequences for various diseases." [project.urls] -Repository = "https://github.com/RealYHD/autoBIGS.engine" +Homepage = "https://github.com/RealYHD/autoBIGS.engine" +Source = "https://github.com/RealYHD/autoBIGS.engine" Issues = "https://github.com/RealYHD/autoBIGS.engine/issues" [tool.setuptools_scm] diff --git a/src/autobigs/engine/data/remote/databases/bigsdb.py b/src/autobigs/engine/data/remote/databases/bigsdb.py index 9c81195..f7ba79a 100644 --- a/src/autobigs/engine/data/remote/databases/bigsdb.py +++ b/src/autobigs/engine/data/remote/databases/bigsdb.py @@ -1,16 +1,43 @@ +from abc import abstractmethod from collections import defaultdict from contextlib import AbstractAsyncContextManager -from numbers import Number -from typing import Any, AsyncGenerator, AsyncIterable, Collection, Generator, Iterable, Mapping, Sequence, Union +import csv +from os import path +from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union from aiohttp import ClientSession, ClientTimeout +from autobigs.engine.data.local.fasta import read_fasta from autobigs.engine.data.structures.genomics import NamedString -from autobigs.engine.data.structures.mlst import Allele, PartialAllelicMatchProfile, MLSTProfile +from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException +from Bio.Align import PairwiseAligner + class BIGSdbMLSTProfiler(AbstractAsyncContextManager): + @abstractmethod + def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: + pass + + @abstractmethod + async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: + pass + + @abstractmethod + async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile: + pass + + @abstractmethod + def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: + pass + + @abstractmethod + async def close(self): + pass + +class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): + def __init__(self, database_api: str, database_name: str, schema_id: int): self._database_name = database_name self._schema_id = schema_id @@ -20,86 +47,193 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager): async def __aenter__(self): return self - async def fetch_mlst_allele_variants(self, sequence_string: str, exact: bool) -> AsyncGenerator[Allele, Any]: + async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes uri_path = "sequence" - async with self._http_client.post(uri_path, json={ - "sequence": sequence_string, - "partial_matches": not exact - }) as response: - sequence_response: dict = await response.json() - - if "exact_matches" in sequence_response: - # loci -> list of alleles with id and loci - exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"] - for allele_loci, alleles in exact_matches.items(): - for allele in alleles: - alelle_id = allele["allele_id"] - yield Allele(allele_loci=allele_loci, allele_variant=alelle_id, partial_match_profile=None) - elif "partial_matches" in sequence_response: - if exact: - raise NoBIGSdbExactMatchesException(self._database_name, self._schema_id) - partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] - for allele_loci, partial_match in partial_matches.items(): - if len(partial_match) <= 0: - continue - partial_match_profile = PartialAllelicMatchProfile( - percent_identity=float(partial_match["identity"]), - mismatches=int(partial_match["mismatches"]), - bitscore=float(partial_match["bitscore"]), - gaps=int(partial_match["gaps"]) - ) - yield Allele( - allele_loci=allele_loci, - allele_variant=str(partial_match["allele"]), - partial_match_profile=partial_match_profile - ) - else: - raise NoBIGSdbMatchesException(self._database_name, self._schema_id) - + for sequence_string in sequence_strings: + async with self._http_client.post(uri_path, json={ + "sequence": sequence_string, + "partial_matches": True + }) as response: + sequence_response: dict = await response.json() + if "exact_matches" in sequence_response: + # loci -> list of alleles with id and loci + exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"] + for allele_loci, alleles in exact_matches.items(): + for allele in alleles: + alelle_id = allele["allele_id"] + yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) + elif "partial_matches" in sequence_response: + partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] + for allele_loci, partial_match in partial_matches.items(): + if len(partial_match) <= 0: + continue + partial_match_profile = PartialAllelicMatchProfile( + percent_identity=float(partial_match["identity"]), + mismatches=int(partial_match["mismatches"]), + gaps=int(partial_match["gaps"]) + ) + yield Allele( + allele_locus=allele_loci, + allele_variant=str(partial_match["allele"]), + partial_match_profile=partial_match_profile + ) + else: + raise NoBIGSdbMatchesException(self._database_name, self._schema_id) async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: uri_path = "designations" allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list) if isinstance(alleles, AsyncIterable): async for allele in alleles: - allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)}) + allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) else: for allele in alleles: - allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)}) + allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)}) request_json = { "designations": allele_request_dict } async with self._http_client.post(uri_path, json=request_json) as response: response_json: dict = await response.json() - allele_map: dict[str, list[Allele]] = defaultdict(list) + allele_map: dict[str, Allele] = {} response_json.setdefault("fields", dict()) schema_fields_returned: dict[str, str] = response_json["fields"] schema_fields_returned.setdefault("ST", "unknown") schema_fields_returned.setdefault("clonal_complex", "unknown") schema_exact_matches: dict = response_json["exact_matches"] - for exact_match_loci, exact_match_alleles in schema_exact_matches.items(): - for exact_match_allele in exact_match_alleles: - allele_map[exact_match_loci].append(Allele(exact_match_loci, exact_match_allele["allele_id"], None)) + for exact_match_locus, exact_match_alleles in schema_exact_matches.items(): + if len(exact_match_alleles) > 1: + raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})") + allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None) if len(allele_map) == 0: raise ValueError("Passed in no alleles.") return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) - async def profile_string(self, string: str, exact: bool = False) -> MLSTProfile: - alleles = self.fetch_mlst_allele_variants(string, exact) + async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile: + alleles = self.fetch_mlst_allele_variants(sequence_strings) return await self.fetch_mlst_st(alleles) + async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: + async for named_strings in named_string_groups: + for named_string in named_strings: + try: + yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence]))) + except NoBIGSdbMatchesException as e: + if stop_on_fail: + raise e + yield NamedMLSTProfile(named_string.name, None) - async def profile_multiple_strings(self, namedStrings: AsyncIterable[NamedString], exact: bool = False, stop_on_fail: bool = False) -> AsyncGenerator[tuple[str, Union[MLSTProfile, None]], Any]: - async for named_string in namedStrings: - try: - yield (named_string.name, await self.profile_string(named_string.sequence, exact)) - except NoBIGSdbMatchesException as e: - if stop_on_fail: - raise e - yield (named_string.name, None) + async def close(self): + await self._http_client.close() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + +class LazyPersistentCachedBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): + def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: str): + self._database_api = database_api + self._database_name = database_name + self._schema_id = schema_id + self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/" + self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000)) + self._cache_path = cache_path + self._loci: list[str] = [] + self._profiles = {} + + async def load_scheme_locis(self): + self._loci.clear() + async with self._http_client.get("") as schema_response: + schema_json = await schema_response.json() + for locus in schema_json["loci"]: + locus_name = path.basename(locus) + self._loci.append(locus_name) + self._loci.sort() + + async def load_scheme_profiles(self): + self._profiles.clear() + with open(self.get_scheme_profile_path()) as profile_cache_handle: + reader = csv.DictReader(profile_cache_handle, delimiter="\t") + for line in reader: + alleles = [] + for locus in self._loci: + alleles.append(line[locus]) + self._profiles[tuple(alleles)] = (line["ST"], line["clonal_complex"]) + + def get_locus_cache_path(self, locus) -> str: + return path.join(self._cache_path, locus + "." + "fasta") + + def get_scheme_profile_path(self): + return path.join(self._cache_path, "profiles.csv") + + async def download_alleles_cache_data(self): + for locus in self._loci: + with open(self.get_locus_cache_path(locus), "wb") as fasta_handle: + async with self._http_client.get(f"/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response: + async for chunk, eof in fasta_response.content.iter_chunks(): # TODO maybe allow chunking to be configurable + fasta_handle.write(chunk) + + async def download_scheme_profiles(self): + with open(self.get_scheme_profile_path(), "wb") as profile_cache_handle: + async with self._http_client.get("profiles_csv") as profiles_response: + async for chunk, eof in profiles_response.content.iter_chunks(): + profile_cache_handle.write(chunk) + + async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: + aligner = PairwiseAligner("blastn") + aligner.mode = "local" + for sequence_string in sequence_strings: + for locus in self._loci: + async for fasta_seq in read_fasta(self.get_locus_cache_path(locus)): + allele_variant = fasta_seq.name + alignment_results = aligner.align(sequence_string, fasta_seq.sequence) + top_alignment = sorted(alignment_results)[0] + top_alignment_stats = top_alignment.counts() + top_alignment_gaps = top_alignment_stats.gaps + top_alignment_identities = top_alignment_stats.identities + top_alignment_mismatches = top_alignment_stats.mismatches + if top_alignment_gaps == 0 and top_alignment_mismatches == 0: + yield Allele(locus, allele_variant, None) + else: + yield Allele( + locus, + allele_variant, + PartialAllelicMatchProfile( + percent_identity=top_alignment_identities/top_alignment.length, + mismatches=top_alignment_mismatches, + gaps=top_alignment_gaps + ) + ) + + async def fetch_mlst_st(self, alleles): + allele_variants: dict[str, Allele] = {} + if isinstance(alleles, AsyncIterable): + async for allele in alleles: + allele_variants[allele.allele_locus] = allele + else: + for allele in alleles: + allele_variants[allele.allele_locus] = allele + ordered_profile = [] + for locus in self._loci: + ordered_profile.append(allele_variants[locus].allele_variant) + + st, clonal_complex = self._profiles[tuple(ordered_profile)] + return MLSTProfile(allele_variants, st, clonal_complex) + + async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile: + alleles = self.fetch_mlst_allele_variants(sequence_strings) + return await self.fetch_mlst_st(alleles) + + async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: + async for named_strings in named_string_groups: + for named_string in named_strings: + try: + yield NamedMLSTProfile(named_string.name, await self.profile_string([named_string.sequence])) + except NoBIGSdbMatchesException as e: + if stop_on_fail: + raise e + yield NamedMLSTProfile(named_string.name, None) async def close(self): await self._http_client.close() @@ -156,8 +290,8 @@ class BIGSdbIndex(AbstractAsyncContextManager): self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions return self._seqdefdb_schemas[seqdef_db_name] # type: ignore - async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> BIGSdbMLSTProfiler: - return BIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id) + async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler: + return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id) async def close(self): await self._http_client.close() diff --git a/src/autobigs/engine/data/structures/mlst.py b/src/autobigs/engine/data/structures/mlst.py index 1c52d74..88b173c 100644 --- a/src/autobigs/engine/data/structures/mlst.py +++ b/src/autobigs/engine/data/structures/mlst.py @@ -5,17 +5,21 @@ from typing import Mapping, Sequence, Union class PartialAllelicMatchProfile: percent_identity: float mismatches: int - bitscore: float gaps: int @dataclass(frozen=True) class Allele: - allele_loci: str + allele_locus: str allele_variant: str partial_match_profile: Union[None, PartialAllelicMatchProfile] @dataclass(frozen=True) class MLSTProfile: - alleles: Mapping[str, Sequence[Allele]] + alleles: Mapping[str, Allele] sequence_type: str clonal_complex: str + +@dataclass(frozen=True) +class NamedMLSTProfile: + name: str + mlst_profile: Union[None, MLSTProfile] \ No newline at end of file diff --git a/tests/autobigs/engine/data/remote/databases/test_bigsdb.py b/tests/autobigs/engine/data/remote/databases/test_bigsdb.py index f7cc7f2..74414a8 100644 --- a/tests/autobigs/engine/data/remote/databases/test_bigsdb.py +++ b/tests/autobigs/engine/data/remote/databases/test_bigsdb.py @@ -6,7 +6,7 @@ import pytest from autobigs.engine.data.structures.genomics import NamedString from autobigs.engine.data.structures.mlst import Allele, MLSTProfile from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException -from autobigs.engine.data.remote.databases.bigsdb import BIGSdbIndex, BIGSdbMLSTProfiler +from autobigs.engine.data.remote.databases.bigsdb import BIGSdbIndex, OnlineBIGSdbMLSTProfiler def gene_scrambler(gene: str, mutation_site_count: Union[int, float], alphabet: Sequence[str] = ["A", "T", "C", "G"]): rand = random.Random(gene) @@ -20,19 +20,19 @@ def gene_scrambler(gene: str, mutation_site_count: Union[int, float], alphabet: async def test_institutpasteur_profiling_results_in_exact_matches_when_exact(): sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: targets_left = {"adk", "fumC", "glyA", "tyrB", "icd", "pepA", "pgm"} - async for exact_match in dummy_profiler.fetch_mlst_allele_variants(sequence_string=sequence, exact=True): + async for exact_match in dummy_profiler.fetch_mlst_allele_variants(sequence_strings=[sequence]): assert isinstance(exact_match, Allele) assert exact_match.allele_variant == '1' # All of Tohama I has allele id I - targets_left.remove(exact_match.allele_loci) + targets_left.remove(exact_match.allele_locus) assert len(targets_left) == 0 async def test_institutpasteur_sequence_profiling_non_exact_returns_non_exact(): sequences = list(SeqIO.parse("tests/resources/tohama_I_bpertussis_coding.fasta", "fasta")) mlst_targets = {"adk", "fumc", "glya", "tyrb", "icd", "pepa", "pgm"} - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as profiler: for sequence in sequences: match = re.fullmatch(r".*\[gene=([\w\d]+)\].*", sequence.description) if match is None: @@ -41,7 +41,7 @@ async def test_institutpasteur_sequence_profiling_non_exact_returns_non_exact(): if gene.lower() not in mlst_targets: continue scrambled = gene_scrambler(str(sequence.seq), 0.125) - async for partial_match in profiler.fetch_mlst_allele_variants(scrambled, False): + async for partial_match in profiler.fetch_mlst_allele_variants(scrambled): assert partial_match.partial_match_profile is not None mlst_targets.remove(gene.lower()) @@ -60,7 +60,7 @@ async def test_institutpasteur_profiling_results_in_correct_mlst_st(): ] for dummy_allele in dummy_alleles: yield dummy_allele - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: mlst_st_data = await dummy_profiler.fetch_mlst_st(dummy_allele_generator()) assert mlst_st_data is not None assert isinstance(mlst_st_data, MLSTProfile) @@ -77,7 +77,7 @@ async def test_institutpasteur_profiling_non_exact_results_in_list_of_mlsts(): Allele("pepA", "1", None), Allele("pgm", "5", None), ] - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: mlst_profile = await dummy_profiler.fetch_mlst_st(dummy_alleles) assert mlst_profile.clonal_complex == "unknown" assert mlst_profile.sequence_type == "unknown" @@ -85,7 +85,7 @@ async def test_institutpasteur_profiling_non_exact_results_in_list_of_mlsts(): async def test_institutpasteur_sequence_profiling_is_correct(): sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: profile = await dummy_profiler.profile_string(sequence) assert profile is not None assert isinstance(profile, MLSTProfile) @@ -104,8 +104,8 @@ async def test_pubmlst_profiling_results_in_exact_matches_when_exact(): Allele("recA", "5", None), } sequence = str(SeqIO.read("tests/resources/FDAARGOS_1560.fasta", "fasta").seq) - async with BIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: - exact_matches = dummy_profiler.fetch_mlst_allele_variants(sequence_string=sequence, exact=True) + async with OnlineBIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: + exact_matches = dummy_profiler.fetch_mlst_allele_variants(sequence_strings=sequence) async for exact_match in exact_matches: assert isinstance(exact_match, Allele) dummy_alleles.remove(exact_match) @@ -125,7 +125,7 @@ async def test_pubmlst_profiling_results_in_correct_st(): ] for dummy_allele in dummy_alleles: yield dummy_allele - async with BIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: mlst_st_data = await dummy_profiler.fetch_mlst_st(generate_dummy_targets()) assert mlst_st_data is not None assert isinstance(mlst_st_data, MLSTProfile) @@ -134,7 +134,7 @@ async def test_pubmlst_profiling_results_in_correct_st(): async def test_pubmlst_sequence_profiling_is_correct(): sequence = str(SeqIO.read("tests/resources/FDAARGOS_1560.fasta", "fasta").seq) - async with BIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: + async with OnlineBIGSdbMLSTProfiler(database_api="https://rest.pubmlst.org/", database_name="pubmlst_hinfluenzae_seqdef", schema_id=1) as dummy_profiler: profile = await dummy_profiler.profile_string(sequence) assert profile is not None assert isinstance(profile, MLSTProfile) @@ -167,9 +167,10 @@ async def test_bigsdb_profile_multiple_strings_same_string_twice(): dummy_sequences = [NamedString("seq1", sequence), NamedString("seq2", sequence)] async def generate_async_iterable_sequences(): for dummy_sequence in dummy_sequences: - yield dummy_sequence - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences()): + yield [dummy_sequence] + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async for named_profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences()): + name, profile = named_profile.name, named_profile.mlst_profile assert profile is not None assert isinstance(profile, MLSTProfile) assert profile.clonal_complex == "ST-2 complex" @@ -180,9 +181,11 @@ async def test_bigsdb_profile_multiple_strings_exactmatch_fail_second_no_stop(): dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", gene_scrambler(valid_seq, 0.3)), NamedString("seq3", valid_seq)] async def generate_async_iterable_sequences(): for dummy_sequence in dummy_sequences: - yield dummy_sequence - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), True): + yield [dummy_sequence] + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async for name_profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), True): + name, profile = name_profile.name, name_profile.mlst_profile + if name == "should_fail": assert profile is None else: @@ -196,9 +199,10 @@ async def test_bigsdb_profile_multiple_strings_nonexact_second_no_stop(): dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", gene_scrambler(valid_seq, 0.3)), NamedString("seq3", valid_seq)] async def generate_async_iterable_sequences(): for dummy_sequence in dummy_sequences: - yield dummy_sequence - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), False): + yield [dummy_sequence] + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async for named_profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), False): + name, profile = named_profile.name, named_profile.mlst_profile if name == "should_fail": assert profile is not None assert profile.clonal_complex == "unknown" @@ -216,10 +220,11 @@ async def test_bigsdb_profile_multiple_strings_fail_second_stop(): dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", invalid_seq), NamedString("seq3", valid_seq)] async def generate_async_iterable_sequences(): for dummy_sequence in dummy_sequences: - yield dummy_sequence - async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + yield [dummy_sequence] + async with OnlineBIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: with pytest.raises(NoBIGSdbMatchesException): - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), exact=True, stop_on_fail=True): + async for named_profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), stop_on_fail=True): + name, profile = named_profile.name, named_profile.mlst_profile if name == "should_fail": pytest.fail("Exception should have been thrown, no exception was thrown.") else: