Compare commits
	
		
			9 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 29fcf8c176 | |||
| 8264242fa5 | |||
| f8d92a4aad | |||
| 34bf02c75a | |||
| 3cb10a4609 | |||
| 1776f5aa51 | |||
| 96d715fdcb | |||
| e088d1080b | |||
| 8ffc7c7fb5 | 
							
								
								
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							@@ -1,5 +1,6 @@
 | 
			
		||||
{
 | 
			
		||||
    "recommendations": [
 | 
			
		||||
        "piotrpalarz.vscode-gitignore-generator"
 | 
			
		||||
        "piotrpalarz.vscode-gitignore-generator",
 | 
			
		||||
        "gruntfuggly.todo-tree"
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
@@ -9,13 +9,13 @@ import shutil
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Iterable, Mapping, Sequence, Set, Union
 | 
			
		||||
 | 
			
		||||
from aiohttp import ClientSession, ClientTimeout
 | 
			
		||||
from aiohttp import ClientOSError, ClientSession, ClientTimeout, ServerDisconnectedError
 | 
			
		||||
 | 
			
		||||
from autobigs.engine.reading import read_fasta
 | 
			
		||||
from autobigs.engine.structures.alignment import PairwiseAlignment
 | 
			
		||||
from autobigs.engine.structures.genomics import NamedString
 | 
			
		||||
from autobigs.engine.structures.mlst import Allele, NamedMLSTProfile, AlignmentStats, MLSTProfile
 | 
			
		||||
from autobigs.engine.exceptions.database import NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
 | 
			
		||||
from autobigs.engine.exceptions.database import BIGSdbResponseNotOkay, NoBIGSdbExactMatchesException, NoBIGSdbMatchesException, NoSuchBIGSdbDatabaseException
 | 
			
		||||
 | 
			
		||||
from Bio.Align import PairwiseAligner
 | 
			
		||||
 | 
			
		||||
@@ -43,11 +43,12 @@ class BIGSdbMLSTProfiler(AbstractAsyncContextManager):
 | 
			
		||||
 | 
			
		||||
class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, database_api: str, database_name: str, scheme_id: int):
 | 
			
		||||
    def __init__(self, database_api: str, database_name: str, scheme_id: int, retry_requests: int = 5):
 | 
			
		||||
        self._retry_limit = retry_requests
 | 
			
		||||
        self._database_name = database_name
 | 
			
		||||
        self._scheme_id = scheme_id
 | 
			
		||||
        self._base_url = f"{database_api}/db/{self._database_name}/schemes/{self._scheme_id}/"
 | 
			
		||||
        self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(60))
 | 
			
		||||
        self._http_client = ClientSession(self._base_url, timeout=ClientTimeout(300))
 | 
			
		||||
 | 
			
		||||
    async def __aenter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
@@ -57,11 +58,19 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | 
			
		||||
        uri_path = "sequence"
 | 
			
		||||
        if isinstance(query_sequence_strings, str) or isinstance(query_sequence_strings, NamedString):
 | 
			
		||||
            query_sequence_strings = [query_sequence_strings]
 | 
			
		||||
        
 | 
			
		||||
        for sequence_string in query_sequence_strings:
 | 
			
		||||
            async with self._http_client.post(uri_path, json={
 | 
			
		||||
            attempts = 0
 | 
			
		||||
            success = False
 | 
			
		||||
            last_error = None
 | 
			
		||||
            while not success and attempts < self._retry_limit:
 | 
			
		||||
                attempts += 1
 | 
			
		||||
                request = self._http_client.post(uri_path, json={
 | 
			
		||||
                    "sequence": sequence_string if isinstance(sequence_string, str) else sequence_string.sequence,
 | 
			
		||||
                    "partial_matches": True
 | 
			
		||||
            }) as response:
 | 
			
		||||
                })
 | 
			
		||||
                try:
 | 
			
		||||
                    async with request as response:
 | 
			
		||||
                        sequence_response: dict = await response.json()
 | 
			
		||||
 | 
			
		||||
                        if "exact_matches" in sequence_response:
 | 
			
		||||
@@ -90,7 +99,21 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | 
			
		||||
                                )
 | 
			
		||||
                                yield result_allele if isinstance(sequence_string, str) else (sequence_string.name, result_allele)
 | 
			
		||||
                        else:
 | 
			
		||||
                            if response.status == 200:
 | 
			
		||||
                                raise NoBIGSdbMatchesException(self._database_name, self._scheme_id, sequence_string.name if isinstance(sequence_string, NamedString) else None)
 | 
			
		||||
                            else:
 | 
			
		||||
                                raise BIGSdbResponseNotOkay(sequence_response)
 | 
			
		||||
                except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Errors we will retry
 | 
			
		||||
                    last_error = e
 | 
			
		||||
                    success = False
 | 
			
		||||
                    await asyncio.sleep(5) # In case the connection issue is due to rate issues
 | 
			
		||||
                else:
 | 
			
		||||
                    success = True
 | 
			
		||||
            if not success and last_error is not None:
 | 
			
		||||
                try:
 | 
			
		||||
                    raise last_error
 | 
			
		||||
                except (ConnectionError, ServerDisconnectedError, ClientOSError) as e: # Non-fatal errors
 | 
			
		||||
                    yield Allele("error", "error", None)
 | 
			
		||||
 | 
			
		||||
    async def determine_mlst_st(self, alleles: Union[AsyncIterable[Union[Allele, tuple[str, Allele]]], Iterable[Union[Allele, tuple[str, Allele]]]]) -> Union[MLSTProfile, NamedMLSTProfile]:
 | 
			
		||||
        uri_path = "designations"
 | 
			
		||||
@@ -113,6 +136,13 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | 
			
		||||
        request_json = {
 | 
			
		||||
            "designations": allele_request_dict
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        attempts = 0
 | 
			
		||||
        success = False
 | 
			
		||||
        last_error = None
 | 
			
		||||
        while attempts < self._retry_limit and not success:
 | 
			
		||||
            attempts += 1
 | 
			
		||||
            try:
 | 
			
		||||
                async with self._http_client.post(uri_path, json=request_json) as response:
 | 
			
		||||
                    response_json: dict = await response.json()
 | 
			
		||||
                    allele_set: Set[Allele] = set()
 | 
			
		||||
@@ -129,6 +159,19 @@ class RemoteBIGSdbMLSTProfiler(BIGSdbMLSTProfiler):
 | 
			
		||||
                    if len(names_list) > 0:
 | 
			
		||||
                        result_mlst_profile = NamedMLSTProfile(str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0], result_mlst_profile)
 | 
			
		||||
                    return result_mlst_profile
 | 
			
		||||
            except (ConnectionError, ServerDisconnectedError, ClientOSError) as e:
 | 
			
		||||
                last_error = e
 | 
			
		||||
                success = False
 | 
			
		||||
                await asyncio.sleep(5)
 | 
			
		||||
            else:
 | 
			
		||||
                success = True
 | 
			
		||||
        try:
 | 
			
		||||
            if last_error is not None:
 | 
			
		||||
                raise last_error
 | 
			
		||||
        except (ConnectionError, ServerDisconnectedError, ClientOSError) as e:
 | 
			
		||||
                result_mlst_profile = NamedMLSTProfile((str(tuple(names_list)) if len(set(names_list)) > 1 else names_list[0]) + ":Error", None)
 | 
			
		||||
        raise ValueError("Last error was not recorded.")
 | 
			
		||||
 | 
			
		||||
                    
 | 
			
		||||
    async def profile_string(self, query_sequence_strings: Iterable[Union[NamedString, str]]) -> Union[NamedMLSTProfile, MLSTProfile]:
 | 
			
		||||
        alleles = self.determine_mlst_allele_variants(query_sequence_strings)
 | 
			
		||||
@@ -212,6 +255,16 @@ class BIGSdbIndex(AbstractAsyncContextManager):
 | 
			
		||||
    async def build_profiler_from_seqdefdb(self, local: bool, dbseqdef_name: str, scheme_id: int) -> BIGSdbMLSTProfiler:
 | 
			
		||||
        return get_BIGSdb_MLST_profiler(local, await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name), dbseqdef_name, scheme_id)
 | 
			
		||||
    
 | 
			
		||||
    async def get_scheme_loci(self, dbseqdef_name: str, scheme_id: int) -> list[str]:
 | 
			
		||||
        uri_path = f"{await self.get_bigsdb_api_from_seqdefdb(dbseqdef_name)}/db/{dbseqdef_name}/schemes/{scheme_id}"
 | 
			
		||||
        async with self._http_client.get(uri_path) as response:
 | 
			
		||||
            response_json = await response.json()
 | 
			
		||||
            loci = response_json["loci"]
 | 
			
		||||
            results = []
 | 
			
		||||
            for locus in loci:
 | 
			
		||||
                results.append(path.basename(locus))
 | 
			
		||||
            return results
 | 
			
		||||
 | 
			
		||||
    async def close(self):
 | 
			
		||||
        await self._http_client.close()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,8 @@ from typing import Union
 | 
			
		||||
class BIGSDbDatabaseAPIException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
class BIGSdbResponseNotOkay(BIGSDbDatabaseAPIException):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
class NoBIGSdbMatchesException(BIGSDbDatabaseAPIException):
 | 
			
		||||
    def __init__(self, database_name: str, database_scheme_id: int, query_name: Union[None, str], *args):
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
import csv
 | 
			
		||||
from os import PathLike
 | 
			
		||||
from typing import AsyncIterable, Collection, Mapping, Sequence, Union
 | 
			
		||||
