From b4845fab348180f4ec45438b129bc4e7165d160e Mon Sep 17 00:00:00 2001 From: Harrison Deng Date: Thu, 6 Feb 2025 21:15:50 +0000 Subject: [PATCH] Added automatic handling of strings instead of arrays of sequences to typing --- src/autobigs/engine/analysis/bigsdb.py | 15 +++++++-------- tests/autobigs/engine/analysis/test_bigsdb.py | 9 +++++++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/autobigs/engine/analysis/bigsdb.py b/src/autobigs/engine/analysis/bigsdb.py index 0685da2..f96be17 100644 --- a/src/autobigs/engine/analysis/bigsdb.py +++ b/src/autobigs/engine/analysis/bigsdb.py @@ -48,17 +48,16 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): self._database_name = database_name self._schema_id = schema_id self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/" - self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000)) + self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60)) async def __aenter__(self): return self - async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: + async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[str], str]) -> AsyncGenerator[Allele, Any]: # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes uri_path = "sequence" - if not isinstance(query_sequence_strings, Iterable): - raise ValueError("Invalid data type for parameter \"sequence_strings\".") - + if isinstance(query_sequence_strings, str): + query_sequence_strings = [query_sequence_strings] for sequence_string in query_sequence_strings: async with self._http_client.post(uri_path, json={ "sequence": sequence_string, @@ -156,7 +155,7 @@ class LocalBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): self._database_name = database_name self._schema_id = schema_id self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/" - self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000)) + self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60)) if cache_path is None: self._cache_path = tempfile.mkdtemp("BIGSdb") self._cleanup_required = True @@ -324,8 +323,8 @@ class BIGSdbIndex(AbstractAsyncContextManager): self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions return self._seqdefdb_schemas[seqdef_db_name] # type: ignore - async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> RemoteBIGSdbMLSTProfiler: - return RemoteBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id) + async def build_profiler_from_seqdefdb(self, local: bool, dbseqdef_name: str, schema_id: int) -> BIGSdbMLSTProfiler: + return get_BIGSdb_MLST_profiler(local, await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id) async def close(self): await self._http_client.close() diff --git a/tests/autobigs/engine/analysis/test_bigsdb.py b/tests/autobigs/engine/analysis/test_bigsdb.py index 8d0bc15..1984ca7 100644 --- a/tests/autobigs/engine/analysis/test_bigsdb.py +++ b/tests/autobigs/engine/analysis/test_bigsdb.py @@ -201,10 +201,15 @@ class TestBIGSdbIndex: assert database_name.endswith("seqdef") assert databases["pubmlst_bordetella_seqdef"] == "https://bigsdb.pasteur.fr/api" - async def test_bigsdb_index_instantiates_correct_profiler(self): + @pytest.mark.parametrize("local", [ + (True), + (False) + ]) + async def test_bigsdb_index_instantiates_correct_profiler(self, local): sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) async with BIGSdbIndex() as bigsdb_index: - async with await bigsdb_index.build_profiler_from_seqdefdb("pubmlst_bordetella_seqdef", 3) as profiler: + async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler: + assert isinstance(profiler, BIGSdbMLSTProfiler) profile = await profiler.profile_string(sequence) assert profile.clonal_complex == "ST-2 complex" assert profile.sequence_type == "1"