Replaced local profiler with a not implemented exception
This commit is contained in:
parent
897f7ee922
commit
175a51f968
@ -139,141 +139,6 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
|
|||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
class LocalBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
|
|
||||||
async def __aenter__(self):
|
|
||||||
if self._prepare:
|
|
||||||
await self.update_scheme_locis()
|
|
||||||
await asyncio.gather(
|
|
||||||
self.download_alleles_cache_data(),
|
|
||||||
self.download_scheme_profiles()
|
|
||||||
)
|
|
||||||
await self.load_scheme_profiles()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __init__(self, database_api: str, database_name: str, schema_id: int, cache_path: Union[str, None] = None, prepare: bool =True):
|
|
||||||
self._database_api = database_api
|
|
||||||
self._database_name = database_name
|
|
||||||
self._schema_id = schema_id
|
|
||||||
self._base_url = f"{self._database_api}/db/{self._database_name}/schemes/{self._schema_id}/"
|
|
||||||
self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60))
|
|
||||||
if cache_path is None:
|
|
||||||
self._cache_path = tempfile.mkdtemp("BIGSdb")
|
|
||||||
self._cleanup_required = True
|
|
||||||
else:
|
|
||||||
self._cache_path = cache_path
|
|
||||||
self._cleanup_required = False
|
|
||||||
self._loci: list[str] = []
|
|
||||||
self._profiles_st_map = {}
|
|
||||||
self._prepare = prepare
|
|
||||||
|
|
||||||
async def update_scheme_locis(self):
|
|
||||||
self._loci.clear()
|
|
||||||
async with self._http_client.get(f"/api/db/{self._database_name}/schemes/{self._schema_id}") as schema_response:
|
|
||||||
schema_json = await schema_response.json()
|
|
||||||
for locus in schema_json["loci"]:
|
|
||||||
locus_name = path.basename(locus)
|
|
||||||
self._loci.append(locus_name)
|
|
||||||
self._loci.sort()
|
|
||||||
|
|
||||||
async def load_scheme_profiles(self):
|
|
||||||
self._profiles_st_map.clear()
|
|
||||||
with open(self.get_scheme_profile_path()) as profile_cache_handle:
|
|
||||||
reader = csv.DictReader(profile_cache_handle, delimiter="\t")
|
|
||||||
for line in reader:
|
|
||||||
alleles = []
|
|
||||||
for locus in self._loci:
|
|
||||||
alleles.append(line[locus])
|
|
||||||
self._profiles_st_map[tuple(alleles)] = (line["ST"], line["clonal_complex"])
|
|
||||||
|
|
||||||
def get_locus_cache_path(self, locus) -> str:
|
|
||||||
return path.join(self._cache_path, locus + "." + "fasta")
|
|
||||||
|
|
||||||
def get_scheme_profile_path(self):
|
|
||||||
return path.join(self._cache_path, "profiles.csv")
|
|
||||||
|
|
||||||
async def download_alleles_cache_data(self):
|
|
||||||
for locus in self._loci:
|
|
||||||
with open(self.get_locus_cache_path(locus), "wb") as fasta_handle:
|
|
||||||
async with self._http_client.get(f"/api/db/{self._database_name}/loci/{locus}/alleles_fasta") as fasta_response:
|
|
||||||
async for chunk, eof in fasta_response.content.iter_chunks():
|
|
||||||
fasta_handle.write(chunk)
|
|
||||||
|
|
||||||
async def download_scheme_profiles(self):
|
|
||||||
with open(self.get_scheme_profile_path(), "wb") as profile_cache_handle:
|
|
||||||
async with self._http_client.get("profiles_csv") as profiles_response:
|
|
||||||
async for chunk, eof in profiles_response.content.iter_chunks():
|
|
||||||
profile_cache_handle.write(chunk)
|
|
||||||
await self.load_scheme_profiles()
|
|
||||||
|
|
||||||
async def determine_mlst_allele_variants(self, query_sequence_strings: Iterable[str]) -> AsyncGenerator[Allele, Any]:
|
|
||||||
aligner = PairwiseAligner("blastn")
|
|
||||||
aligner.mode = "local"
|
|
||||||
with AsyncBiopythonPairwiseAlignmentEngine(aligner, max_threads=4) as aligner_engine:
|
|
||||||
for query_sequence_string in query_sequence_strings:
|
|
||||||
for locus in self._loci:
|
|
||||||
async for allele_variant in read_fasta(self.get_locus_cache_path(locus)):
|
|
||||||
aligner_engine.align(allele_variant.sequence, query_sequence_string, variant_name=allele_variant.name, full=True)
|
|
||||||
break # start a bunch of full alignments for each variant to select segments
|
|
||||||
alignment_rankings: dict[str, set[tuple[PairwiseAlignment, str]]] = defaultdict(set)
|
|
||||||
async for alignment_result, additional_information in aligner_engine:
|
|
||||||
result_variant_name = additional_information["variant_name"]
|
|
||||||
result_locus, variant_id = result_variant_name.split("_")
|
|
||||||
full_alignment = additional_information["full"]
|
|
||||||
if full_alignment:
|
|
||||||
if alignment_result.alignment_stats.gaps == 0 and alignment_result.alignment_stats.mismatches == 0:
|
|
||||||
# I.e., 100% exactly the same
|
|
||||||
yield Allele(result_locus, variant_id, None)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
alignment_rankings[result_locus].add((alignment_result, variant_id))
|
|
||||||
interest_sequence = full_alignment[alignment_result.query_indices[0]:alignment_result.query_indices[-1]]
|
|
||||||
async for allele_variant in read_fasta(self.get_locus_cache_path(result_locus)):
|
|
||||||
if result_variant_name == allele_variant.name:
|
|
||||||
continue # Skip if we just finished aligning this
|
|
||||||
aligner_engine.align(allele_variant.sequence, interest_sequence, variant_name=result_variant_name.name, full=False)
|
|
||||||
else:
|
|
||||||
alignment_rankings[result_locus].add((alignment_result, variant_id))
|
|
||||||
for final_locus, alignments in alignment_rankings.items():
|
|
||||||
closest_alignment, closest_variant_id = sorted(alignments, key=lambda index: index[0].alignment_stats.match_metric)[0]
|
|
||||||
yield Allele(final_locus, closest_variant_id, closest_alignment.alignment_stats)
|
|
||||||
|
|
||||||
async def determine_mlst_st(self, alleles):
|
|
||||||
allele_variants: dict[str, Allele] = {}
|
|
||||||
if isinstance(alleles, AsyncIterable):
|
|
||||||
async for allele in alleles:
|
|
||||||
allele_variants[allele.allele_locus] = allele
|
|
||||||
else:
|
|
||||||
for allele in alleles:
|
|
||||||
allele_variants[allele.allele_locus] = allele
|
|
||||||
ordered_profile = []
|
|
||||||
for locus in self._loci:
|
|
||||||
ordered_profile.append(allele_variants[locus].allele_variant)
|
|
||||||
|
|
||||||
st, clonal_complex = self._profiles_st_map[tuple(ordered_profile)]
|
|
||||||
return MLSTProfile(set(allele_variants.values()), st, clonal_complex)
|
|
||||||
|
|
||||||
async def profile_string(self, query_sequence_strings: Iterable[str]) -> MLSTProfile:
|
|
||||||
alleles = self.determine_mlst_allele_variants(query_sequence_strings)
|
|
||||||
return await self.determine_mlst_st(alleles)
|
|
||||||
|
|
||||||
async def profile_multiple_strings(self, query_named_string_groups: AsyncIterable[Iterable[NamedString]], stop_on_fail: bool = False) -> AsyncGenerator[NamedMLSTProfile, Any]:
|
|
||||||
async for named_strings in query_named_string_groups:
|
|
||||||
for named_string in named_strings:
|
|
||||||
try:
|
|
||||||
yield NamedMLSTProfile(named_string.name, await self.profile_string([named_string.sequence]))
|
|
||||||
except NoBIGSdbMatchesException as e:
|
|
||||||
if stop_on_fail:
|
|
||||||
raise e
|
|
||||||
yield NamedMLSTProfile(named_string.name, None)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self._http_client.close()
|
|
||||||
if self._cleanup_required:
|
|
||||||
shutil.rmtree(self._cache_path)
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
class BIGSdbIndex(AbstractAsyncContextManager):
|
class BIGSdbIndex(AbstractAsyncContextManager):
|
||||||
KNOWN_BIGSDB_APIS = {
|
KNOWN_BIGSDB_APIS = {
|
||||||
"https://bigsdb.pasteur.fr/api",
|
"https://bigsdb.pasteur.fr/api",
|
||||||
@ -334,5 +199,5 @@ class BIGSdbIndex(AbstractAsyncContextManager):
|
|||||||
|
|
||||||
def get_BIGSdb_MLST_profiler(local: bool, database_api: str, database_name: str, schema_id: int):
|
def get_BIGSdb_MLST_profiler(local: bool, database_api: str, database_name: str, schema_id: int):
|
||||||
if local:
|
if local:
|
||||||
return LocalBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id)
|
raise NotImplementedError()
|
||||||
return RemoteBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id)
|
return RemoteBIGSdbMLSTProfiler(database_api=database_api, database_name=database_name, schema_id=schema_id)
|
Loading…
x
Reference in New Issue
Block a user