From f20a656f45446620a1a47c711bb5e2f46bb6761e Mon Sep 17 00:00:00 2001 From: Harrison Deng Date: Fri, 10 Jan 2025 16:00:27 +0000 Subject: [PATCH] Fixed multiple string typing failure handling --- src/automlst/engine/exceptions/database.py | 4 +-- .../engine/remote/databases/bigsdb.py | 6 ++-- .../engine/remote/databases/test_bigsdb.py | 32 +++++++++++++++---- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/automlst/engine/exceptions/database.py b/src/automlst/engine/exceptions/database.py index 8b549c8..c60190a 100644 --- a/src/automlst/engine/exceptions/database.py +++ b/src/automlst/engine/exceptions/database.py @@ -6,9 +6,9 @@ class BIGSDbDatabaseAPIException(Exception): class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException): def __init__(self, database_name: str, database_schema_id: int, *args): - super().__init__(f"No exact match found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args) + super().__init__(f"No matches found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args) -class NoBIGSdbExactMatchesException(BIGSDbDatabaseAPIException): +class NoBIGSdbExactMatchesException(NoBIGSdbMatchesException): def __init__(self, database_name: str, database_schema_id: int, *args): super().__init__(f"No exact match found with schema with ID {database_schema_id} in the database \"{database_name}\".", *args) diff --git a/src/automlst/engine/remote/databases/bigsdb.py b/src/automlst/engine/remote/databases/bigsdb.py index 7bcd5d2..6cd877f 100644 --- a/src/automlst/engine/remote/databases/bigsdb.py +++ b/src/automlst/engine/remote/databases/bigsdb.py @@ -82,7 +82,9 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager): for exact_match_loci, exact_match_alleles in schema_exact_matches.items(): for exact_match_allele in exact_match_alleles: allele_map[exact_match_loci].append(Allele(exact_match_loci, exact_match_allele["allele_id"], None)) - return MLSTProfile(allele_map, schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) + if len(allele_map) == 0: + raise ValueError("Passed in no alleles.") + return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"]) async def profile_string(self, string: str, exact: bool = False) -> MLSTProfile: alleles = self.fetch_mlst_allele_variants(string, exact) @@ -93,7 +95,7 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager): async for named_string in namedStrings: try: yield (named_string.name, await self.profile_string(named_string.sequence, exact)) - except NoBIGSdbExactMatchesException as e: + except NoBIGSdbMatchesException as e: if stop_on_fail: raise e yield (named_string.name, None) diff --git a/tests/automlst/engine/remote/databases/test_bigsdb.py b/tests/automlst/engine/remote/databases/test_bigsdb.py index 0a2f581..f649281 100644 --- a/tests/automlst/engine/remote/databases/test_bigsdb.py +++ b/tests/automlst/engine/remote/databases/test_bigsdb.py @@ -5,7 +5,7 @@ from Bio import SeqIO import pytest from automlst.engine.data.genomics import NamedString from automlst.engine.data.mlst import Allele, MLSTProfile -from automlst.engine.exceptions.database import NoBIGSdbExactMatchesException +from automlst.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException from automlst.engine.remote.databases.bigsdb import BIGSdbIndex, BIGSdbMLSTProfiler def gene_scrambler(gene: str, mutation_site_count: Union[int, float], alphabet: Sequence[str] = ["A", "T", "C", "G"]): @@ -175,15 +175,14 @@ async def test_bigsdb_profile_multiple_strings_same_string_twice(): assert profile.clonal_complex == "ST-2 complex" assert profile.sequence_type == "1" -async def test_bigsdb_profile_multiple_strings_fail_second_no_stop(): +async def test_bigsdb_profile_multiple_strings_exactmatch_fail_second_no_stop(): valid_seq = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) - invalid_seq = str(SeqIO.read("tests/resources/FDAARGOS_1560.fasta", "fasta").seq) - dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", invalid_seq), NamedString("seq3", valid_seq)] + dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", gene_scrambler(valid_seq, 0.3)), NamedString("seq3", valid_seq)] async def generate_async_iterable_sequences(): for dummy_sequence in dummy_sequences: yield dummy_sequence async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences()): + async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), True): if name == "should_fail": assert profile is None else: @@ -192,6 +191,25 @@ async def test_bigsdb_profile_multiple_strings_fail_second_no_stop(): assert profile.clonal_complex == "ST-2 complex" assert profile.sequence_type == "1" +async def test_bigsdb_profile_multiple_strings_nonexact_second_no_stop(): + valid_seq = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) + dummy_sequences = [NamedString("seq1", valid_seq), NamedString("should_fail", gene_scrambler(valid_seq, 0.3)), NamedString("seq3", valid_seq)] + async def generate_async_iterable_sequences(): + for dummy_sequence in dummy_sequences: + yield dummy_sequence + async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: + async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), False): + if name == "should_fail": + assert profile is not None + assert profile.clonal_complex == "unknown" + assert profile.sequence_type == "unknown" + assert len(profile.alleles) > 0 + else: + assert profile is not None + assert isinstance(profile, MLSTProfile) + assert profile.clonal_complex == "ST-2 complex" + assert profile.sequence_type == "1" + async def test_bigsdb_profile_multiple_strings_fail_second_stop(): valid_seq = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) invalid_seq = str(SeqIO.read("tests/resources/FDAARGOS_1560.fasta", "fasta").seq) @@ -200,8 +218,8 @@ async def test_bigsdb_profile_multiple_strings_fail_second_stop(): for dummy_sequence in dummy_sequences: yield dummy_sequence async with BIGSdbMLSTProfiler(database_api="https://bigsdb.pasteur.fr/api", database_name="pubmlst_bordetella_seqdef", schema_id=3) as dummy_profiler: - with pytest.raises(NoBIGSdbExactMatchesException): - async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), stop_on_fail=True): + with pytest.raises(NoBIGSdbMatchesException): + async for name, profile in dummy_profiler.profile_multiple_strings(generate_async_iterable_sequences(), exact=True, stop_on_fail=True): if name == "should_fail": pytest.fail("Exception should have been thrown, no exception was thrown.") else: