192 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			192 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from abc import abstractmethod
 | |
| from collections import defaultdict
 | |
| from contextlib import AbstractAsyncContextManager
 | |
| import csv
 | |
| from os import path
 | |
| from typing import Any, AsyncGenerator, AsyncIterable, Iterable, Mapping, Sequence, Union
 | |
| 
 | |
| from aiohttp import ClientSession, ClientTimeout
 | |
| 
 | |
| from autobigs.engine.data.local.fasta import read_fasta
 | |
| from autobigs.engine.data.structures.genomics import NamedString
 | |
| from autobigs.engine.data.structures.mlst import Allele, NamedMLSTProfile, PartialAllelicMatchProfile, MLSTProfile
 | |
| from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
 | |
| 
 | |
| from Bio.Align import PairwiseAligner
 | |
| 
 | |
| class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
 | |
| 
 | |
|     @abstractmethod
 | |
|     def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def close(self):
 | |
|         pass
 | |
| 
 | |
| class OnlineBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | |
| 
 | |
|     def __init__(self, database_api: str, database_name: str, schema_id: int):
 | |
|         self._database_name = database_name
 | |
|         self._schema_id = schema_id
 | |
|         self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
 | |
|         self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(10000))
 | |
| 
 | |
|     async def __aenter__(self):
 | |
|         return self
 | |
| 
 | |
