diff --git a/src/automlst/engine/remote/databases/bigsdb.py b/src/automlst/engine/remote/databases/bigsdb.py index 99baaf7..7bcd5d2 100644 --- a/src/automlst/engine/remote/databases/bigsdb.py +++ b/src/automlst/engine/remote/databases/bigsdb.py @@ -59,28 +59,29 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager): - async def fetch_mlst_st(self, alleles: AsyncIterable[Allele]) -> MLSTProfile: + async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile: uri_path = "designations" allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list) - async for allele in alleles: - allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)}) - + if isinstance(alleles, AsyncIterable): + async for allele in alleles: + allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)}) + else: + for allele in alleles: + allele_request_dict[allele.allele_loci].append({"allele": str(allele.allele_variant)}) request_json = { "designations": allele_request_dict } async with self._http_client.post(uri_path, json=request_json) as response: response_json: dict = await response.json() - if "exact_matches" not in response_json: - raise NoBIGSdbExactMatchesException(self._database_name, self._schema_id) - schema_exact_matches: dict = response_json["exact_matches"] - response_json.setdefault("fields", dict) + allele_map: dict[str, list[Allele]] = defaultdict(list) + response_json.setdefault("fields", dict()) schema_fields_returned: dict[str, str] = response_json["fields"] schema_fields_returned.setdefault("ST", "unknown") schema_fields_returned.setdefault("clonal_complex", "unknown") - allele_map: dict[str, list[Allele]] = defaultdict(list) + schema_exact_matches: dict = response_json["exact_matches"] for exact_match_loci, exact_match_alleles in schema_exact_matches.items(): for exact_match_allele in exact_match_alleles: - allele_map[exact_match_loci].append(Allele(exact_match_loci, exact_match_allele["allele_id"])) + allele_map[exact_match_loci].append(Allele(exact_match_loci, exact_match_allele["allele_id"], None)) return MLSTProfile(allele_map, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) async def profile_string(self, string: str, exact: bool = False) -> MLSTProfile: diff --git a/tests/automlst/engine/remote/databases/test_bigsdb.py b/tests/automlst/engine/remote/databases/test_bigsdb.py index 895f09d..0a2f581 100644 --- a/tests/automlst/engine/remote/databases/test_bigsdb.py +++ b/tests/automlst/engine/remote/databases/test_bigsdb.py @@ -30,7 +30,7 @@ async def test_institutpasteur_profiling_results_in_exact_matches_when_exact(): assert len(targets_left) == 0 async def test_institutpasteur_sequence_profiling_non_exact_returns_non_exact(): - sequences = SeqIO.parse("tests/resources/tohama_I_bpertussis_coding.fasta", "fasta") + sequences = list(SeqIO.parse("tests/resources/tohama_I_bpertussis_coding.fasta", "fasta")) mlst_targets = {"adk", "fumc", "glya", "tyrb", "icd", "pepa", "pgm"} async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as profiler: for sequence in sequences: @@ -40,13 +40,12 @@ async def test_institutpasteur_sequence_profiling_non_exact_returns_non_exact(): gene = match.group(1) if gene.lower() not in mlst_targets: continue - scrambled = gene_scrambler(str(sequence.seq), 0.2) + scrambled = gene_scrambler(str(sequence.seq), 0.125) async for partial_match in profiler.fetch_mlst_allele_variants(scrambled, False): assert partial_match.partial_match_profile is not None - assert partial_match.allele_variant == '1' mlst_targets.remove(gene.lower()) - assert len(mlst_targets) == 0 + assert len(mlst_targets) == 0 async def test_institutpasteur_profiling_results_in_correct_mlst_st(): async def dummy_allele_generator(): @@ -68,6 +67,22 @@ async def test_institutpasteur_profiling_results_in_correct_mlst_st(): assert mlst_st_data.clonal_complex == "ST-2 complex" assert mlst_st_data.sequence_type == "1" +async def test_institutpasteur_profiling_non_exact_results_in_list_of_mlsts(): + dummy_alleles = [ + Allele("adk", "1", None), + Allele("fumC", "2", None), + Allele("glyA", "36", None), + Allele("tyrB", "4", None), + Allele("icd", "4", None), + Allele("pepA", "1", None), + Allele("pgm", "5", None), + ] + async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + mlst_profile = await dummy_profiler.fetch_mlst_st(dummy_alleles) + assert mlst_profile.clonal_complex == "unknown" + assert mlst_profile.sequence_type == "unknown" + + async def test_institutpasteur_sequence_profiling_is_correct(): sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: