generated from ydeng/python-program
Added automatic sample name replacement
Some checks failed
ydeng/modvcfsamples/pipeline/head There was a failure building this commit
Some checks failed
ydeng/modvcfsamples/pipeline/head There was a failure building this commit
This commit is contained in:
parent
02f52560eb
commit
2d1d69f95f
@ -3,38 +3,54 @@ import os
|
||||
from typing import Union
|
||||
from modvcfsamples import sample
|
||||
|
||||
def run(vcfs: list[str], only: list[str], gt: Union[int, None], output_dir: str):
|
||||
|
||||
def run(
|
||||
vcfs: list[str],
|
||||
only: list[str],
|
||||
gt: Union[int, None],
|
||||
filename_replace: Union[str, None],
|
||||
output_dir: str,
|
||||
):
|
||||
for vcf in vcfs:
|
||||
vcf_records, header = sample.get_records_from_vcf(vcf)
|
||||
modified_vcfs = vcf_records
|
||||
modified_header = header
|
||||
if len(only) > 0:
|
||||
modified_vcfs, modified_header = sample.keep_specific_call_data(modified_vcfs, modified_header, *only)
|
||||
modified_vcfs, modified_header = sample.keep_specific_call_data(
|
||||
modified_vcfs, modified_header, *only
|
||||
)
|
||||
if gt is not None:
|
||||
modified_vcfs, modified_header = sample.normalize_gt_to_length(modified_vcfs, modified_header, gt)
|
||||
sample.write_records_to_vcf(modified_vcfs, modified_header, os.path.join(output_dir, os.path.basename(vcf)))
|
||||
modified_vcfs = sample.normalize_gt_to_length(modified_vcfs, gt)
|
||||
|
||||
if filename_replace is not None:
|
||||
modified_vcfs, modified_header = sample.replace_sample_names(
|
||||
modified_vcfs,
|
||||
modified_header,
|
||||
filename_replace,
|
||||
os.path.basename(os.path.splitext(vcf)[0]),
|
||||
)
|
||||
|
||||
sample.write_records_to_vcf(
|
||||
modified_vcfs,
|
||||
modified_header,
|
||||
os.path.join(output_dir, os.path.basename(vcf)),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"vcfs",
|
||||
help="The VCFs to run filtering on",
|
||||
nargs="+",
|
||||
metavar="I",
|
||||
type=str
|
||||
"vcfs", help="The VCFs to run filtering on", nargs="+", metavar="I", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
help="The output directory",
|
||||
metavar="O",
|
||||
type=str
|
||||
"output_dir", help="The output directory", metavar="O", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gt-norm",
|
||||
"-g",
|
||||
help="Resizes haploid genotypes to n-ploid by repeating it.",
|
||||
type=int,
|
||||
required=False
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
@ -42,11 +58,20 @@ def main():
|
||||
help="Remove everything but the sample datatype",
|
||||
action="append",
|
||||
type=str,
|
||||
required=False
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filename-replace",
|
||||
"-e",
|
||||
help="Replaces (parts of or entire) sample names with file names. "
|
||||
"Ex. '-e unknown' will replace all samples called 'unknown' with the "
|
||||
"filename the samples came from.",
|
||||
type=str,
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run(args.vcfs, args.only, args.gt_norm, args.output_dir)
|
||||
run(args.vcfs, args.only, args.gt_norm, args.filename_replace, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,9 +17,13 @@ def get_records_from_vcf(path: str):
|
||||
return vcf_records, reader.header
|
||||
|
||||
|
||||
def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str):
|
||||
def keep_specific_call_data(
|
||||
records: list[vcfpy.Record], header: vcfpy.Header, *datatypes: str
|
||||
):
|
||||
lines_kept = [line for line in header.lines if line.key != "FORMAT"]
|
||||
lines_kept.extend([line for line in header.get_lines("FORMAT") if line.id in datatypes])
|
||||
lines_kept.extend(
|
||||
[line for line in header.get_lines("FORMAT") if line.id in datatypes]
|
||||
)
|
||||
modified_header = vcfpy.Header(lines_kept, header.samples)
|
||||
modified_records = []
|
||||
|
||||
@ -45,18 +49,19 @@ def keep_specific_call_data(records: list[vcfpy.Record], header: vcfpy.Header, *
|
||||
modified_records.append(modified_record)
|
||||
return modified_records, modified_header
|
||||
|
||||
def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, num: int):
|
||||
|
||||
def normalize_gt_to_length(records: list[vcfpy.Record], num: int):
|
||||
modified_records = []
|
||||
for record in records:
|
||||
modified_calls = []
|
||||
for call in record.calls:
|
||||
gt_parts = call.data['GT'].replace("/", "|").split("|")
|
||||
gt_parts = call.data["GT"].replace("/", "|").split("|")
|
||||
modified_call = deepcopy(call)
|
||||
if len(gt_parts) > 1:
|
||||
# TODO Add logging and output if gt_parts is longer.
|
||||
pass
|
||||
else:
|
||||
modified_call.data['GT'] = "|".join([gt_parts[0]] * num)
|
||||
modified_call.data["GT"] = "|".join([gt_parts[0]] * num)
|
||||
modified_calls.append(modified_call)
|
||||
modified_record = vcfpy.Record(
|
||||
record.CHROM,
|
||||
@ -71,9 +76,44 @@ def normalize_gt_to_length(records: list[vcfpy.Record], header: vcfpy.Header, nu
|
||||
modified_calls,
|
||||
)
|
||||
modified_records.append(modified_record)
|
||||
return modified_records, header
|
||||
return modified_records
|
||||
|
||||
def write_records_to_vcf(records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str):
|
||||
|
||||
def replace_sample_names(
|
||||
records: list[vcfpy.Record], header: vcfpy.Header, to_replace: str, replaced: str
|
||||
):
|
||||
modified_header = vcfpy.Header(
|
||||
header.lines,
|
||||
vcfpy.SamplesInfos(
|
||||
[sample.replace(to_replace, replaced) for sample in header.samples.names]
|
||||
),
|
||||
)
|
||||
modified_records = []
|
||||
for record in records:
|
||||
modified_calls = []
|
||||
for call in record.calls:
|
||||
modified_calls.append(
|
||||
Call(call.sample.replace(to_replace, replaced), call.data, call.site)
|
||||
)
|
||||
modified_record = vcfpy.Record(
|
||||
record.CHROM,
|
||||
record.POS,
|
||||
record.ID,
|
||||
record.REF,
|
||||
record.ALT,
|
||||
record.QUAL,
|
||||
record.FILTER,
|
||||
record.INFO,
|
||||
record.FORMAT,
|
||||
modified_calls,
|
||||
)
|
||||
modified_records.append(modified_record)
|
||||
return modified_records, modified_header
|
||||
|
||||
|
||||
def write_records_to_vcf(
|
||||
records: Iterable[vcfpy.Record], header: vcfpy.Header, path: str
|
||||
):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as vcf_stream:
|
||||
writer = vcfpy.Writer.from_stream(vcf_stream, header)
|
||||
|
@ -2,6 +2,7 @@ from modvcfsamples.sample import (
|
||||
keep_specific_call_data,
|
||||
get_records_from_vcf,
|
||||
normalize_gt_to_length,
|
||||
replace_sample_names,
|
||||
)
|
||||
import os
|
||||
|
||||
@ -32,7 +33,7 @@ def test_normalize_gt_to_length_not_empty():
|
||||
records, header = get_records_from_vcf(
|
||||
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
|
||||
)
|
||||
modified_records, _ = normalize_gt_to_length(records, header, 4)
|
||||
modified_records = normalize_gt_to_length(records, 4)
|
||||
assert len(modified_records) > 0
|
||||
|
||||
|
||||
@ -40,7 +41,37 @@ def test_normalize_gt_to_length_gt_normalized():
|
||||
records, header = get_records_from_vcf(
|
||||
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
|
||||
)
|
||||
modified_records, _ = normalize_gt_to_length(records, header, 4)
|
||||
modified_records = normalize_gt_to_length(records, 4)
|
||||
for modified_record in modified_records:
|
||||
for call in modified_record.calls:
|
||||
assert len(call.data["GT"].split("|")) == 4 or "/" in call.data["GT"]
|
||||
|
||||
def test_replace_sample_names_record_modified_correctly():
|
||||
records, header = get_records_from_vcf(
|
||||
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
|
||||
)
|
||||
modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different")
|
||||
for modified_record in modified_records:
|
||||
call_name_replaced = False
|
||||
for call in modified_record.calls:
|
||||
assert call.sample != "Gambian"
|
||||
if call.sample == "different":
|
||||
call_name_replaced = True
|
||||
assert call_name_replaced
|
||||
|
||||
|
||||
def test_replace_sample_names_header_modified_correctly():
|
||||
records, header = get_records_from_vcf(
|
||||
os.path.abspath("tests/resources/test_files_shortened_haploid.vcf")
|
||||
)
|
||||
modified_records, modified_headers = replace_sample_names(records, header, "Gambian", "different")
|
||||
replaced_found = False
|
||||
original_found = False
|
||||
for name in modified_headers.samples.names:
|
||||
if name == "Gambian":
|
||||
original_found = True
|
||||
|
||||
if name == "different":
|
||||
replaced_found = True
|
||||
assert replaced_found
|
||||
assert not original_found
|
Loading…
x
Reference in New Issue
Block a user