from typing import AsyncIterable, Collection, Iterable, Mapping, Sequence, Union
 | 
			
		||||
 | 
			
		||||
from autobigs.engine.structures.mlst import Allele, MLSTProfile, NamedMLSTProfile
 | 
			
		||||
 | 
			
		||||
@@ -17,7 +17,7 @@ def alleles_to_text_map(alleles: Collection[Allele]) -> Mapping[str, Union[Seque
 | 
			
		||||
            result[locus] = tuple(result[locus]) # type: ignore
 | 
			
		||||
    return dict(result)
 | 
			
		||||
 | 
			
		||||
async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[NamedMLSTProfile], handle: Union[str, bytes, PathLike[str], PathLike[bytes]]) -> Sequence[str]:
 | 
			
		||||
async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[NamedMLSTProfile], handle: Union[str, bytes, PathLike[str], PathLike[bytes]], allele_names: Iterable[str]) -> Sequence[str]:
 | 
			
		||||
    failed = list()
 | 
			
		||||
    with open(handle, "w", newline='') as filehandle:
 | 
			
		||||
        header = None
 | 
			
		||||
@@ -30,7 +30,7 @@ async def write_mlst_profiles_as_csv(mlst_profiles_iterable: AsyncIterable[Named
 | 
			
		||||
                continue
 | 
			
		||||
            allele_mapping = alleles_to_text_map(mlst_profile.alleles)
 | 
			
		||||
            if writer is None:
 | 
			
		||||
                header = ["id", "st", "clonal-complex", *sorted(allele_mapping.keys())]
 | 
			
		||||
                header = ["id", "st", "clonal-complex", *sorted(allele_names)]
 | 
			
		||||
                writer = csv.DictWriter(filehandle, fieldnames=header)
 | 
			
		||||
                writer.writeheader()
 | 
			
		||||
            row_dictionary = {
 | 
			
		||||
 
 | 
			
		||||
@@ -222,3 +222,12 @@ class TestBIGSdbIndex:
 | 
			
		||||
                assert isinstance(profile, MLSTProfile)
 | 
			
		||||
                assert profile.clonal_complex == "ST-2 complex"
 | 
			
		||||
                assert profile.sequence_type == "1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize(["bigsdb_name", "scheme_id", "expected"], [
 | 
			
		||||
        ("pubmlst_bordetella_seqdef", 3, ["adk", "fumC", "glyA", "tyrB", "icd", "pepA", "pgm"])
 | 
			
		||||
    ])
 | 
			
		||||
    async def test_bigsdb_index_fetches_loci_names(self, bigsdb_name, scheme_id, expected):
 | 
			
		||||
        async with BIGSdbIndex() as bigsdb_index:
 | 
			
		||||
            loci = await bigsdb_index.get_scheme_loci(bigsdb_name, scheme_id)
 | 
			
		||||
            assert set(loci) == set(expected)
 | 
			
		||||
@@ -27,7 +27,7 @@ async def test_column_order_is_same_as_expected_file(dummy_alphabet_mlst_profile
 | 
			
		||||
    dummy_profiles = [dummy_alphabet_mlst_profile]
 | 
			
		||||
    with tempfile.TemporaryDirectory() as temp_dir:
 | 
			
		||||
        output_path = path.join(temp_dir, "out.csv")
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path)
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"])
 | 
			
		||||
        with open(output_path) as csv_handle:
 | 
			
		||||
            csv_reader = reader(csv_handle)
 | 
			
		||||
            lines = list(csv_reader)
 | 
			
		||||
@@ -38,7 +38,7 @@ async def test_csv_writing_sample_name_not_repeated_when_single_sequence(dummy_a
 | 
			
		||||
    dummy_profiles = [dummy_alphabet_mlst_profile]
 | 
			
		||||
    with tempfile.TemporaryDirectory() as temp_dir:
 | 
			
		||||
        output_path = path.join(temp_dir, "out.csv")
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path)
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"])
 | 
			
		||||
        with open(output_path) as csv_handle:
 | 
			
		||||
            csv_reader = reader(csv_handle)
 | 
			
		||||
            lines = list(csv_reader)
 | 
			
		||||
@@ -63,7 +63,7 @@ async def test_csv_writing_includes_asterisk_for_non_exact(dummy_alphabet_mlst_p
 | 
			
		||||
    dummy_profiles = [dummy_alphabet_mlst_profile]
 | 
			
		||||
    with tempfile.TemporaryDirectory() as temp_dir:
 | 
			
		||||
        output_path = path.join(temp_dir, "out.csv")
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path)
 | 
			
		||||
        await write_mlst_profiles_as_csv(iterable_to_asynciterable(dummy_profiles), output_path, ["A", "D", "B", "C"])
 | 
			
		||||
        with open(output_path) as csv_handle:
 | 
			
		||||
            csv_reader = reader(csv_handle)
 | 
			
		||||
            lines = list(csv_reader)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user