Added automatic handling of strings instead of arrays of sequences to typing

This commit is contained in:
Harrison Deng 2025-02-06 21:15:50 +00:00
parent fe999f1cab
commit b4845fab34
2 changed files with 14 additions and 10 deletions

View File

@ -48,17 +48,16 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
self._database_name = database_name self._database_name = database_name
self._schema_id = schema_id self._schema_id = schema_id
self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/" self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000)) self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60))
async def __aenter__(self): async def __aenter__(self):
return self return self
async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: async def determine_mlst_allele_variants(self, query_sequence_strings: Union[Iterable[str], str]) -> AsyncGenerator[Allele, Any]:
# See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
uri_path = "sequence" uri_path = "sequence"
if not isinstance(query_sequence_strings, Iterable): if isinstance(query_sequence_strings, str):
raise ValueError("Invalid data type for parameter \"sequence_strings\".") query_sequence_strings = [query_sequence_strings]
for sequence_string in query_sequence_strings: for sequence_string in query_sequence_strings:
async with self._http_client.post(uri_path, json={ async with self._http_client.post(uri_path, json={
"sequence": sequence_string, "sequence": sequence_string,
@ -156,7 +155,7 @@ class LocalBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
self._database_name = database_name self._database_name = database_name
self._schema_id = schema_id self._schema_id = schema_id
self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/" self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000)) self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60))
if cache_path is None: if cache_path is None:
self._cache_path = tempfile.mkdtemp("BIGSdb") self._cache_path = tempfile.mkdtemp("BIGSdb")
self._cleanup_required = True self._cleanup_required = True
@ -324,8 +323,8 @@ class BIGSdbIndex(AbstractAsyncContextManager):
self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
return self._seqdefdb_schemas[seqdef_db_name] # type: ignore return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> RemoteBIGSdbMLSTProfiler: async def build_profiler_from_seqdefdb(self, local: bool, dbseqdef_name: str, schema_id: int) -> BIGSdbMLSTProfiler:
return RemoteBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id) return get_BIGSdb_MLST_profiler(local, await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
async def close(self): async def close(self):
await self._http_client.close() await self._http_client.close()

View File

@ -201,10 +201,15 @@ class TestBIGSdbIndex:
assert database_name.endswith("seqdef") assert database_name.endswith("seqdef")
assert databases["pubmlst_bordetella_seqdef"] == "https://bigsdb.pasteur.fr/api" assert databases["pubmlst_bordetella_seqdef"] == "https://bigsdb.pasteur.fr/api"
async def test_bigsdb_index_instantiates_correct_profiler(self): @pytest.mark.parametrize("local", [
(True),
(False)
])
async def test_bigsdb_index_instantiates_correct_profiler(self, local):
sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq) sequence = str(SeqIO.read("tests/resources/tohama_I_bpertussis.fasta", "fasta").seq)
async with BIGSdbIndex() as bigsdb_index: async with BIGSdbIndex() as bigsdb_index:
async with await bigsdb_index.build_profiler_from_seqdefdb("pubmlst_bordetella_seqdef", 3) as profiler: async with await bigsdb_index.build_profiler_from_seqdefdb(local, "pubmlst_bordetella_seqdef", 3) as profiler:
assert isinstance(profiler, BIGSdbMLSTProfiler)
profile = await profiler.profile_string(sequence) profile = await profiler.profile_string(sequence)
assert profile.clonal_complex == "ST-2 complex" assert profile.clonal_complex == "ST-2 complex"
assert profile.sequence_type == "1" assert profile.sequence_type == "1"