Added automatic sample name replacement
Some checks failed
ydeng/modvcfsamples/pipeline/head There was a failure building this commit

This commit is contained in:
Harrison Deng 2023-06-28 05:51:52 +00:00
parent 02f52560eb
commit 2d1d69f95f
3 changed files with 121 additions and 25 deletions

View File

@ -3,38 +3,54 @@ import os
from typing import Union
from modvcfsamples import sample
def run(vcfs: list[str], only: list[str], gt: Union[int, None], output_dir: str):
def run(
vcfs: list[str],
only: list[str],
gt: Union[int, None],
filename_replace: Union[str, None],
output_dir: str,
):
for vcf in vcfs:
vcf_records, header = sample.get_records_from_vcf(vcf)
modified_vcfs = vcf_records
modified_header = header
if len(only) > 0:
modified_vcfs, modified_header = sample.keep_specific_call_data(modified_vcfs, modified_header, *only)
modified_vcfs, modified_header = sample.keep_specific_call_data(
modified_vcfs, modified_header, *only
)
if gt is not None:
modified_vcfs, modified_header = sample.normalize_gt_to_length(modified_vcfs, modified_header, gt)
sample.write_records_to_vcf(modified_vcfs, modified_header, os.path.join(output_dir, os.path.basename(vcf)))
modified_vcfs = sample.normalize_gt_to_length(modified_vcfs, gt)
if filename_replace is not None:
modified_vcfs, modified_header = sample.replace_sample_names(
modified_vcfs,
modified_header,
filename_replace,
os.path.basename(os.path.splitext(vcf)[0]),
)
sample.write_records_to_vcf(
modified_vcfs,
modified_header,
os.path.join(output_dir, os.path.basename(vcf)),
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"vcfs",
help="The VCFs to run filtering on",
nargs="+",
metavar="I",
type=str
"vcfs", help="The VCFs to run filtering on", nargs="+", metavar="I", type=str
)
parser.add_argument(
"output_dir",
help="The output directory",
metavar="O",
type=str
"output_dir", help="The output directory", metavar="O", type=str
)
parser.add_argument(
"--gt-norm",
"-g",
help="Resizes haploid genotypes to n-ploid by repeating it.",
type=int,
required=False
required=False,
)
parser.add_argument(
"--only",
@ -42,11 +58,20 @@ def main():
help="Remove everything but the sample datatype",
action="append",
type=str,
required=False
required=False,
)
parser.add_argument(
"--filename-replace",
"-e",
help="Replaces (parts of or entire) sample names with file names. "
"Ex. '-e unknown' will replace all samples called 'unknown' with the "
"filename the samples came from.",
type=str,
required=False,
)
args = parser.parse_args()
run(args.vcfs, args.only, args.gt_norm, args.output_dir)
run(args.vcfs, args.only, args.gt_norm, args.filename_replace, args.output_dir)
if __name__ == "__main__":

View File

@ -17,9 +17,13 @@ def get_records_from_vcf(path: str):
return vcf_records, reader.header
def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str):
def keep_specific_call_data(
records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str
):
lines_kept = [line for line in header.lines if line.key != "FORMAT"]
lines_kept.extend([line for line in header.get_lines("FORMAT") if line.id in datatypes])
lines_kept.extend(
[line for line in header.get_lines("FORMAT") if line.id in datatypes]
)
modified_header = vcfpy.Header(lines_kept, header.samples)
modified_records = []
@ -45,18 +49,19 @@ def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, *
modified_records.append(modified_record)
return modified_records, modified_header
def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, num: int):
def normalize_gt_to_length(records: list[vcfpy.Record], num: int):
modified_records = []
for record in records:
modified_calls = []
for call in record.calls:
gt_parts = call.data['GT'].replace("/", "|").split("|")
gt_parts = call.data["GT"].replace("/", "|").split("|")
modified_call = deepcopy(call)
if len(gt_parts) > 1:
# TODO Add logging and output if gt_parts is longer.
pass
else:
modified_call.data['GT'] = "|".join([gt_parts[0]] * num)
modified_call.data["GT"] = "|".join([gt_parts[0]] * num)
modified_calls.append(modified_call)
modified_record = vcfpy.Record(
record.CHROM,
@ -71,9 +76,44 @@ def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, nu
modified_calls,
)
modified_records.append(modified_record)
return modified_records, header
return modified_records
def write_records_to_vcf(records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str):
def replace_sample_names(
records: list[vcfpy.Record], header: vcfpy.Header, to_replace: str, replaced: str
):
modified_header = vcfpy.Header(
header.lines,
vcfpy.SamplesInfos(
[sample.replace(to_replace, replaced) for sample in header.samples.names]
),
)
modified_records = []
for record in records:
modified_calls = []
for call in record.calls:
modified_calls.append(
Call(call.sample.replace(to_replace, replaced), call.data, call.site)
)
modified_record = vcfpy.Record(
record.CHROM,
record.POS,
record.ID,
record.REF,
record.ALT,
record.QUAL,
record.FILTER,
record.INFO,
record.FORMAT,
modified_calls,
)
modified_records.append(modified_record)
return modified_records, modified_header
def write_records_to_vcf(
records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str
):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as vcf_stream:
writer = vcfpy.Writer.from_stream(vcf_stream, header)

View File

@ -2,6 +2,7 @@ from modvcfsamples.sample import (
keep_specific_call_data,
get_records_from_vcf,
normalize_gt_to_length,
replace_sample_names,
)
import os
@ -32,7 +33,7 @@ def test_normalize_gt_to_length_not_empty():
records, header = get_records_from_vcf(
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
)
modified_records, _ = normalize_gt_to_length(records, header, 4)
modified_records = normalize_gt_to_length(records, 4)
assert len(modified_records) > 0
@ -40,7 +41,37 @@ def test_normalize_gt_to_length_gt_normalized():
records, header = get_records_from_vcf(
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
)
modified_records, _ = normalize_gt_to_length(records, header, 4)
modified_records = normalize_gt_to_length(records, 4)
for modified_record in modified_records:
for call in modified_record.calls:
assert len(call.data["GT"].split("|")) == 4 or "/" in call.data["GT"]
def test_replace_sample_names_record_modified_correctly():
records, header = get_records_from_vcf(
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
)
modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different")
for modified_record in modified_records:
call_name_replaced = False
for call in modified_record.calls:
assert call.sample != "Gambian"
if call.sample == "different":
call_name_replaced = True
assert call_name_replaced
def test_replace_sample_names_header_modified_correctly():
records, header = get_records_from_vcf(
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
)
modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different")
replaced_found = False
original_found = False
for name in modified_headers.samples.names:
if name == "Gambian":
original_found = True
if name == "different":
replaced_found = True
assert replaced_found
assert not original_found