diff --git a/src/modvcfsamples/cli.py b/src/modvcfsamples/cli.py index 505e827..113cc96 100644 --- a/src/modvcfsamples/cli.py +++ b/src/modvcfsamples/cli.py @@ -3,38 +3,54 @@ import os from typing import Union from modvcfsamples import sample -def run(vcfs: list[str], only: list[str], gt: Union[int, None], output_dir: str): + +def run( + vcfs: list[str], + only: list[str], + gt: Union[int, None], + filename_replace: Union[str, None], + output_dir: str, +): for vcf in vcfs: vcf_records, header = sample.get_records_from_vcf(vcf) modified_vcfs = vcf_records modified_header = header if len(only) > 0: - modified_vcfs, modified_header = sample.keep_specific_call_data(modified_vcfs, modified_header, *only) + modified_vcfs, modified_header = sample.keep_specific_call_data( + modified_vcfs, modified_header, *only + ) if gt is not None: - modified_vcfs, modified_header = sample.normalize_gt_to_length(modified_vcfs, modified_header, gt) - sample.write_records_to_vcf(modified_vcfs, modified_header, os.path.join(output_dir, os.path.basename(vcf))) + modified_vcfs = sample.normalize_gt_to_length(modified_vcfs, gt) + + if filename_replace is not None: + modified_vcfs, modified_header = sample.replace_sample_names( + modified_vcfs, + modified_header, + filename_replace, + os.path.basename(os.path.splitext(vcf)[0]), + ) + + sample.write_records_to_vcf( + modified_vcfs, + modified_header, + os.path.join(output_dir, os.path.basename(vcf)), + ) + def main(): parser = argparse.ArgumentParser() parser.add_argument( - "vcfs", - help="The VCFs to run filtering on", - nargs="+", - metavar="I", - type=str + "vcfs", help="The VCFs to run filtering on", nargs="+", metavar="I", type=str ) parser.add_argument( - "output_dir", - help="The output directory", - metavar="O", - type=str + "output_dir", help="The output directory", metavar="O", type=str ) parser.add_argument( "--gt-norm", "-g", help="Resizes haploid genotypes to n-ploid by repeating it.", type=int, - required=False + required=False, ) parser.add_argument( "--only", @@ -42,11 +58,20 @@ def main(): help="Remove everything but the sample datatype", action="append", type=str, - required=False + required=False, + ) + parser.add_argument( + "--filename-replace", + "-e", + help="Replaces (parts of or entire) sample names with file names. " + "Ex. '-e unknown' will replace all samples called 'unknown' with the " + "filename the samples came from.", + type=str, + required=False, ) args = parser.parse_args() - run(args.vcfs, args.only, args.gt_norm, args.output_dir) + run(args.vcfs, args.only, args.gt_norm, args.filename_replace, args.output_dir) if __name__ == "__main__": diff --git a/src/modvcfsamples/sample.py b/src/modvcfsamples/sample.py index daddc4e..0f4673f 100644 --- a/src/modvcfsamples/sample.py +++ b/src/modvcfsamples/sample.py @@ -17,9 +17,13 @@ def get_records_from_vcf(path: str): return vcf_records, reader.header -def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str): +def keep_specific_call_data( + records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str +): lines_kept = [line for line in header.lines if line.key != "FORMAT"] - lines_kept.extend([line for line in header.get_lines("FORMAT") if line.id in datatypes]) + lines_kept.extend( + [line for line in header.get_lines("FORMAT") if line.id in datatypes] + ) modified_header = vcfpy.Header(lines_kept, header.samples) modified_records = [] @@ -45,18 +49,19 @@ def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, * modified_records.append(modified_record) return modified_records, modified_header -def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, num: int): + +def normalize_gt_to_length(records: list[vcfpy.Record], num: int): modified_records = [] for record in records: modified_calls = [] for call in record.calls: - gt_parts = call.data['GT'].replace("/", "|").split("|") + gt_parts = call.data["GT"].replace("/", "|").split("|") modified_call = deepcopy(call) if len(gt_parts) > 1: # TODO Add logging and output if gt_parts is longer. pass else: - modified_call.data['GT'] = "|".join([gt_parts[0]] * num) + modified_call.data["GT"] = "|".join([gt_parts[0]] * num) modified_calls.append(modified_call) modified_record = vcfpy.Record( record.CHROM, @@ -71,9 +76,44 @@ def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, nu modified_calls, ) modified_records.append(modified_record) - return modified_records, header + return modified_records -def write_records_to_vcf(records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str): + +def replace_sample_names( + records: list[vcfpy.Record], header: vcfpy.Header, to_replace: str, replaced: str +): + modified_header = vcfpy.Header( + header.lines, + vcfpy.SamplesInfos( + [sample.replace(to_replace, replaced) for sample in header.samples.names] + ), + ) + modified_records = [] + for record in records: + modified_calls = [] + for call in record.calls: + modified_calls.append( + Call(call.sample.replace(to_replace, replaced), call.data, call.site) + ) + modified_record = vcfpy.Record( + record.CHROM, + record.POS, + record.ID, + record.REF, + record.ALT, + record.QUAL, + record.FILTER, + record.INFO, + record.FORMAT, + modified_calls, + ) + modified_records.append(modified_record) + return modified_records, modified_header + + +def write_records_to_vcf( + records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str +): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as vcf_stream: writer = vcfpy.Writer.from_stream(vcf_stream, header) diff --git a/tests/modvcfsamples/test_sample.py b/tests/modvcfsamples/test_sample.py index 2f1002b..a2bec09 100644 --- a/tests/modvcfsamples/test_sample.py +++ b/tests/modvcfsamples/test_sample.py @@ -2,6 +2,7 @@ from modvcfsamples.sample import ( keep_specific_call_data, get_records_from_vcf, normalize_gt_to_length, + replace_sample_names, ) import os @@ -32,7 +33,7 @@ def test_normalize_gt_to_length_not_empty(): records, header = get_records_from_vcf( os.path.abspath("tests/resources/test_files_shortened_haploid.vcf") ) - modified_records, _ = normalize_gt_to_length(records, header, 4) + modified_records = normalize_gt_to_length(records, 4) assert len(modified_records) > 0 @@ -40,7 +41,37 @@ def test_normalize_gt_to_length_gt_normalized(): records, header = get_records_from_vcf( os.path.abspath("tests/resources/test_files_shortened_haploid.vcf") ) - modified_records, _ = normalize_gt_to_length(records, header, 4) + modified_records = normalize_gt_to_length(records, 4) for modified_record in modified_records: for call in modified_record.calls: assert len(call.data["GT"].split("|")) == 4 or "/" in call.data["GT"] + +def test_replace_sample_names_record_modified_correctly(): + records, header = get_records_from_vcf( + os.path.abspath("tests/resources/test_files_shortened_haploid.vcf") + ) + modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different") + for modified_record in modified_records: + call_name_replaced = False + for call in modified_record.calls: + assert call.sample != "Gambian" + if call.sample == "different": + call_name_replaced = True + assert call_name_replaced + + +def test_replace_sample_names_header_modified_correctly(): + records, header = get_records_from_vcf( + os.path.abspath("tests/resources/test_files_shortened_haploid.vcf") + ) + modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different") + replaced_found = False + original_found = False + for name in modified_headers.samples.names: + if name == "Gambian": + original_found = True + + if name == "different": + replaced_found = True + assert replaced_found + assert not original_found \ No newline at end of file