|     async def fetch_mlst_allele_variants(self, sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
 | |
|         # See https://bigsdb.pasteur.fr/api/db/pubmlst_bordetella_seqdef/schemes
 | |
|         uri_path = "sequence"
 | |
| 
 | |
|         for sequence_string in sequence_strings:
 | |
|             async with self._http_client.post(uri_path, json={
 | |
|                 "sequence": sequence_string,
 | |
|                 "partial_matches": True
 | |
|             }) as response:
 | |
|                 sequence_response: dict = await response.json()
 | |
| 
 | |
|                 if "exact_matches" in sequence_response:
 | |
|                     # loci -> list of alleles with id and loci
 | |
|                     exact_matches: dict[str, Sequence[dict[str, str]]] = sequence_response["exact_matches"]  
 | |
|                     for allele_loci, alleles in exact_matches.items():
 | |
|                         for allele in alleles:
 | |
|                             alelle_id = allele["allele_id"]
 | |
|                             yield Allele(allele_locus=allele_loci, allele_variant=alelle_id, partial_match_profile=None)
 | |
|                 elif "partial_matches" in sequence_response:
 | |
|                     partial_matches: dict[str, dict[str, Union[str, float, int]]] = sequence_response["partial_matches"] 
 | |
|                     for allele_loci, partial_match in partial_matches.items():
 | |
|                         if len(partial_match) <= 0:
 | |
|                             continue
 | |
|                         partial_match_profile = PartialAllelicMatchProfile(
 | |
|                             percent_identity=float(partial_match["identity"]),
 | |
|                             mismatches=int(partial_match["mismatches"]),
 | |
|                             gaps=int(partial_match["gaps"])
 | |
|                         )
 | |
|                         yield Allele(
 | |
|                             allele_locus=allele_loci,
 | |
|                             allele_variant=str(partial_match["allele"]),
 | |
|                             partial_match_profile=partial_match_profile
 | |
|                         )
 | |
|                 else:
 | |
|                     raise NoBIGSdbMatchesException(self._database_name, self._schema_id)
 | |
| 
 | |
|     async def fetch_mlst_st(self, alleles: Union[AsyncIterable[Allele], Iterable[Allele]]) -> MLSTProfile:
 | |
|         uri_path = "designations"
 | |
|         allele_request_dict: dict[str, list[dict[str, str]]] = defaultdict(list)
 | |
|         if isinstance(alleles, AsyncIterable):
 | |
|             async for allele in alleles:
 | |
|                 allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
 | |
|         else:
 | |
|             for allele in alleles:
 | |
|                 allele_request_dict[allele.allele_locus].append({"allele": str(allele.allele_variant)})
 | |
|         request_json = {
 | |
|             "designations": allele_request_dict
 | |
|         }
 | |
|         async with self._http_client.post(uri_path, json=request_json) as response:
 | |
|             response_json: dict = await response.json()
 | |
|             allele_map: dict[str, Allele] = {}
 | |
|             response_json.setdefault("fields", dict())
 | |
|             schema_fields_returned: dict[str, str] = response_json["fields"]
 | |
|             schema_fields_returned.setdefault("ST", "unknown")
 | |
|             schema_fields_returned.setdefault("clonal_complex", "unknown")
 | |
|             schema_exact_matches: dict = response_json["exact_matches"]
 | |
|             for exact_match_locus, exact_match_alleles in schema_exact_matches.items():
 | |
|                 if len(exact_match_alleles) > 1:
 | |
|                     raise ValueError(f"Unexpected number of alleles returned for exact match (Expected 1, retrieved {len(exact_match_alleles)})")
 | |
|                 allele_map[exact_match_locus] = Allele(exact_match_locus, exact_match_alleles[0]["allele_id"], None)
 | |
|             if len(allele_map) == 0:
 | |
|                 raise ValueError("Passed in no alleles.")
 | |
|             return MLSTProfile(dict(allele_map), schema_fields_returned["ST"], schema_fields_returned["clonal_complex"])
 | |
| 
 | |
|     async def profile_string(self, sequence_strings: Iterable[str]) -> MLSTProfile:
 | |
|         alleles = self.fetch_mlst_allele_variants(sequence_strings)
 | |
|         return await self.fetch_mlst_st(alleles)
 | |
| 
 | |
|     async def profile_multiple_strings(self, named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
 | |
|         async for named_strings in named_string_groups:
 | |
|             for named_string in named_strings:
 | |
|                 try:
 | |
|                     yield NamedMLSTProfile(named_string.name, (await self.profile_string([named_string.sequence])))
 | |
|                 except NoBIGSdbMatchesException as e:
 | |
|                     if stop_on_fail:
 | |
|                         raise e
 | |
|                     yield NamedMLSTProfile(named_string.name, None)
 | |
| 
 | |
|     async def close(self):
 | |
|         await self._http_client.close()
 | |
| 
 | |
|     async def __aexit__(self, exc_type, exc_value, traceback):
 | |
|         await self.close()
 | |
| 
 | |
| class BIGSdbIndex(AbstractAsyncContextManager):
 | |
|     KNOWN_BIGSDB_APIS = {
 | |
|         "https://bigsdb.pasteur.fr/api",
 | |
|         "https://rest.pubmlst.org"
 | |
|     }
 | |
| 
 | |
|     def __init__(self):
 | |
|         self._http_client = ClientSession()
 | |
|         self._known_seqdef_dbs_origin: Union[Mapping[str, str], None] = None
 | |
|         self._seqdefdb_schemas: dict[str, Union[Mapping[str, int], None]] = dict()
 | |
|         super().__init__()
 | |
| 
 | |
|     async def __aenter__(self):
 | |
|         return self
 | |
|     
 | |
|     async def get_known_seqdef_dbs(self, force: bool = False) -> Mapping[str, str]:
 | |
|         if self._known_seqdef_dbs_origin is not None and not force:
 | |
|             return self._known_seqdef_dbs_origin
 | |
|         known_seqdef_dbs = dict()
 | |
|         for known_bigsdb in BIGSdbIndex.KNOWN_BIGSDB_APIS:
 | |
|             async with self._http_client.get(f"{known_bigsdb}/db") as response:
 | |
|                 response_json_databases = await response.json()
 | |
|                 for database_group in response_json_databases:
 | |
|                     for database_info in database_group["databases"]:
 | |
|                         if str(database_info["name"]).endswith("seqdef"):
 | |
|                             known_seqdef_dbs[database_info["name"]] = known_bigsdb
 | |
|         self._known_seqdef_dbs_origin = dict(known_seqdef_dbs)
 | |
|         return self._known_seqdef_dbs_origin
 | |
| 
 | |
|     async def get_bigsdb_api_from_seqdefdb(self, seqdef_db_name: str) -> str:
 | |
|         known_databases = await self.get_known_seqdef_dbs()
 | |
|         if seqdef_db_name not in known_databases:
 | |
|             raise NoSuchBIGSdbDatabaseException(seqdef_db_name)
 | |
|         return known_databases[seqdef_db_name]     
 | |
| 
 | |
|     async def get_schemas_for_seqdefdb(self, seqdef_db_name: str, force: bool = False) -> Mapping[str, int]:
 | |
|         if seqdef_db_name in self._seqdefdb_schemas and not force:
 | |
|             return self._seqdefdb_schemas[seqdef_db_name] # type: ignore since it's guaranteed to not be none by conditional
 | |
|         uri_path = f"{await self.get_bigsdb_api_from_seqdefdb(seqdef_db_name)}/db/{seqdef_db_name}/schemes"
 | |
|         async with self._http_client.get(uri_path) as response: 
 | |
|             response_json = await response.json()
 | |
|             schema_descriptions: Mapping[str, int] = dict()
 | |
|             for scheme_definition in response_json["schemes"]:
 | |
|                 scheme_id: int = int(str(scheme_definition["scheme"]).split("/")[-1])
 | |
|                 scheme_desc: str = scheme_definition["description"]
 | |
|                 schema_descriptions[scheme_desc] = scheme_id
 | |
|             self._seqdefdb_schemas[seqdef_db_name] = schema_descriptions
 | |
|             return self._seqdefdb_schemas[seqdef_db_name] # type: ignore
 | |
| 
 | |
|     async def build_profiler_from_seqdefdb(self, dbseqdef_name: str, schema_id: int) -> OnlineBIGSdbMLSTProfiler:
 | |
|         return OnlineBIGSdbMLSTProfiler(await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, schema_id)
 | |
| 
 | |
|     async def close(self):
 | |
|         await self._http_client.close()
 | |
| 
 | |
|     async def __aexit__(self, exc_type, exc_value, traceback):
 | |
|         await self.close()
 | |
|     
 |