diff --git a/src/autobigs/engine/analysis/bigsdb.py b/src/autobigs/engine/analysis/bigsdb.py index f96be17..0e9af59 100644 --- a/src/autobigs/engine/analysis/bigsdb.py +++ b/src/autobigs/engine/analysis/bigsdb.py @@ -139,141 +139,6 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() -class LocalBIGSdbMLSTProfiler(BIGSdbMLSTProfiler): - async def __aenter__(self): - if self._prepare: - await self.update_scheme_locis() - await asyncio.gather( - self.download_alleles_cache_data(), - self.download_scheme_profiles() - ) - await self.load_scheme_profiles() - return self - - def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: Union[str, None] = None, prepare: bool =True): - self._database_api = database_api - self._database_name = database_name - self._schema_id = schema_id - self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/" - self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60)) - if cache_path is None: - self._cache_path = tempfile.mkdtemp("BIGSdb") - self._cleanup_required = True - else: - self._cache_path = cache_path - self._cleanup_required = False - self._loci: list[str] = [] - self._profiles_st_map = {} - self._prepare = prepare - - async def update_scheme_locis(self): - self._loci.clear() - async with self._http_client.get(f"/api/db/{self._database_name}/schemes/{self._schema_id}") as schema_response: - schema_json = await schema_response.json() - for locus in schema_json["loci"]: - locus_name = path.basename(locus) - self._loci.append(locus_name) - self._loci.sort() - - async def load_scheme_profiles(self): - self._profiles_st_map.clear() - with open(self.get_scheme_profile_path()) as profile_cache_handle: - reader = csv.DictReader(profile_cache_handle, delimiter="\t") - for line in reader: - alleles = [] - for locus in self._loci: - alleles.append(line[locus]) - self._profiles_st_map[tuple(alleles)] = (line["ST"], line["clonal_complex"]) - - def get_locus_cache_path(self, locus) -> str: - return path.join(self._cache_path, locus + "." + "fasta") - - def get_scheme_profile_path(self): - return path.join(self._cache_path, "profiles.csv") - - async def download_alleles_cache_data(self): - for locus in self._loci: - with open(self.get_locus_cache_path(locus), "wb") as fasta_handle: - async with self._http_client.get(f"/api/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response: - async for chunk, eof in fasta_response.content.iter_chunks(): - fasta_handle.write(chunk) - - async def download_scheme_profiles(self): - with open(self.get_scheme_profile_path(), "wb") as profile_cache_handle: - async with self._http_client.get("profiles_csv") as profiles_response: - async for chunk, eof in profiles_response.content.iter_chunks(): - profile_cache_handle.write(chunk) - await self.load_scheme_profiles() - - async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]: - aligner = PairwiseAligner("blastn") - aligner.mode = "local" - with AsyncBiopythonPairwiseAlignmentEngine(aligner, max_threads=4) as aligner_engine: - for query_sequence_string in query_sequence_strings: - for locus in self._loci: - async for allele_variant in read_fasta(self.get_locus_cache_path(locus)): - aligner_engine.align(allele_variant.sequence, query_sequence_string, variant_name=allele_variant.name, full=True) - break # start a bunch of full alignments for each variant to select segments - alignment_rankings: dict[str, set[tuple[PairwiseAlignment, str]]] = defaultdict(set) - async for alignment_result, additional_information in aligner_engine: - result_variant_name = additional_information["variant_name"] - result_locus, variant_id = result_variant_name.split("_") - full_alignment = additional_information["full"] - if full_alignment: - if alignment_result.alignment_stats.gaps == 0 and alignment_result.alignment_stats.mismatches == 0: - # I.e., 100% exactly the same - yield Allele(result_locus, variant_id, None) - continue - else: - alignment_rankings[result_locus].add((alignment_result, variant_id)) - interest_sequence = full_alignment[alignment_result.query_indices[0]:alignment_result.query_indices[-1]] - async for allele_variant in read_fasta(self.get_locus_cache_path(result_locus)): - if result_variant_name == allele_variant.name: - continue # Skip if we just finished aligning this - aligner_engine.align(allele_variant.sequence, interest_sequence, variant_name=result_variant_name.name, full=False) - else: - alignment_rankings[result_locus].add((alignment_result, variant_id)) - for final_locus, alignments in alignment_rankings.items(): - closest_alignment, closest_variant_id = sorted(alignments, key=lambda index: index[0].alignment_stats.match_metric)[0] - yield Allele(final_locus, closest_variant_id, closest_alignment.alignment_stats) - - async def determine_mlst_st(self, alleles): - allele_variants: dict[str, Allele] = {} - if isinstance(alleles, AsyncIterable): - async for allele in alleles: - allele_variants[allele.allele_locus] = allele - else: - for allele in alleles: - allele_variants[allele.allele_locus] = allele - ordered_profile = [] - for locus in self._loci: - ordered_profile.append(allele_variants[locus].allele_variant) - - st, clonal_complex = self._profiles_st_map[tuple(ordered_profile)] - return MLSTProfile(set(allele_variants.values()), st, clonal_complex) - - async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile: - alleles = self.determine_mlst_allele_variants(query_sequence_strings) - return await self.determine_mlst_st(alleles) - - async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]: - async for named_strings in query_named_string_groups: - for named_string in named_strings: - try: - yield NamedMLSTProfile(named_string.name, await self.profile_string([named_string.sequence])) - except NoBIGSdbMatchesException as e: - if stop_on_fail: - raise e - yield NamedMLSTProfile(named_string.name, None) - - async def close(self): - await self._http_client.close() - if self._cleanup_required: - shutil.rmtree(self._cache_path) - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() - class BIGSdbIndex(AbstractAsyncContextManager): KNOWN_BIGSDB_APIS = { "https://bigsdb.pasteur.fr/api", @@ -334,5 +199,5 @@ class BIGSdbIndex(AbstractAsyncContextManager): def get_BIGSdb_MLST_profiler(local: bool, database_api: str, database_name: str, schema_id: int): if local: - return LocalBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id) + raise NotImplementedError() return RemoteBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id) \ No newline at end of file