diff --git a/src/nsbdiagnosistoolkit/cli/aggregator.py b/src/nsbdiagnosistoolkit/cli/aggregator.py index fd2b07e..d5a1535 100644 --- a/src/nsbdiagnosistoolkit/cli/aggregator.py +++ b/src/nsbdiagnosistoolkit/cli/aggregator.py @@ -7,17 +7,15 @@ from nsbdiagnosistoolkit.engine.local.fasta import read_fasta from nsbdiagnosistoolkit.engine.remote.databases.institutpasteur.profiling import InstitutPasteurProfiler -async def aggregate_sequences(fastas: Iterable[str], abifs: Iterable[str]) -> AsyncGenerator[str, Any]: +async def read_all_fastas(fastas: Iterable[str]) -> AsyncGenerator[NamedString, Any]: for fasta_path in fastas: async for fasta in read_fasta(fasta_path): - yield fasta.sequence - for abif_path in abifs: - abif_data = await read_abif(abif_path) - yield "".join(abif_data.sequence) + yield fasta -async def profile_all_genetic_strings(strings: AsyncIterable[str], database_name: str) -> Sequence[MLSTProfile]: + +async def profile_all_genetic_strings(strings: AsyncIterable[NamedString], database_name: str) -> Sequence[tuple[str, MLSTProfile]]: profiles = list() async with InstitutPasteurProfiler(database_name=database_name) as profiler: - async for string in strings: - profiles.append(await profiler.profile_string(string)) + async for named_string in strings: + profiles.append((named_string.name, await profiler.profile_string(named_string.sequence))) return profiles \ No newline at end of file diff --git a/src/nsbdiagnosistoolkit/cli/root.py b/src/nsbdiagnosistoolkit/cli/root.py index c573965..75dc033 100644 --- a/src/nsbdiagnosistoolkit/cli/root.py +++ b/src/nsbdiagnosistoolkit/cli/root.py @@ -57,14 +57,17 @@ parser.add_argument( def cli(): args = parser.parse_args() - gen_strings = aggregator.aggregate_sequences(args.fastas, args.abifs) + gen_strings = aggregator.read_all_fastas(args.fastas) os.makedirs(args.out, exist_ok=True) if args.institut_pasteur_db is not None: mlst_profiles = aggregator.profile_all_genetic_strings( gen_strings, args.institut_pasteur_db) asyncio.run(write_mlst_profiles_as_csv( - asyncio.run(mlst_profiles), str(path.join(args.out, "MLST_" + args.run_name + ".csv")))) + asyncio.run(mlst_profiles), + str(path.join(args.out, "MLST_" + args.run_name + ".csv") + ) + )) if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/src/nsbdiagnosistoolkit/engine/local/csv.py b/src/nsbdiagnosistoolkit/engine/local/csv.py index 6165606..9d17168 100644 --- a/src/nsbdiagnosistoolkit/engine/local/csv.py +++ b/src/nsbdiagnosistoolkit/engine/local/csv.py @@ -1,7 +1,7 @@ import csv from io import TextIOWrapper from os import PathLike -from typing import AsyncIterable, Iterable, Mapping, Sequence, Union +from typing import AsyncIterable, Iterable, Mapping, Sequence, Tuple, Union from nsbdiagnosistoolkit.engine.data.MLST import Allele, MLSTProfile @@ -15,16 +15,17 @@ def loci_alleles_variants_from_loci(alleles_map: Mapping[str, Sequence[Allele]]) return result_dict -async def write_mlst_profiles_as_csv(mlst_profiles_iterable: Iterable[MLSTProfile], handle: Union[str, bytes, PathLike[str], PathLike[bytes]]): +async def write_mlst_profiles_as_csv(mlst_profiles_iterable: Iterable[tuple[str, MLSTProfile]], handle: Union[str, bytes, PathLike[str], PathLike[bytes]]): mlst_profiles = list(mlst_profiles_iterable) - header = ["st", "clonal-complex", *mlst_profiles[0].alleles.keys()] + header = ["name", "st", "clonal-complex", *mlst_profiles[0][1].alleles.keys()] with open(handle, "w", newline='') as filehandle: writer = csv.DictWriter(filehandle, fieldnames=header) writer.writeheader() - for mlst_profile in mlst_profiles: + for name, mlst_profile in mlst_profiles: row_dictionary = { "st": mlst_profile.sequence_type, "clonal-complex": mlst_profile.clonal_complex, + "name": name, **loci_alleles_variants_from_loci(mlst_profile.alleles) }