"""_summary_
author: Herman Huanan Zhao
email: hermanzhaozzzz@gmail.com
homepage: https://github.com/hermanzhaozzzz
_description_
example 1:
bioat list
<in shell>:
$ bioat list
<in python consolo>:
>>> from bioat.cli import Cli
>>> bioat = Cli()
>>> bioat.list()
>>> print(bioat.list())
example 2:
_example_
"""
import gzip
import os.path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from Bio.Seq import Seq
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from tabulate import tabulate
from bioat.exceptions import (
BioatFileFormatError,
BioatInvalidInputError,
BioatInvalidOptionError,
)
from bioat.lib.libalignment import instantiate_pairwise_aligner
from bioat.lib.libcolor import convert_hex_to_rgb, make_color_list, map_color
from bioat.lib.libcrispr import TARGET_SEQ_LIB, run_target_seq_align
from bioat.lib.libpandas import set_option
from bioat.logger import LoggerManager
lm = LoggerManager(mod_name="bioat.target_seq")
set_option(log_level='ERROR')
[docs]
class TargetSeq:
"""Target Deep Sequencing toolbox."""
lm.set_names(cls_name="TargetSeq")
def __init__(self):
pass
[docs]
def region_heatmap(
self,
input_table: str,
output_fig: str,
target_seq: str = None, # sgRNA
reference_seq: str = None, # target locus
input_table_header: bool = True,
output_fig_fmt: str = 'pdf',
output_fig_dpi: int = 100,
show_indel: bool = True,
show_index: bool = True,
box_border: bool = False,
box_space: float = 0.03,
min_color: tuple = (250, 239, 230),
max_color: tuple = (154, 104, 57),
min_ratio: float = 0.001,
max_ratio: float = 0.99,
region_extend_length: int = 5,
local_alignment_scoring_matrix: tuple = (5, -4, -24, -8),
local_alignment_min_score: int = 15,
PAM_priority_weight: float = 1.0,
get_built_in_target_seq: bool = False,
log_level: str = 'INFO'
):
"""Plot region mutation information using a table generated by `bioat bam mpileup_to_table`.
This function generates a visualization of mutation information for a specific genomic region
based on a table created by the `bioat bam mpileup_to_table` command.
Args:
input_table (str): Path to the table generated by `bioat bam mpileup_to_table`.
This table should contain base mutation information for a short genome region (no more than 1k nt).
output_fig (str): Path to the output figure file.
target_seq (str, optional): Target sequence to align against the reference sequence in `mpileup.table`.
Examples:
- 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^).
- '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^).
- 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM.
Defaults to None.
reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`.
Can be a FASTA file or a DNA sequence. Defaults to None.
input_table_header (bool, optional): Whether the `input_table` contains a header. Defaults to True.
output_fig_fmt (str, optional): Format of the output figure ("pdf" or "png"). Defaults to "pdf".
output_fig_dpi (int, optional): DPI for the output figure. Defaults to 300.
show_indel (bool, optional): Whether to show indel information in the output figure. Defaults to True.
show_index (bool, optional): Whether to display index information in the output figure. Defaults to True.
box_border (bool, optional): Whether to display box borders in the output figure. Defaults to True.
box_space (int, optional): Space size between two boxes. Defaults to 1.
min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255).
max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0).
min_ratio (float, optional): Mutation ratio below `min_ratio` will be shown as white. Defaults to 0.0.
max_ratio (float, optional): Mutation ratio above `max_ratio` will be capped. Defaults to 1.0.
region_extend_length (int, optional): Number of base pairs to extend on either side of the region. Defaults to 0.
local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple:
(<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>).
Defaults to None.
local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0.
PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0.
get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False.
log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO".
Returns:
None
Examples:
>>> # Generate mutation table
>>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info.tsv
>>> # Generate region heatmap
>>> bioat target_seq region_heatmap --input_table test_sorted.mpileup.info.tsv --output_fig test_sorted.mpileup.info.pdf
"""
lm.set_names(func_name="region_heatmap")
lm.set_log_level(log_level)
if get_built_in_target_seq:
lm.logger.info(
"You can use <key> in built-in <target_seq> to represent your target_seq:\n"
+ "\t<key>\t<target_seq>\n"
+ "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()])
)
lm.logger.info("exit because of the defination for get_built_in_target_seq")
exit(0)
else:
# set sgRNA info
lm.logger.debug(f"load target_seq = {target_seq}")
if target_seq in TARGET_SEQ_LIB:
lm.logger.debug(
f"target_seq information is abtained from the <key>={target_seq}"
)
target_seq = TARGET_SEQ_LIB[target_seq]
lm.logger.debug(f"target_seq is refered to {target_seq}")
elif target_seq is None:
pass
elif isinstance(target_seq, str):
lm.logger.debug(f"target_seq is refered to {target_seq}")
else:
lm.logger.error(
"<target_seq> should be set as:"
"DNA sequence / <key> in <get_built_in_target_seq> / None,"
f"but your <target_seq> is {target_seq}"
)
exit(1)
# set alignment
aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix)
# load mpileup.info.tsv
if input_table_header:
df_bases = pd.read_csv(input_table, sep="\t")
else:
df_bases = pd.read_csv(input_table, sep="\t", header=None)
lm.logger.debug(
"df_bases.head(10):\n"
+ tabulate(df_bases.head(10), headers="keys", tablefmt="psql")
)
# ref_seq & ref_seq_rc
if reference_seq:
# fa file or seq str to seq str
if os.path.isfile(reference_seq):
f = open(reference_seq, 'rt') if not reference_seq.endswith('.gz') else gzip.open(reference_seq, 'rt')
reference_seq = ''.join([i.rstrip() for i in f.readlines()[1:]])
if reference_seq.count('>') > 0:
raise BioatFileFormatError
ref_seq = Seq(reference_seq)
lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}")
else:
ref_seq = Seq("".join(df_bases.ref_base))
lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}")
lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}")
ref_seq_rc = ref_seq.reverse_complement()
lm.logger.debug(
f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}"
)
# target_seq
if target_seq is None:
target_seq = ref_seq
PAM = None
elif '^' in target_seq:
if target_seq.find('^') == 0:
PAM, target_seq = target_seq[1:].strip().split('^')
target_seq = Seq(target_seq.upper())
PAM = {'PAM': PAM, 'position': 0, 'weight': PAM_priority_weight}
else:
target_seq, PAM = target_seq[:-1].strip().split('^')
target_seq = Seq(target_seq.upper())
PAM = {'PAM': PAM, 'position': len(target_seq), 'weight': PAM_priority_weight}
else:
lm.logger.debug(
f"no PAM sequence is defined from target_seq = {target_seq}"
)
target_seq = Seq(target_seq)
PAM = None
lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}")
# fwd alignment & rev alignment
# 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上
if not PAM:
lm.logger.info("use PAMless mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
else:
lm.logger.info("use PAM mode to align")
# fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM)
# lm.logger.debug(f'final_align_forward:\n{fwd}')
# rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM)
# lm.logger.debug(f'final_align_reverse:\n{rev}')
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
# get fwd alignment info
reference_seq = fwd['alignment']['reference_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
align_info = fwd['alignment']['aln_info'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
target_seq = fwd['alignment']['target_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
aln_score = fwd['aln_score']
lm.logger.info(
f"Forward best alignment:\n"
f"reference : {reference_seq}\n"
f"align_info : {align_info}\n"
f"target_seq : {target_seq}\n"
f"align_score: {aln_score}"
)
# get rev alignment info
lm.logger.info(rev)
reference_seq_rev = rev['alignment']['reference_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
align_info_rev = rev['alignment']['aln_info'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
target_seq_rev = rev['alignment']['target_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
aln_score_rev = rev['aln_score']
lm.logger.info(
f"Reverse best alignment:\n"
f"reference : {reference_seq_rev}\n"
f"align_info : {align_info_rev}\n"
f"target_seq : {target_seq_rev}\n"
f"align_score: {aln_score_rev}"
)
# define align direction and final align res
# make alignment info
if any([x >= local_alignment_min_score for x in (fwd['aln_score'], rev['aln_score'])]):
if fwd['aln_score'] >= rev['aln_score']:
aln_direction = "Forward Alignment"
aln = fwd
else:
aln_direction = "Reverse Alignment"
aln = rev
reference_seq = reference_seq_rev
align_info = align_info_rev
target_seq = target_seq_rev
aln_score = aln_score_rev
else:
lm.logger.critical("Alignment Error!")
exit(1)
# mark aligned all bases in target_seq
ref_seq_length = len(ref_seq)
target_seq_aln = [""] * ref_seq_length
# mark inserted bases in target_seq
target_seq_aln_insert = [""] * ref_seq_length
DNA_rev_cmp_dict = {
"A": "T",
"T": "A",
"C": "G",
"G": "C",
"N": "N",
"-": "-"
}
reference_gap_count = reference_seq.count("-")
target_seq_gap_count = target_seq.count("-")
ref_gap_count = 0
ref_del_str = ""
aln_start = aln['ref_aln_start']
aln_end = aln['ref_aln_end']
# update target_seq_aln
lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (before update) = {target_seq_aln_insert}"
)
for idx, ref_base in enumerate(reference_seq):
# reference : GCCTCTGGAGAGGGAGGAGGG
# align_info : |.|.|||..|..|||||.||-
# target_seq : GGCACTGCGGCTGGAGGTGG-
if ref_base != "-":
ref_gap_count += 0
ref_del_str = ""
target_seq_aln[aln_start + idx - ref_gap_count] = ref_del_str + target_seq[idx]
else:
ref_gap_count += 1
ref_del_str += target_seq[idx] # for continuous gap
target_seq_aln_insert[aln_start + idx] = target_seq[idx]
lm.logger.debug(
"update target_seq_aln: show ''.join(target_seq_aln)\n"
f"target_seq_aln = {''.join(target_seq_aln)}\n"
f"target_seq = {target_seq}"
)
lm.logger.debug("update target_seq_aln: Done")
lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (after update) = {target_seq_aln_insert}"
)
# if reverse alignment
if aln_direction == "Reverse Alignment":
# illustrate reverse alignment as forward alignment
target_seq_aln = target_seq_aln[::-1]
target_seq_aln_insert = target_seq_aln_insert[::-1]
lm.logger.debug("use reverse alignment")
lm.logger.debug(f"target_seq_aln = {target_seq_aln}")
lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}")
# add PAM info
# print(target_seq_aln)
# print(target_seq_aln_insert)
# target_seq_aln
# ['-','-','G','G','C','A','G','C','G','G','C','T','G','G','A','A','A','A','A','A','A','A','A','A','A','A',
# 'A','A','A','G','T','','']
# target_seq_aln_insert
# ['','','','','','','','','','','','','','','','','','','','','','','','','','','','','','A','G','','']
# set plot_region
# print(aln_start, aln_end, region_extend_length, ref_seq_length)
# 56 76 25 266
# print(aln_start - region_extend_length)
# 31
# print(aln_end + region_extend_length)
# 101
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bases.iloc[plot_region[0]: plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql")
)
# make plot
# set color
# show indel
matplotlib.use('Agg')
np.set_printoptions(suppress=True)
indel_plot_state = show_indel
index_plot_state = show_index
box_border_plot_state = box_border
# set panel size
panel_box_width = 0.4
panel_box_heigth = 0.4
panel_space = 0.05
panel_box_space = box_space
# color part
base_colors = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#AAAAAA", "-": "#AAAAAA"}
# make color breaks
color_break_num = 20
break_step = 1.0 / color_break_num
# mutation ratio of box lower than min_ratio will show as white
min_color_value = min_ratio
# mutation ratio of box lower than max_ratio will show as white
max_color_value = max_ratio
color_break = np.round(np.arange(min_color_value, max_color_value, break_step), 5)
# min_color: min color to plot heatmap with RGB format
# max_color: max color to plot heatmap with RGB format
color_list = make_color_list(min_color, max_color, len(color_break) - 1, "HEX")
color_list = ["#FFFFFF"] + color_list
lm.logger.debug(f"color_list = {color_list}")
# get plot info
total_box_count = plot_region[1] - plot_region[0]
# calculate base info and fix zero
base_sum_count = df_bases_select.loc[:, ["A", "G", "C", "T"]].sum(axis=1).astype(int)
lm.logger.debug(f"base_sum_count =\n{base_sum_count.values}")
total_sum_count = df_bases_select[["A", "G", "C", "T", "del_count", "insert_count"]].sum(axis=1).astype(int)
lm.logger.debug(f"total_sum_count =\n{total_sum_count.values}")
# fix 0 -> ["A.ratio", list(df_bases_select["A"] / base_sum_count)]
base_sum_count[base_sum_count == 0] = 1
total_sum_count[total_sum_count == 0] = 1
# make plot size
lm.logger.debug(f"aln =\n{aln}")
if indel_plot_state:
panel_height_coef = [.5, .9, .9] + [.5] * 6 + [.5] * 6
panel_space_coef = [1.] * 3 + [.3] * 3 + [1., .3, 1.] + [.3] * 3 + [1., .3]
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", target_seq_aln[plot_region[0]: plot_region[1]]],
["Ref_seq", df_bases_select.ref_base],
["A", np.array(df_bases_select["A"])],
["G", np.array(df_bases_select["G"])],
["C", np.array(df_bases_select["C"])],
["T", np.array(df_bases_select["T"])],
["Del", np.array(df_bases_select["del_count"])],
["Ins", np.array(df_bases_select["insert_count"])],
["A %", list(df_bases_select["A"] / base_sum_count)],
["G %", list(df_bases_select["G"] / base_sum_count)],
["C %", list(df_bases_select["C"] / base_sum_count)],
["T %", list(df_bases_select["T"] / base_sum_count)],
["Del %", list(df_bases_select["del_count"] / total_sum_count)],
["Ins %", list(df_bases_select["insert_count"] / total_sum_count)]
]
else:
panel_height_coef = [.5, .9, .9] + [.5] * 4 + [.5] * 4
panel_space_coef = [1.] * 3 + [0.3] * 3 + [1.] + [0.3] * 3
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", target_seq_aln[plot_region[0]: plot_region[1]]],
["Ref_seq", df_bases_select.ref_base],
["A", np.array(df_bases_select["A"])],
["G", np.array(df_bases_select["G"])],
["C", np.array(df_bases_select["C"])],
["T", np.array(df_bases_select["T"])],
["A %", list(df_bases_select["A"] / base_sum_count)],
["G %", list(df_bases_select["G"] / base_sum_count)],
["C %", list(df_bases_select["C"] / base_sum_count)],
["T %", list(df_bases_select["T"] / base_sum_count)]
]
# get box and space info
box_height_list = np.array(panel_height_coef) * panel_box_heigth
panel_space_list = np.array(panel_space_coef) * panel_space
lm.logger.debug(f"box_height_list = {box_height_list}")
lm.logger.debug(f"panel_space_list = {panel_space_list}")
# calculate figure total width and height
figure_width = total_box_count * panel_box_width + \
(total_box_count - 1) * panel_box_space + panel_box_width * 2
figure_height = sum(box_height_list) + sum(panel_space_list)
lm.logger.debug(f"figure_width = {figure_width}")
lm.logger.debug(f"figure_height = {figure_height}")
# make all box_x
box_x_vec = np.arange(0, figure_width + panel_box_width, panel_box_width + panel_box_space)
box_x_vec = box_x_vec[:(len(ref_seq) + 1)]
# lm.logger.debug(f'box_x_vec (x for each column) =\n{box_x_vec}')
# make box border
if box_border_plot_state:
box_edgecolor = "#AAAAAA"
box_linestyle = "-"
box_linewidth = 2
lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}")
lm.logger.debug(f"box_edgecolor = {box_edgecolor}")
lm.logger.debug(f"box_linestyle = {box_linestyle}")
lm.logger.debug(f"box_linewidth = {box_linewidth}")
else:
box_edgecolor = "#FFFFFF"
box_linestyle = "None"
box_linewidth = 0
lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}")
lm.logger.debug(f"box_edgecolor = {box_edgecolor}")
lm.logger.debug(f"box_linestyle = {box_linestyle}")
lm.logger.debug(f"box_linewidth = {box_linewidth}")
# make box_y initialize
current_y = 0
lm.logger.debug("start to plot")
fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1))
lm.logger.debug(
f"set new figure, figsize = ({figure_width * 1.1}, {figure_height * 1.1})"
)
plt.set_loglevel("info")
# will show matplotlib debug log with logging
ax = fig.add_subplot(111, aspect="equal")
# will show matplotlib debug log with logging
plt.xlim([0, figure_width])
plt.ylim([-figure_height, 0])
plt.axis("off")
lm.logger.debug(f"plt.xlim([0, {figure_width}])")
lm.logger.debug(f"plt.ylim([-{figure_height}, 0])")
lm.logger.debug('plt.axis("off")')
# make plot
text_list = []
patches = []
for panel_index in range(len(panel_height_coef)):
# panel name
panel_name = plot_data_list[panel_index][0]
panel_name_x = box_x_vec[0]
panel_name_y = current_y - box_height_list[panel_index] * 0.5
text_list.append((panel_name_x, panel_name_y, panel_name, 10))
# plot panel box
if panel_name == "Ref_index":
# don't draw box, only add text
for index, box_value in enumerate(
plot_data_list[panel_index][1]):
box_x = box_x_vec[index + 1]
text_list.append(
(
box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5,
str(box_value),
7
)
)
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
elif panel_name in ["Target_seq", "Ref_seq"]:
for index, box_value in enumerate(plot_data_list[panel_index][1]):
if box_value == "":
box_fill = False
box_color = "#FFFFFF"
else:
if "Reverse" in aln_direction:
if panel_name == "Ref_seq":
box_value = "".join([DNA_rev_cmp_dict.get(x) for x in box_value])
else:
pass
box_color = base_colors.get(box_value[0])
else:
box_color = base_colors.get(box_value[-1])
if not box_color:
box_fill = False
box_color = "#FFFFFF"
else:
box_fill = True
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=box_fill,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color
)
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
str(box_value),
16
)
)
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
elif panel_name in ["A", "G", "C", "T", "Del", "Ins"]:
if panel_name in ["Del", "Ins"]:
box_ratio = plot_data_list[panel_index][1] / total_sum_count
else:
box_ratio = plot_data_list[panel_index][1] / base_sum_count
box_color_list = map_color(box_ratio, color_break, color_list)
lm.logger.debug(f"get box_color_list for {panel_name}")
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_color = box_color_list[index]
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color
)
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
str(box_value),
6
)
)
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
else:
box_color_list = map_color(plot_data_list[panel_index][1], color_break, color_list)
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_color = box_color_list[index]
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color
)
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
round(box_value * 100, 4),
6
)
)
# make next panel_y
if panel_index < len(panel_space_list):
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
# plot box
lm.logger.debug("plot rectangles")
ax.add_collection(PatchCollection(patches, match_original=True))
lm.logger.debug("plot text on each rectangle")
for text_x, text_y, text_info, text_fontsize in text_list:
plt.text(
x=text_x, y=text_y, s=text_info, horizontalalignment='center', verticalalignment='center',
fontsize=text_fontsize, fontname="Arial"
)
# output plot
fig.savefig(fname=output_fig, bbox_inches='tight', dpi=output_fig_dpi, format=output_fig_fmt)
[docs]
def region_heatmap_compare(
self,
input_tables: str,
labels: str | None = None,
target_seq: str | None = None, # sgRNA
reference_seq: str | None = None, # target locus
output_fig_heatmap: str | None = None,
output_fig_count_ratio: str | None = None,
output_table_heatmap: str | None = None,
output_table_count_ratio: str | None = None,
output_fig_fmt: str = "pdf",
input_table_header: bool = True,
to_base: tuple = ("A", "G", "C", "T", "Ins", "Del"),
heatmap_mut_direction: tuple = ("CT", "GA"),
count_ratio="all",
region_extend_length: int = 5,
output_fig_dpi: int = 100,
show_indel: bool = True,
show_index: bool = True,
block_ref: bool = True,
box_border: bool = False,
box_space: float = 0.03,
min_color: tuple = (250, 239, 230),
max_color: tuple = (154, 104, 57),
min_ratio: float = 0.001,
max_ratio: float = 0.99,
local_alignment_scoring_matrix: tuple = (5, -4, -24, -8),
local_alignment_min_score: int = 15,
PAM_priority_weight: float = 1.0,
get_built_in_target_seq: bool = False,
log_level: str = "INFO",
):
"""Plot region mutation information for multiple conditions.
This function generates a comparison of mutation information across multiple conditions
using tables generated by `bioat bam mpileup_to_table`.
Args:
input_tables (str):
Paths to input tables generated by `bioat bam mpileup_to_table`, separated by commas.
Each table should contain mutation information for a short genomic region (≤1k nt).
labels (str):
Labels for the panels, separated by commas.
target_seq (str, optional):
Sequence to align against the reference sequence in `mpileup.table`. Examples:
- 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^).
- '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^).
- 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM.
Defaults to None.
reference_seq (str, optional):
Custom reference sequence to overwrite the one in `mpileup.table`.
Can be a FASTA file or a DNA sequence. Defaults to None.
output_fig_heatmap (str, optional):
Path to the heatmap output figure. Defaults to None.
output_fig_count_ratio (str, optional):
Path to the count/ratio output figure. Defaults to None.
output_table_heatmap (str, optional):
Path to the heatmap output table. Defaults to None.
output_table_count_ratio (str, optional):
Path to the count/ratio output table. Defaults to None.
output_fig_fmt (str, optional):
Format of the output figures, either "pdf" or "png". Defaults to "pdf".
input_table_header (bool, optional):
Whether the input tables have headers. Defaults to True.
to_base (str, optional):
Reference bases to convert to, separated by commas. Defaults to "A,G,C,T,Ins,Del".
heatmap_mut_direction (str, optional):
Mutation directions to plot, specified as [from Base][to Base], separated by commas.
Defaults to "CT,GA".
count_ratio (str, optional):
Type of plot to generate: "count", "ratio", or "all". Defaults to "all".
region_extend_length (int, optional):
Number of base pairs to extend on both sides of the region. Defaults to 0.
output_fig_dpi (int, optional):
DPI for the output figures. Defaults to 300.
show_indel (bool, optional):
Whether to show indel information in the output figures. Defaults to True.
show_index (bool, optional):
Whether to display index information in the output figures. Defaults to True.
block_ref (bool, optional):
Whether to hide colors for reference sites in the output figures. Defaults to True.
box_border (bool, optional):
Whether to display box borders in the output figures. Defaults to True.
box_space (int, optional):
Space size between two boxes in the heatmap. Defaults to 1.
min_color (tuple, optional):
Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255).
max_color (tuple, optional):
Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0).
min_ratio (float, optional):
Mutation ratios below this value will appear white. Defaults to 0.0.
max_ratio (float, optional):
Mutation ratios above this value will be capped. Defaults to 1.0.
local_alignment_scoring_matrix (tuple, optional):
Alignment scoring parameters as a tuple:
(<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>).
Defaults to None.
local_alignment_min_score (int, optional):
Minimum alignment score to consider as a valid alignment. Defaults to 0.
PAM_priority_weight (float, optional):
Weight multiplier for PAM alignment scores. Defaults to 1.0.
get_built_in_target_seq (bool, optional):
Set to True to return built-in target sequence information. Defaults to False.
log_level (str, optional):
Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO".
Returns:
None
Examples:
>>> # Generate mutation tables
>>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info1.tsv
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info2.tsv
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info3.tsv
>>> # Generate region heatmap comparison
>>> bioat target_seq region_heatmap_compare --input_tables test_sorted.mpileup.info1.tsv,test_sorted.mpileup.info2.tsv,test_sorted.mpileup.info3.tsv --labels condition1,condition2,condition3 --target_seq HEK4
"""
lm.set_names(func_name="region_heatmap_compare")
lm.set_level(log_level)
base_color_dict = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#DDEAF6"}
def plot_agct(x):
if x == '': # target_seq sgRNA左右没align的位置给空值
return '#FFFFFF'
elif x in ['A', 'G', 'C', 'T', 'N']:
return base_color_dict[x]
else:
return '#AAAAAA'
# check whether return built-in information
if get_built_in_target_seq:
lm.logger.info(
"You can use <key> in built-in <target_seq> to represent your target_seq:\n"
+ "\t<key>\t<target_seq>\n"
+ "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()])
)
lm.logger.info("exit because of the defination for get_built_in_target_seq")
exit(0)
else:
# set target_seq(sgRNA) info
lm.logger.debug(f"load target_seq = {target_seq}")
if target_seq in TARGET_SEQ_LIB:
lm.logger.debug(
f"target_seq information is abtained from the <key>={target_seq}"
)
target_seq = TARGET_SEQ_LIB[target_seq]
lm.logger.debug(f"target_seq is refered to {target_seq}")
elif target_seq is None:
pass
elif isinstance(target_seq, str):
lm.logger.debug(f"target_seq is refered to {target_seq}")
else:
lm.logger.error(
"<target_seq> should be set as:"
"DNA sequence / <key> in <get_built_in_target_seq> / None,"
f"but your <target_seq> is {target_seq}"
)
exit(1)
# check if nothing to output
outputs = (output_fig_heatmap, output_fig_count_ratio, output_table_heatmap, output_table_count_ratio)
if not [i for i in outputs if i is not None]:
lm.logger.error(
"outputs should be set as free combination of (output_fig_heatmap, output_fig_count_ratio, "
"output_table_heatmap, output_table_count_ratio), but none of outputs was defined"
)
exit(1)
# check parameter: to_base
# check to_base: tuple for A G C T Ins Del or combine like A G T Ins
default_bases = ('A', 'G', 'C', 'T', 'Ins', 'Del')
lm.logger.debug("check to_base parameter")
for base in to_base:
if base not in default_bases:
lm.logger.error("check to_base parameter")
lm.logger.error(f"{base} not in {default_bases}")
raise BioatInvalidOptionError
lm.logger.debug(f"to_base was defined as {to_base}")
# check parameter: count_ratio
default_count_ratio_choices = ('count', 'ratio', 'all')
lm.logger.debug("check count_ratio parameter")
if count_ratio not in default_count_ratio_choices:
lm.logger.error("check count_ratio parameter")
lm.logger.error(f"{count_ratio} not in {default_count_ratio_choices}")
raise BioatInvalidOptionError
lm.logger.debug(f"count_ratio was defined as {count_ratio}")
# parse parameter: input_tables
if isinstance(input_tables, tuple):
input_tables = [i.strip() for i in input_tables]
elif isinstance(input_tables, str):
input_tables = input_tables.replace(' ', '').replace('\t', '').strip().split(',')
else:
raise BioatInvalidInputError
lm.logger.debug(f"parse tables... {input_tables}")
# parse parameter: labels, if labels is None, labels value will be input_tables names
if isinstance(labels, tuple):
labels = [i.strip() for i in labels]
elif isinstance(labels, str):
labels = labels.replace(' ', '').replace('\t', '').strip().split(',') if labels else input_tables
else:
raise BioatInvalidInputError
lm.logger.debug(f"parse labels... {labels}")
# set alignment
aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix)
# load & merge bmat files & set treatment names
ls = []
for input_table, label in zip(input_tables, labels):
df = pd.read_csv(input_table, sep="\t") if input_table_header \
else pd.read_csv(input_table, sep="\t", header=None)
df['label'] = label
ls.append(df)
df_bases = pd.concat(ls)
lm.logger.debug(
"df_bases.head(10):\n"
+ tabulate(df_bases.head(8), headers="keys", tablefmt="psql")
)
# parse parameter: reference_seq, if reference_seq is None, reference_seq value will be derived from df_bases
if reference_seq:
# fa file or seq str to seq str
if os.path.isfile(reference_seq):
f = open(reference_seq, 'rt') if not reference_seq.endswith('.gz') else gzip.open(reference_seq, 'rt')
reference_seq = ''.join([i.rstrip() for i in f.readlines()[1:]])
if reference_seq.count('>') > 0:
raise BioatFileFormatError
ref_seq = Seq(reference_seq)
lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}")
else:
ref_seq = Seq("".join(df_bases.query('label==@labels[0]').ref_base))
lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}")
lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}")
ref_seq_rc = ref_seq.reverse_complement()
lm.logger.debug(
f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}"
)
# parse parameter: target_seq
if target_seq is None:
target_seq = ref_seq
PAM = None
elif '^' in target_seq:
if target_seq.find('^') == 0:
PAM, target_seq = target_seq[1:].strip().split('^')
target_seq = Seq(target_seq.upper())
PAM = {'PAM': PAM, 'position': 0, 'weight': PAM_priority_weight}
else:
target_seq, PAM = target_seq[:-1].strip().split('^')
target_seq = Seq(target_seq.upper())
PAM = {'PAM': PAM, 'position': len(target_seq), 'weight': PAM_priority_weight}
else:
lm.logger.debug(
f"no PAM sequence is defined from target_seq = {target_seq}"
)
target_seq = Seq(target_seq)
PAM = None
lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}")
# fwd alignment & rev alignment
# 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上
if not PAM:
lm.logger.info("use PAMless mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
else:
lm.logger.info("use PAM mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM)
lm.logger.debug(f"final_align_reverse:\n{rev}")
# get fwd alignment info
reference_seq = fwd['alignment']['reference_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
align_info = fwd['alignment']['aln_info'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
target_seq = fwd['alignment']['target_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1]
aln_score = fwd['aln_score']
lm.logger.info(
f"Forward best alignment:\n"
f"reference : {reference_seq}\n"
f"align_info : {align_info}\n"
f"target_seq : {target_seq}\n"
f"align_score: {aln_score}"
)
# get rev alignment info
reference_seq_rev = rev['alignment']['reference_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
align_info_rev = rev['alignment']['aln_info'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
target_seq_rev = rev['alignment']['target_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1]
aln_score_rev = rev['aln_score']
lm.logger.info(
f"Reverse best alignment:\n"
f"reference : {reference_seq_rev}\n"
f"align_info : {align_info_rev}\n"
f"target_seq : {target_seq_rev}\n"
f"align_score: {aln_score_rev}"
)
# define align direction and final align res
# make alignment info
if any([x >= local_alignment_min_score for x in (fwd['aln_score'], rev['aln_score'])]):
if fwd['aln_score'] >= rev['aln_score']:
aln_direction = "Forward Alignment"
aln = fwd
else:
aln_direction = "Reverse Alignment"
aln = rev
reference_seq = reference_seq_rev
align_info = align_info_rev
target_seq = target_seq_rev
aln_score = aln_score_rev
else:
lm.logger.critical("Alignment Error!")
exit(1)
# mark aligned all bases in target_seq
ref_seq_length = len(ref_seq)
target_seq_aln = [""] * ref_seq_length
# mark inserted bases in target_seq
target_seq_aln_insert = [""] * ref_seq_length
DNA_rev_cmp_dict = {
"A": "T",
"T": "A",
"C": "G",
"G": "C",
"N": "N",
"-": "-"
}
reference_gap_count = reference_seq.count("-")
target_seq_gap_count = target_seq.count("-")
ref_gap_count = 0
ref_del_str = ""
aln_start = aln['ref_aln_start']
aln_end = aln['ref_aln_end']
# update target_seq_aln
lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (before update) = {target_seq_aln_insert}"
)
for idx, ref_base in enumerate(reference_seq):
# reference : GCCTCTGGAGAGGGAGGAGGG
# align_info : |.|.|||..|..|||||.||-
# target_seq : GGCACTGCGGCTGGAGGTGG-
if ref_base != "-":
ref_gap_count += 0
ref_del_str = ""
target_seq_aln[aln_start + idx - ref_gap_count] = ref_del_str + target_seq[idx]
else:
ref_gap_count += 1
ref_del_str += target_seq[idx] # for continuous gap
target_seq_aln_insert[aln_start + idx] = target_seq[idx]
lm.logger.debug(
"update target_seq_aln: show ''.join(target_seq_aln)\n"
f"target_seq_aln = {''.join(target_seq_aln)}\n"
f"target_seq = {target_seq}"
)
lm.logger.debug("update target_seq_aln: Done")
lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (after update) = {target_seq_aln_insert}"
)
# if reverse alignment
if aln_direction == "Reverse Alignment":
# illustrate reverse alignment as forward alignment
target_seq_aln = target_seq_aln[::-1]
target_seq_aln_insert = target_seq_aln_insert[::-1]
lm.logger.debug("use reverse alignment")
lm.logger.debug(f"target_seq_aln = {target_seq_aln}")
lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}")
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bases.iloc[plot_region[0]: plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql")
)
# ---------------------------------------------------------------->>>>>
# load .bmat file
# ---------------------------------------------------------------->>>>>
label_panel = labels
ls_bmat = input_tables
ls_bmat_table = [pd.read_csv(path_bmat, sep='\t') for path_bmat in ls_bmat]
for index, bmat in enumerate(ls_bmat_table):
bmat['label'] = label_panel[index]
bmat.drop('chr_name', axis=1, inplace=True)
df_tmp = ls_bmat_table[0]
bmat_table = ls_bmat_table[0]
for df_bmat in ls_bmat_table[1:]:
suffixes = (f'_{df_tmp.iloc[0, -1]}', f'_{df_bmat.iloc[0, -1]}')
df_tmp = pd.merge(df_tmp, df_bmat, on='chr_index', how='outer', suffixes=suffixes)
ls_columns = ['chr_index']
for label in label_panel:
str_tmp = ' '.join(
['ref_base_', 'A_', 'G_', 'C_', 'T_', 'del_count_', 'insert_count_', 'ambiguous_count_',
'deletion_',
'insertion_', 'ambiguous_', 'mut_num_', 'label_ '])
ls_tmp = str_tmp.replace('_ ', '_{label} ').format(label=label).strip().split(' ')
ls_columns.extend(ls_tmp)
df_tmp.columns = ls_columns
# print(f'df_tmp = \n{df_tmp}')
# define ref_seq as the referencing sequence
if len(ls_bmat_table) == 1:
ref_seq = "".join(bmat_table['ref_base_' + label_panel[0]].tolist())
elif len(ls_bmat_table) > 1:
ref_seq = "".join(bmat_table.ref_base)
else:
raise ValueError('ref info gets wrong!')
df_bmat_all = df_tmp.copy()
# print(f'df_bmat_all = \n{df_bmat_all}')
# print(f'ref_seq = \n{ref_seq}')
# make alignment info
sgRNA_align = [""] * len(ref_seq)
sgRNA_align_insert = [""] * len(ref_seq)
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bmat_all.iloc[plot_region[0]: plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql")
)
# make plot
# set color
# show indel
matplotlib.use('Agg')
np.set_printoptions(suppress=True)
indel_plot_state = show_indel
index_plot_state = show_index
box_border_plot_state = box_border
# set panel size
panel_box_width = 0.4
panel_box_heigth = 0.4
panel_space = 0.05
panel_box_space = box_space
# color part
base_colors = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#AAAAAA", "-": "#AAAAAA"}
# make color breaks
color_break_num = 20
break_step = 1.0 / color_break_num
# mutation ratio of box lower than min_ratio will show as white
min_color_value = min_ratio
# mutation ratio of box lower than max_ratio will show as white
max_color_value = max_ratio
color_break = np.round(np.arange(min_color_value, max_color_value, break_step), 5)
# min_color: min color to plot heatmap with RGB format
# max_color: max color to plot heatmap with RGB format
if isinstance(min_color, tuple):
low_color = min_color
else:
raise BioatInvalidOptionError
if isinstance(max_color, tuple):
high_color = max_color
else:
raise BioatInvalidOptionError
try:
color_list = make_color_list(low_color, high_color, len(color_break) - 1, "Hex")
color_list = ["#FFFFFF"] + color_list
except:
lm.logger.debug(low_color, high_color)
lm.logger.debug(color_break)
# get plot info
total_box_count = plot_region[1] - plot_region[0]
# calculate base info and fix zero
ls_base_sum_count = []
ls_total_sum_count = []
for label in label_panel:
base_sum_count = df_bases_select[
["A_%s" % label, "G_%s" % label, "C_%s" % label, "T_%s" % label]].apply(
lambda x: x.sum(), axis=1)
# print base_sum_count
total_sum_count = df_bases_select[["A_%s" % label, "G_%s" % label, "C_%s" % label, "T_%s" % label,
"del_count_%s" % label, "insert_count_%s" % label]].apply(
lambda x: x.sum(), axis=1)
base_sum_count[base_sum_count == 0] = 1
base_sum_count = base_sum_count + 1
total_sum_count = total_sum_count + 1
ls_base_sum_count.append(base_sum_count)
ls_total_sum_count.append(total_sum_count)
# make plot size
plot_data_list = None
panel_space_coef = None
panel_height_coef = None
if count_ratio == 'all':
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(default_bases) * 2
else:
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(default_bases)
# 根据bmat表格个数来控制方格高度
# panel_space_coef = [1, 1, 1] + [0.3] * 3 + [1, 0.3, 1] + [0.3] * 3 + [1, 0.3]
if count_ratio == 'all':
panel_space_coef = [1, 1, 1] + ([0.3] * (len(ls_bmat_table) - 1) + [4]) * len(default_bases) * 2
else:
panel_space_coef = [1, 1, 1] + ([0.3] * (len(ls_bmat_table) - 1) + [4]) * len(default_bases)
# 更正heatmap align错误
# plot_heatmap_index = bmat_table_select.chr_index
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", np.array(target_seq_aln[plot_region[0]: plot_region[1]])],
["Ref_seq", df_bases_select['ref_base_%s' % label_panel[0]]]
]
if count_ratio == 'count' or count_ratio == 'all':
for to_base in default_bases:
if to_base == 'A':
for label in label_panel:
plot_data_list.append(
["{label}: to A".format(label=label),
np.array(df_bases_select["A_{label}".format(label=label)])]
)
if to_base == 'G':
for label in label_panel:
plot_data_list.append(
["{label}: to G".format(label=label),
np.array(df_bases_select["G_{label}".format(label=label)])]
)
if to_base == 'C':
for label in label_panel:
plot_data_list.append(
["{label}: to C".format(label=label),
np.array(df_bases_select["C_{label}".format(label=label)])]
)
if to_base == 'T':
for label in label_panel:
plot_data_list.append(
["{label}: to T".format(label=label),
np.array(df_bases_select["T_{label}".format(label=label)])]
)
if to_base == 'Del':
for label in label_panel:
plot_data_list.append(
["{label}: to Del".format(label=label),
np.array(df_bases_select["del_count_{label}".format(label=label)])]
)
if to_base == 'Ins':
for label in label_panel:
plot_data_list.append(
["{label}: to Ins".format(label=label),
np.array(df_bases_select["insert_count_{label}".format(label=label)])]
)
if count_ratio == 'ratio' or count_ratio == 'all':
for to_base in default_bases:
if to_base == 'A':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to A(%)".format(label=label),
np.array(
df_bases_select["A_{label}".format(label=label)] / ls_base_sum_count[index])]
)
if to_base == 'G':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to G(%)".format(label=label),
np.array(
df_bases_select["G_{label}".format(label=label)] / ls_base_sum_count[index])]
)
if to_base == 'C':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to C(%)".format(label=label),
np.array(
df_bases_select["C_{label}".format(label=label)] / ls_base_sum_count[index])]
)
if to_base == 'T':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to T(%)".format(label=label),
np.array(
df_bases_select["T_{label}".format(label=label)] / ls_base_sum_count[index])]
)
if to_base == 'Del':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to Del(%)".format(label=label), np.array(
df_bases_select["del_count_{label}".format(label=label)] / ls_total_sum_count[
index])]
)
if to_base == 'Ins':
for index, label in enumerate(label_panel):
plot_data_list.append(
["{label}: to Ins(%)".format(label=label), np.array(
df_bases_select["insert_count_{label}".format(label=label)] / ls_total_sum_count[
index])]
)
# get box and space info
box_height_list = np.array(panel_height_coef) * panel_box_heigth
panel_space_list = np.array(panel_space_coef) * panel_space
# for i, j in enumerate(panel_space_coef):
# print i,j
# print panel_space
# calculate figure total width and height
figure_width = total_box_count * panel_box_width + (
total_box_count - 1) * panel_box_space + panel_box_width * 2
figure_height = sum(box_height_list) + sum(panel_space_list)
lm.logger.debug(f"figure_width = {figure_width}")
lm.logger.debug(f"figure_height = {figure_height}")
# make all box_x
box_x_vec = np.arange(0, figure_width + panel_box_width, panel_box_width + panel_box_space)
box_x_vec = box_x_vec[:(len(ref_seq) + 1)]
# make box border
if box_border_plot_state:
box_edgecolor = "#AAAAAA"
box_linestyle = "-"
box_linewidth = 2
else:
box_edgecolor = "#FFFFFF"
box_linestyle = "None"
box_linewidth = 0
# make box_y initialize
current_y = 0
df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef))
ls_row_name = []
for panel_index in range(len(panel_height_coef)):
ls_row_name.append(plot_data_list[panel_index][0])
df_matrix.iloc[panel_index, :] = df_matrix.iloc[panel_index, :].astype(np.str_)
for index, box_value in enumerate(plot_data_list[panel_index][1]):
row = panel_index
col = index
df_matrix.iloc[row, col] = box_value
# print(f'box_value = {box_value}, row = {row}, col = {col}')
# print(f'df_matrix.iloc[row, col] = {df_matrix.iloc[row, col]}')
# if box_value == '':
# raise ValueError
df_matrix.index = ls_row_name
# plot heatmap
if output_fig_heatmap:
# 预设参数
num_extend = region_extend_length
# rename columns of df_matrix
df_matrix.columns = range(1, df_matrix.shape[1] + 1)
dt_base = {'A': [], 'G': [], 'C': [], 'T': []}
for info in heatmap_mut_direction:
dt_base[info[0]] += info[1]
lm.logger.debug("[mut_direction] for heatmap: %s" % dt_base)
lm.logger.debug("[label_panel] for heatmap: %s" % label_panel)
lm.logger.debug(
"[region_extend_length] for heatmap: %s" % region_extend_length
)
lm.logger.debug("[block_ref] for Target-seq multiplot: %s" % block_ref)
ls_col_not_null = df_matrix.columns[-(df_matrix.loc['Target_seq', :] == '')].tolist()
ls_col_all = df_matrix.columns.tolist()
ls_plot_range = list(range(ls_col_not_null[0] - num_extend, ls_col_not_null[-1] + num_extend))
try:
df_matrix = df_matrix.loc[:, ls_plot_range]
# 有不在的就不为空,不为空则判断成立
# print 'testTTTTTTTTTTTTTTT'
except:
raise ValueError('The param [num_extend] is too large or <0, it must be a integer>=0')
ls_on_target = df_matrix.loc['Target_seq', :].tolist()
ls_ref = df_matrix.loc['Ref_seq', :].tolist()
ls_ratio = []
for label in label_panel:
df_sample = df_matrix[[label == i.split(': to')[0] and '(%)' not in i for i in df_matrix.index]]
df_sample.index = ['A', 'G', 'C', 'T', 'Ins', 'Del']
ls_A = df_sample.T['A'].isnull().tolist()
ls_G = df_sample.T['G'].isnull().tolist()
ls_C = df_sample.T['C'].isnull().tolist()
ls_T = df_sample.T['T'].isnull().tolist()
ls_bool = []
for i in range(len(ls_A)):
if ls_A[i] & ls_C[i] & ls_G[i] & ls_T[i]:
ls_bool.append(False)
else:
ls_bool.append(True)
ls_bl_select_df_sample = ls_bool
df_sample = df_sample.loc[:, ls_bl_select_df_sample]
df_sample = df_sample.loc[['A', 'G', 'C', 'T'], :].T.copy()
df_sample = df_sample.astype(int)
df_sample['sum'] = df_sample['A'] + df_sample['G'] + df_sample['C'] + df_sample['T']
df_sample['Ref_seq'] = np.array(ls_ref)[ls_bl_select_df_sample]
df_sample['To_base'] = df_sample['Ref_seq'].map(lambda x: dt_base[x])
df_sample['Mut_count'] = 0
for recode in df_sample.iterrows():
index_row = recode[0]
for to_base in recode[1]['To_base']:
df_sample.loc[index_row, 'Mut_count'] += recode[1][to_base]
df_sample['Mut_ratio'] = df_sample['Mut_count'] / df_sample['sum']
ls_sample_ratio = df_sample['Mut_ratio'].tolist()
ls_ratio.append(ls_sample_ratio)
df_ratio_all = pd.DataFrame(ls_ratio).T
df_ratio_all.columns = label_panel
try:
lm.logger.debug("Catch NA: {}".format(df_ratio_all.isna().sum().sum()))
lm.logger.debug(df_ratio_all.head())
# print(f'df_matrix.T = \n{df_matrix.T}')
df_ratio_all.index = df_matrix.T['Ref_index'][ls_bl_select_df_sample].map(float).map(int).tolist()
df_ratio_all['Target_seq'] = df_matrix.T['Target_seq'][ls_bl_select_df_sample].tolist()
df_ratio_all['Ref_seq'] = df_matrix.T['Ref_seq'][ls_bl_select_df_sample].tolist()
# print(f'df_ratio_all.T = \n{df_ratio_all.T}')
# exit()
except ValueError:
start_idx = df_matrix.T['Ref_index'][ls_bl_select_df_sample].map(float).map(int).tolist()[0]
end_idx = start_idx + len(df_ratio_all.index.tolist())
df_ratio_all.index = np.arange(start_idx, end_idx)
df_ratio_all['Target_seq'] = df_matrix.T['Target_seq'].tolist()
df_ratio_all['Ref_seq'] = df_matrix.T['Ref_seq'].tolist()
df_ratio_all['On-Target'] = ''
df_ratio_all['Reference'] = ''
df_ratio_all = df_ratio_all[['Target_seq', 'Ref_seq', 'On-Target', 'Reference'] + label_panel].T.copy()
df_onTarget_Ref = df_ratio_all.iloc[:2, :].fillna(' ')
lm.logger.debug(f"df_onTarget_Ref: \n{df_onTarget_Ref}")
# exit()
# print(df_onTarget_Ref == '')
df_onTarget_Ref_color = df_ratio_all.iloc[:2, :].fillna('').map(plot_agct)
lm.logger.debug(f"df_onTarget_Ref_color: \n{df_onTarget_Ref_color}")
# exit()
lm.logger.debug(f"df_ratio_all: \n{df_ratio_all}")
# exit()
df_plot = df_ratio_all.iloc[2:, :].copy()
lm.logger.debug(f"df_plot: \n{df_plot}")
# print(f'df_plot: \n{df_plot}')
# exit()
df_plot.iloc[2:, :] = df_plot.iloc[2:, :].astype(float) * 100
ls_max = []
for i in df_plot.iloc[2:, :].values.tolist():
ls_max.extend(i)
try:
ls_break = list(np.arange(0, max(ls_max), max(ls_max) / 100))
except ZeroDivisionError:
ls_break = [0, 0.003, 0.006, 0.009, 0.012, 0.015, 0.018, 0.022, 0.026, 0.030, 0.035, 0.040, 0.05, 0.06,
0.07, 0.08, 1.00, 2.00, 3.00]
ls_color_middle = make_color_list(
low_color_RGB=convert_hex_to_rgb('#87BDDB'),
high_color_RGB=convert_hex_to_rgb('#0A306A'),
length_out=80,
return_fmt="Hex")
ls_color_top = ['#0A306A'] * 20
ls_color_bottom = ['#EFEFEF'] * 5 + ['#DDEAF6'] * 5 + ['#A9CEE4'] * 5 + ['#87BDDB'] * 5
ls_color = ls_color_bottom + ls_color_middle + ls_color_top
lm.logger.debug(f"ls_color: {ls_color}")
# exit()
df_onTarget_Ref_tmp = df_onTarget_Ref.T
df_onTarget_Ref_tmp.loc[df_onTarget_Ref_tmp['Target_seq'].isnull(), 'Target_seq'] = pd.NA
df_onTarget_Ref = df_onTarget_Ref_tmp.T
df_plot_rec = df_plot.copy()
lm.logger.debug(
"df_plot_rec:\n"
+ tabulate(df_plot_rec, headers="keys", tablefmt="psql")
)
# heatmap颜色,去map_hex_for_matrix函数中调整
def map_hex_for_matrix(x):
# 判断value范围并返回颜色的Hex值
value = -1
for value in ls_break:
if x < value:
break
else:
continue
return ls_color[ls_break.index(value) - 1]
lm.logger.debug(f"df_plot_rec: \n{df_plot_rec}")
# print(f'df_plot_rec: \n{df_plot_rec}')
# exit()
df_plot_rec.iloc[2:, :] = df_plot_rec.iloc[2:, :].map(map_hex_for_matrix)
df_plot_rec_cmap = df_plot_rec.copy()
# print(df_plot_rec_cmap)
# exit()
# print(f'df_plot_rec = \n{df_plot_rec}')
# print(df_plot_rec.loc['On-Target', :])
# print(df_plot_rec.loc['On-Target', :].dtype)
# df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef))
# print(f"df_onTarget_Ref.loc['Target_seq', :] = \n{df_onTarget_Ref.loc['Target_seq', :]}")
# print('df_onTarget_Ref', df_onTarget_Ref.loc['Target_seq', :] == '')
# print(f"df_plot_rec.loc['On-Target', :] = \n{df_plot_rec.loc['On-Target', :]}")
# print('df_plot_rec', df_plot_rec.loc['On-Target', :] == '')
# exit()
df_plot_rec.loc['On-Target', :] = df_onTarget_Ref.loc['Target_seq', :]
# print(df_plot_rec.loc['On-Target', :])
# print(df_onTarget_Ref.loc['Target_seq', :])
df_plot_rec.loc['Reference', :] = df_onTarget_Ref.loc['Ref_seq', :]
df_plot_rec_cmap.loc['On-Target', :] = df_onTarget_Ref.loc['Target_seq', :].map(plot_agct)
df_plot_rec_cmap.loc['Reference', :] = df_onTarget_Ref.loc['Ref_seq', :].map(plot_agct)
figure_width_heatmap = df_plot_rec.shape[1]
figure_height_heatmap = max(df_plot_rec.shape[0], len(ls_break) * 0.8)
scale = 1.1 # 1.1
fig_heatmap = plt.figure(figsize=(figure_width_heatmap * scale, figure_height_heatmap * scale))
ax = fig_heatmap.add_subplot(111, aspect="equal")
plt.xlim([0, figure_width_heatmap + 5])
plt.ylim([-figure_height_heatmap - 0.5, 0])
plt.axis("off")
# export heatmap values
lm.logger.debug("export heatmap reference table...")
lm.logger.debug(f"df_plot_rec: \n{df_plot_rec.index}")
# exit()
for row in range(df_plot_rec.shape[0]):
row_name = df_plot_rec.index[row]
# add panel name
ax.text(
x=-1,
y=0.55 - row - 1.1,
s=row_name,
horizontalalignment='right',
verticalalignment='center',
fontsize=34 / 1.1 * scale,
fontname="DejaVu Sans",
alpha=1,
)
# add Target_seq and ref
if row_name in ['On-Target', 'Reference']:
for col in range(df_plot.shape[1]):
# plot Rectangle
site_x = col
site_y = - row + 0.05 - 1.05
ax.add_patch(matplotlib.patches.Rectangle(
(site_x, site_y),
width=1,
height=1,
linestyle='-',
fill=True,
facecolor=df_plot_rec_cmap.iloc[row, col],
edgecolor='#AAAAAA' if df_plot_rec_cmap.iloc[row, col] != '#FFFFFF' else '#FFFFFF',
linewidth=3.5
))
# plot text
site_x = col
site_y = -row - 0.55
ax.text(
x=site_x + 0.5,
y=site_y,
s=df_plot_rec.iloc[row, col],
horizontalalignment='center',
verticalalignment='center',
fontsize=34 / 1 * scale,
fontname="DejaVu Sans",
alpha=1
)
# add seq panels
else:
for col in range(df_plot.shape[1]):
# plot Rectangle
site_x = col
site_y = -row - 1.05
ax.add_patch(matplotlib.patches.Rectangle(
(site_x, site_y),
width=1,
height=1,
linestyle='-',
fill=True,
facecolor=df_plot_rec_cmap.iloc[row, col],
edgecolor='#FFFFFF',
linewidth=3
))
# add cbar
# start from this site
site_x = df_plot.shape[1] + 1
site_y = -figure_height_heatmap * 0.1 - 1
step_scale = 0.1
lm.logger.debug("[cbar_scale]: %s" % step_scale)
for color in ls_color:
site_y += step_scale
ax.add_patch(matplotlib.patches.Rectangle(
(site_x, site_y),
width=1,
height=step_scale,
fill=True,
facecolor=color,
edgecolor=color,
))
# add cbar text
# start from this site
site_x = df_plot.shape[1] + 1
site_y = -figure_height_heatmap * 0.1
for idx, label in enumerate(ls_break):
site_y += step_scale
if idx % 10 == 0:
ax.text(
x=site_x + 1.1,
y=site_y - 1,
s=round(label, 4),
horizontalalignment='left',
verticalalignment='center',
fontsize=34 / 1.5,
# fontsize=34 / 1.5 * step_scale,
fontname="DejaVu Sans",
alpha=1
)
fig_heatmap.savefig(output_fig_heatmap, bbox_inches='tight') # 减少边缘空白
# plt.show()
if output_table_heatmap:
if output_table_heatmap.endswith('.csv'):
df_plot_rec.to_csv(output_table_heatmap, sep=',')
elif output_table_heatmap.endswith('.tsv'):
df_plot_rec.to_csv(output_table_heatmap, sep='\t')
if output_fig_count_ratio:
# set new figure
fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1))
ax = fig.add_subplot(111, aspect="equal")
plt.xlim([0, figure_width])
plt.ylim([-figure_height, 0])
plt.axis("off")
text_list = []
patches = []
# print panel_height_coef
for panel_index in range(len(panel_height_coef)):
# print 'panel index:', panel_index
# panel name
panel_name = plot_data_list[panel_index][0]
panel_name_x = box_x_vec[0]
# print panel_name
# print panel_name_x
panel_name_y = current_y - box_height_list[panel_index] * 0.5
# print panel_name_y
text_list.append((panel_name_x, panel_name_y, panel_name, 10))
# plot panel box
# plot Index行
if panel_name == "Ref_index":
# don't draw box, only add text
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_x = box_x_vec[index + 1]
text_list.append(
(box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5,
str(box_value), 6))
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
# plot Target_seq和Ref_seq行
elif panel_name in ["Target_seq", "Ref_seq"]:
# if panel_name = Ref, form a new list to store plot info for checking the block_ref param
if panel_name == 'Ref_seq':
plot_data_list_ref = plot_data_list[panel_index][1]
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# print(index, box_value)
if box_value == "":
box_fill = False
box_color = "#FFFFFF"
else:
if aln_direction == "Reverse Alignment":
if panel_name == "Ref_seq":
box_value = "".join([DNA_rev_cmp_dict.get(x) for x in box_value])
else:
### fix bug
for x in box_value:
box_value = DNA_rev_cmp_dict.get(x)
box_color = base_color_dict.get(box_value[0])
else:
box_color = base_color_dict.get(box_value[-1])
if not box_color:
box_fill = False
box_color = "#FFFFFF"
else:
box_fill = True
box_x = box_x_vec[index + 1]
patches.append(Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=box_fill,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color)
)
# text
text_list.append(
(box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
str(box_value), 16))
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
# 用来plot Sample行的
# 这里plot counts
elif (panel_name.split('to ')[-1] in default_bases):
# print 'panel name:', panel_name.split('to ')[-1].replace('(%)', '')
# print panel_name, 'check panel name'
# 改进了这里total 和 base的值的引用
if count_ratio == 'all':
ls_base_sum_count_all = ls_base_sum_count * 2 * len(default_bases)
ls_total_sum_count_all = ls_total_sum_count * 2 * len(default_bases)
total_sum_count = ls_total_sum_count_all[panel_index - 3]
base_sum_count = ls_base_sum_count_all[panel_index - 3]
else:
ls_base_sum_count_all = ls_base_sum_count * len(default_bases)
ls_total_sum_count_all = ls_total_sum_count * len(default_bases)
total_sum_count = ls_total_sum_count_all[panel_index - 3]
base_sum_count = ls_base_sum_count_all[panel_index - 3]
panel_name = panel_name.split('to ')[-1]
if panel_name in ["Del", "Ins"]:
box_ratio = plot_data_list[panel_index][1] / total_sum_count
else:
box_ratio = plot_data_list[panel_index][1] / base_sum_count
box_color_list = map_color(box_ratio, color_break, color_list)
ls_base_sum = base_sum_count.tolist()
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# if count_ratio == 'all':
box_color = box_color_list[index]
# check block_ref state: count plot
if block_ref:
if plot_data_list_ref.iloc[index] == panel_name:
# if box_value >= ls_base_sum[index] * 0.8:
# box_fill = False
box_color = "#FFFFFF"
box_x = box_x_vec[index + 1]
patches.append(Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color)
)
# text
text_list.append(
(box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
str(box_value), 6))
# print box_value
# make next panel_y
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
# 这里plot ratio(剩下的都是base(%))
else:
box_color_list = map_color(plot_data_list[panel_index][1], color_break, color_list)
# print box_color_list
panel_name = panel_name.split('to ')[-1].replace('(%)', '')
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# ratio
box_color = box_color_list[index]
# check block_ref state: ratio plot
if block_ref:
if plot_data_list_ref.iloc[index] == panel_name:
# if box_value >= 0.85:
# box_fill = False
box_color = "#FFFFFF"
box_x = box_x_vec[index + 1]
patches.append(Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color)
)
if '(%)' in plot_data_list[panel_index][0]:
# text
# ref
# print 'test', box_value
# if ARGS.block_ref == True:
# if box_value >= 0.99:
# box_value = 0
text_list.append(
(box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
round(box_value * 100, 4), 6))
else:
text_list.append(
(box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index],
box_value, 6))
# make next panel_y
if panel_index < len(panel_space_list):
current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index])
# plot box
ax.add_collection(PatchCollection(patches, match_original=True))
# add text
for text_x, text_y, text_info, text_fontsize in text_list:
if " to " in str(text_info):
plt.text(
x=text_x + 0.3,
y=text_y,
s=text_info,
horizontalalignment='right',
verticalalignment='center',
fontsize=text_fontsize,
fontname="DejaVu Sans"
)
else:
plt.text(
x=text_x,
y=text_y,
s=text_info,
horizontalalignment='center',
verticalalignment='center',
fontsize=text_fontsize,
fontname="DejaVu Sans"
)
# output plot
if output_fig_fmt.upper == "PNG":
fig.savefig(fname=output_fig_count_ratio, bbox_inches='tight', dpi=output_fig_dpi, format="png")
else:
fig.savefig(fname=output_fig_count_ratio, bbox_inches='tight', dpi=output_fig_dpi, format="pdf")
if output_table_count_ratio:
if output_table_count_ratio.endswith('.csv'):
df_bases_select.to_csv(output_table_count_ratio, sep=',')
elif output_table_count_ratio.endswith('.tsv'):
df_bases_select.to_csv(output_table_count_ratio, sep='\t')
if __name__ == '__main__':
pass