Compare commits
	
		
			9 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 29fcf8c176 | |||
| 8264242fa5 | |||
| f8d92a4aad | |||
| 34bf02c75a | |||
| 3cb10a4609 | |||
| 1776f5aa51 | |||
| 96d715fdcb | |||
| e088d1080b | |||
| 8ffc7c7fb5 | 
							
								
								
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							| @@ -1,5 +1,6 @@ | |||||||
| { | { | ||||||
|     "recommendations": [ |     "recommendations": [ | ||||||
|         "piotrpalarz.vscode-gitignore-generator" |         "piotrpalarz.vscode-gitignore-generator", | ||||||
|  |         "gruntfuggly.todo-tree" | ||||||
|     ] |     ] | ||||||
| } | } | ||||||
| @@ -9,13 +9,13 @@ import shutil | |||||||
| import tempfile | import tempfile | ||||||
| from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union | from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union | ||||||
|  |  | ||||||
| from aiohttp import ClientSession, ClientTimeout | from aiohttp import ClientOSError, ClientSession, ClientTimeout, ServerDisconnectedError | ||||||
|  |  | ||||||
| from autobigs.engine.reading import read_fasta | from autobigs.engine.reading import read_fasta | ||||||
| from autobigs.engine.structures.alignment import PairwiseAlignment | from autobigs.engine.structures.alignment import PairwiseAlignment | ||||||
| from autobigs.engine.structures.genomics import NamedString | from autobigs.engine.structures.genomics import NamedString | ||||||
| from autobigs.engine.structures.mlst import Allele, NamedMLSTProfile, AlignmentStats, MLSTProfile | from autobigs.engine.structures.mlst import Allele, NamedMLSTProfile, AlignmentStats, MLSTProfile | ||||||
| from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException | from autobigs.engine.exceptions.database import BIGSdbResponseNotOkay, NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException | ||||||
|  |  | ||||||
| from Bio.Align import PairwiseAligner | from Bio.Align import PairwiseAligner | ||||||
|  |  | ||||||
| @@ -43,11 +43,12 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager): | |||||||
|  |  | ||||||
| class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): | class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): | ||||||
|  |  | ||||||
|     def __init__(self, database_api: str, database_name: str, scheme_id: int): |     def __init__(self, database_api: str, database_name: str, scheme_id: int, retry_requests: int = 5): | ||||||
|  |         self._retry_limit = retry_requests | ||||||
|         self._database_name = database_name |         self._database_name = database_name | ||||||
|         self._scheme_id = scheme_id |         self._scheme_id = scheme_id | ||||||
|         self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/" |         self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/" | ||||||
|         self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60)) |         self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(300)) | ||||||
|  |  | ||||||
|     async def __aenter__(self): |     async def __aenter__(self): | ||||||
|         return self |         return self | ||||||
| @@ -57,40 +58,62 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): | |||||||
|         uri_path = "sequence" |         uri_path = "sequence" | ||||||
|         if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString): |         if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString): | ||||||
|             query_sequence_strings = [query_sequence_strings] |             query_sequence_strings = [query_sequence_strings] | ||||||
|         for sequence_string in query_sequence_strings: |  | ||||||
|             async with self._http_client.post(uri_path, json={ |  | ||||||
|                 "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence, |  | ||||||
|                 "partial_matches": True |  | ||||||
|             }) as response: |  | ||||||
|                 sequence_response: dict = await response.json() |  | ||||||
|          |          | ||||||
|                 if "exact_matches" in sequence_response: |         for sequence_string in query_sequence_strings: | ||||||
|                     # loci -> list of alleles with id and loci |             attempts = 0 | ||||||
|                     exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]   |             success = False | ||||||
|                     for allele_loci, alleles in exact_matches.items(): |             last_error = None | ||||||
|                         for allele in alleles: |             while not success and attempts < self._retry_limit: | ||||||
|                             alelle_id = allele["allele_id"] |                 attempts += 1 | ||||||
|                             result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) |                 request = self._http_client.post(uri_path, json={ | ||||||
|                             yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) |                     "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence, | ||||||
|                 elif "partial_matches" in sequence_response: |                     "partial_matches": True | ||||||
|                     partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]  |                 }) | ||||||
|                     for allele_loci, partial_match in partial_matches.items(): |                 try: | ||||||
|                         if len(partial_match) <= 0: |                     async with request as response: | ||||||
|                             continue |                         sequence_response: dict = await response.json() | ||||||
|                         partial_match_profile = AlignmentStats( |  | ||||||
|                             percent_identity=float(partial_match["identity"]), |                         if "exact_matches" in sequence_response: | ||||||
|                             mismatches=int(partial_match["mismatches"]), |                             # loci -> list of alleles with id and loci | ||||||
|                             gaps=int(partial_match["gaps"]), |                             exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]   | ||||||
|                             match_metric=int(partial_match["bitscore"]) |                             for allele_loci, alleles in exact_matches.items(): | ||||||
|                         ) |                                 for allele in alleles: | ||||||
|                         result_allele = Allele( |                                     alelle_id = allele["allele_id"] | ||||||
|                             allele_locus=allele_loci, |                                     result_allele = Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None) | ||||||
|                             allele_variant=str(partial_match["allele"]), |                                     yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) | ||||||
|                             partial_match_profile=partial_match_profile |                         elif "partial_matches" in sequence_response: | ||||||
|                         ) |                             partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"]  | ||||||
|                         yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) |                             for allele_loci, partial_match in partial_matches.items(): | ||||||
|  |                                 if len(partial_match) <= 0: | ||||||
|  |                                     continue | ||||||
|  |                                 partial_match_profile = AlignmentStats( | ||||||
|  |                                     percent_identity=float(partial_match["identity"]), | ||||||
|  |                                     mismatches=int(partial_match["mismatches"]), | ||||||
|  |                                     gaps=int(partial_match["gaps"]), | ||||||
|  |                                     match_metric=int(partial_match["bitscore"]) | ||||||
|  |                                 ) | ||||||
|  |                                 result_allele = Allele( | ||||||
|  |                                     allele_locus=allele_loci, | ||||||
|  |                                     allele_variant=str(partial_match["allele"]), | ||||||
|  |                                     partial_match_profile=partial_match_profile | ||||||
|  |                                 ) | ||||||
|  |                                 yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele) | ||||||
|  |                         else: | ||||||
|  |                             if response.status == 200: | ||||||
|  |                                 raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None) | ||||||
|  |                             else: | ||||||
|  |                                 raise BIGSdbResponseNotOkay(sequence_response) | ||||||
|  |                 except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Errors we will retry | ||||||
|  |                     last_error = e | ||||||
|  |                     success = False | ||||||
|  |                     await asyncio.sleep(5) # In case the connection issue is due to rate issues | ||||||
|                 else: |                 else: | ||||||
|                     raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None) |                     success = True | ||||||
|  |             if not success and last_error is not None: | ||||||
|  |                 try: | ||||||
|  |                     raise last_error | ||||||
|  |                 except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Non-fatal errors | ||||||
|  |                     yield Allele("error", "error", None) | ||||||
|  |  | ||||||
|     async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]: |     async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]: | ||||||
|         uri_path = "designations" |         uri_path = "designations" | ||||||
| @@ -113,22 +136,42 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): | |||||||
|         request_json = { |         request_json = { | ||||||
|             "designations": allele_request_dict |             "designations": allele_request_dict | ||||||
|         } |         } | ||||||
|         async with self._http_client.post(uri_path, json=request_json) as response: |  | ||||||
|             response_json: dict = await response.json() |         attempts = 0 | ||||||
|             allele_set: Set[Allele] = set() |         success = False | ||||||
|             response_json.setdefault("fields", dict()) |         last_error = None | ||||||
|             scheme_fields_returned: dict[str, str] = response_json["fields"] |         while attempts < self._retry_limit and not success: | ||||||
|             scheme_fields_returned.setdefault("ST", "unknown") |             attempts += 1 | ||||||
|             scheme_fields_returned.setdefault("clonal_complex", "unknown") |             try: | ||||||
|             scheme_exact_matches: dict = response_json["exact_matches"] |                 async with self._http_client.post(uri_path, json=request_json) as response: | ||||||
|             for exact_match_locus, exact_match_alleles in scheme_exact_matches.items(): |                     response_json: dict = await response.json() | ||||||
|                 allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)) |                     allele_set: Set[Allele] = set() | ||||||
|             if len(allele_set) == 0: |                     response_json.setdefault("fields", dict()) | ||||||
|                 raise ValueError("Passed in no alleles.") |                     scheme_fields_returned: dict[str, str] = response_json["fields"] | ||||||
|             result_mlst_profile = MLSTProfile(allele_set, scheme_fields_returned["ST"], scheme_fields_returned["clonal_complex"]) |                     scheme_fields_returned.setdefault("ST", "unknown") | ||||||
|             if len(names_list) > 0: |                     scheme_fields_returned.setdefault("clonal_complex", "unknown") | ||||||
|                 result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0], result_mlst_profile) |                     scheme_exact_matches: dict = response_json["exact_matches"] | ||||||
|             return result_mlst_profile |                     for exact_match_locus, exact_match_alleles in scheme_exact_matches.items(): | ||||||
|  |                         allele_set.add(Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)) | ||||||
|  |                     if len(allele_set) == 0: | ||||||
|  |                         raise ValueError("Passed in no alleles.") | ||||||
|  |                     result_mlst_profile = MLSTProfile(allele_set, scheme_fields_returned["ST"], scheme_fields_returned["clonal_complex"]) | ||||||
|  |                     if len(names_list) > 0: | ||||||
|  |                         result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0], result_mlst_profile) | ||||||
|  |                     return result_mlst_profile | ||||||
|  |             except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: | ||||||
|  |                 last_error = e | ||||||
|  |                 success = False | ||||||
|  |                 await asyncio.sleep(5) | ||||||
|  |             else: | ||||||
|  |                 success = True | ||||||
|  |         try: | ||||||
|  |             if last_error is not None: | ||||||
|  |                 raise last_error | ||||||
|  |         except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: | ||||||
|  |                 result_mlst_profile = NamedMLSTProfile((str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0]) + ":Error", None) | ||||||
|  |         raise ValueError("Last error was not recorded.") | ||||||
|  |  | ||||||
|                      |                      | ||||||
|     async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]: |     async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]: | ||||||
|         alleles = self.determine_mlst_allele_variants(query_sequence_strings) |         alleles = self.determine_mlst_allele_variants(query_sequence_strings) | ||||||
| @@ -212,6 +255,16 @@ class BIGSdbIndex(AbstractAsyncContextManager): | |||||||
|     async def build_profiler_from_seqdefdb(self, local: bool, dbseqdef_name: str, scheme_id: int) -> BIGSdbMLSTProfiler: |     async def build_profiler_from_seqdefdb(self, local: bool, dbseqdef_name: str, scheme_id: int) -> BIGSdbMLSTProfiler: | ||||||
|         return get_BIGSdb_MLST_profiler(local, await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, scheme_id) |         return get_BIGSdb_MLST_profiler(local, await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, scheme_id) | ||||||
|      |      | ||||||
|  |     async def get_scheme_loci(self, dbseqdef_name: str, scheme_id: int) -> list[str]: | ||||||
|  |         uri_path = f"{await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name)}/db/{dbseqdef_name}/schemes/{scheme_id}" | ||||||
|  |         async with self._http_client.get(uri_path) as response: | ||||||
|  |             response_json = await response.json() | ||||||
|  |             loci = response_json["loci"] | ||||||
|  |             results = [] | ||||||
|  |             for locus in loci: | ||||||
|  |                 results.append(path.basename(locus)) | ||||||
|  |             return results | ||||||
|  |  | ||||||
|     async def close(self): |     async def close(self): | ||||||
|         await self._http_client.close() |         await self._http_client.close() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,11 +3,13 @@ from typing import Union | |||||||
| class BIGSDbDatabaseAPIException(Exception): | class BIGSDbDatabaseAPIException(Exception): | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  | class BIGSdbResponseNotOkay(BIGSDbDatabaseAPIException): | ||||||
|  |     pass | ||||||
|  |  | ||||||
| class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException): | class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException): | ||||||
|     def __init__(self, database_name: str, database_scheme_id: int, query_name: Union[None, str], *args): |     def __init__(self, database_name: str, database_scheme_id: int, query_name: Union[None, str], *args): | ||||||
|         self._query_name = query_name |         self._query_name = query_name | ||||||
|         super().__init__(f"No matches found with scheme with ID {database_scheme_id}  in the database \"{database_name}\".", *args) |         super().__init__(f"No matches found with scheme with ID {database_scheme_id} in the database \"{database_name}\".", *args) | ||||||
|      |      | ||||||
|     def get_causal_query_name(self) -> Union[str, None]: |     def get_causal_query_name(self) -> Union[str, None]: | ||||||
|         return self._query_name |         return self._query_name | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import csv | import csv | ||||||
| from os import PathLike | from os import PathLike | ||||||
| from typing import AsyncIterable, Collection, Mapping, Sequence, Union | from typing import AsyncIterable, Collection, Iterable, Mapping, Sequence, Union | ||||||
|  |  | ||||||
| from autobigs.engine.structures.mlst import Allele, MLSTProfile, NamedMLSTProfile | from autobigs.engine.structures.mlst import Allele, MLSTProfile, NamedMLSTProfile | ||||||
|  |  | ||||||
| @@ -17,7 +17,7 @@ def alleles_to_text_map(alleles: Collection[Allele]) -> Mapping[str, Union[Seque | |||||||
|             result[locus] = tuple(result[locus]) # type: ignore |             result[locus] = tuple(result[locus]) # type: ignore | ||||||
|     return dict(result) |     return dict(result) | ||||||
|  |  | ||||||
| async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[NamedMLSTProfile], handle: Union[str, bytes, PathLike[str], PathLike[bytes]]) -> Sequence[str]: | async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[NamedMLSTProfile], handle: Union[str, bytes, PathLike[str], PathLike[bytes]], allele_names: Iterable[str]) -> Sequence[str]: | ||||||
|     failed = list() |     failed = list() | ||||||
|     with open(handle, "w", newline='') as filehandle: |     with open(handle, "w", newline='') as filehandle: | ||||||
|         header = None |         header = None | ||||||
| @@ -30,7 +30,7 @@ async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[Named | |||||||
|                 continue |                 continue | ||||||
|             allele_mapping = alleles_to_text_map(mlst_profile.alleles) |             allele_mapping = alleles_to_text_map(mlst_profile.alleles) | ||||||
|             if writer is None: |             if writer is None: | ||||||
|                 header = ["id", "st", "clonal-complex", *sorted(allele_mapping.keys())] |                 header = ["id", "st", "clonal-complex", *sorted(allele_names)] | ||||||
|                 writer = csv.DictWriter(filehandle, fieldnames=header) |                 writer = csv.DictWriter(filehandle, fieldnames=header) | ||||||
|                 writer.writeheader() |                 writer.writeheader() | ||||||
|             row_dictionary = { |             row_dictionary = { | ||||||
|   | |||||||
| @@ -222,3 +222,12 @@ class TestBIGSdbIndex: | |||||||
|                 assert isinstance(profile, MLSTProfile) |                 assert isinstance(profile, MLSTProfile) | ||||||
|                 assert profile.clonal_complex == "ST-2 complex" |                 assert profile.clonal_complex == "ST-2 complex" | ||||||
|                 assert profile.sequence_type == "1" |                 assert profile.sequence_type == "1" | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     @pytest.mark.parametrize(["bigsdb_name", "scheme_id", "expected"], [ | ||||||
|  |         ("pubmlst_bordetella_seqdef", 3, ["adk", "fumC", "glyA", "tyrB", "icd", "pepA", "pgm"]) | ||||||
|  |     ]) | ||||||
|  |     async def test_bigsdb_index_fetches_loci_names(self, bigsdb_name, scheme_id, expected): | ||||||
|  |         async with BIGSdbIndex() as bigsdb_index: | ||||||
|  |             loci = await bigsdb_index.get_scheme_loci(bigsdb_name, scheme_id) | ||||||
|  |             assert set(loci) == set(expected) | ||||||
| @@ -27,7 +27,7 @@ async def test_column_order_is_same_as_expected_file(dummy_alphabet_mlst_profile | |||||||
|     dummy_profiles = [dummy_alphabet_mlst_profile] |     dummy_profiles = [dummy_alphabet_mlst_profile] | ||||||
|     with tempfile.TemporaryDirectory() as temp_dir: |     with tempfile.TemporaryDirectory() as temp_dir: | ||||||
|         output_path = path.join(temp_dir, "out.csv") |         output_path = path.join(temp_dir, "out.csv") | ||||||
|         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path) |         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"]) | ||||||
|         with open(output_path) as csv_handle: |         with open(output_path) as csv_handle: | ||||||
|             csv_reader = reader(csv_handle) |             csv_reader = reader(csv_handle) | ||||||
|             lines = list(csv_reader) |             lines = list(csv_reader) | ||||||
| @@ -38,7 +38,7 @@ async def test_csv_writing_sample_name_not_repeated_when_single_sequence(dummy_a | |||||||
|     dummy_profiles = [dummy_alphabet_mlst_profile] |     dummy_profiles = [dummy_alphabet_mlst_profile] | ||||||
|     with tempfile.TemporaryDirectory() as temp_dir: |     with tempfile.TemporaryDirectory() as temp_dir: | ||||||
|         output_path = path.join(temp_dir, "out.csv") |         output_path = path.join(temp_dir, "out.csv") | ||||||
|         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path) |         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"]) | ||||||
|         with open(output_path) as csv_handle: |         with open(output_path) as csv_handle: | ||||||
|             csv_reader = reader(csv_handle) |             csv_reader = reader(csv_handle) | ||||||
|             lines = list(csv_reader) |             lines = list(csv_reader) | ||||||
| @@ -63,7 +63,7 @@ async def test_csv_writing_includes_asterisk_for_non_exact(dummy_alphabet_mlst_p | |||||||
|     dummy_profiles = [dummy_alphabet_mlst_profile] |     dummy_profiles = [dummy_alphabet_mlst_profile] | ||||||
|     with tempfile.TemporaryDirectory() as temp_dir: |     with tempfile.TemporaryDirectory() as temp_dir: | ||||||
|         output_path = path.join(temp_dir, "out.csv") |         output_path = path.join(temp_dir, "out.csv") | ||||||
|         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path) |         await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"]) | ||||||
|         with open(output_path) as csv_handle: |         with open(output_path) as csv_handle: | ||||||
|             csv_reader = reader(csv_handle) |             csv_reader = reader(csv_handle) | ||||||
|             lines = list(csv_reader) |             lines = list(csv_reader) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user