"""_summary_.
author: Herman Huanan Zhao
email: hermanzhaozzzz@gmail.com
homepage: https://github.com/hermanzhaozzzz
_description_
example 1:
bioat list
<in shell>:
$ bioat list
<in python consolo>:
>>> from bioat.cli import Cli
>>> bioat = Cli()
>>> bioat.list()
>>> print(bioat.list())
example 2:
_example_
"""
import gzip
import os.path
import sys
import numpy as np
import pandas as pd
from Bio.Seq import Seq
from tabulate import tabulate
from bioat.exceptions import (
BioatFileFormatError,
BioatInvalidInputError,
BioatInvalidOptionError,
)
from bioat.lib.libalignment import instantiate_pairwise_aligner
from bioat.lib.libcolor import convert_hex_to_rgb, make_color_list, map_color
from bioat.lib.libcrispr import TARGET_SEQ_LIB, run_target_seq_align
from bioat.lib.libpandas import set_option
from bioat.logger import LoggerManager
lm = LoggerManager(mod_name="bioat.target_seq")
set_option(log_level="ERROR")
def _load_matplotlib_plotting():
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
return mpl, plt, PatchCollection, Rectangle
[docs]
class TargetSeq:
"""Target Deep Sequencing toolbox."""
lm.set_names(cls_name="TargetSeq")
def __init__(self):
pass
[docs]
def region_heatmap(
self,
input_table: str,
output_fig: str,
target_seq: str | None = None, # sgRNA
reference_seq: str | None = None, # target locus
input_table_header: bool = True,
output_fig_fmt: str = "pdf",
output_fig_dpi: int = 100,
show_indel: bool = True,
show_index: bool = True,
box_border: bool = False,
box_space: float = 0.03,
min_color: tuple = (250, 239, 230),
max_color: tuple = (154, 104, 57),
min_ratio: float = 0.001,
max_ratio: float = 0.99,
region_extend_length: int = 5,
local_alignment_scoring_matrix: tuple = (5, -4, -24, -8),
local_alignment_min_score: int = 15,
PAM_priority_weight: float = 1.0,
get_built_in_target_seq: bool = False,
log_level: str = "INFO",
):
"""Plot region mutation information using a table generated by `bioat bam mpileup_to_table`.
This function generates a visualization of mutation information for a specific genomic region
based on a table created by the `bioat bam mpileup_to_table` command.
Args:
input_table (str): Path to the table generated by `bioat bam mpileup_to_table`.
This table should contain base mutation information for a short genome region (no more than 1k nt).
output_fig (str): Path to the output figure file.
target_seq (str, optional): Target sequence to align against the reference sequence in `mpileup.table`.
Examples:
- 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^).
- '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^).
- 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM.
Defaults to None.
reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`.
Can be a FASTA file or a DNA sequence. Defaults to None.
input_table_header (bool, optional): Whether the `input_table` contains a header. Defaults to True.
output_fig_fmt (str, optional): Format of the output figure ("pdf" or "png"). Defaults to "pdf".
output_fig_dpi (int, optional): DPI for the output figure. Defaults to 300.
show_indel (bool, optional): Whether to show indel information in the output figure. Defaults to True.
show_index (bool, optional): Whether to display index information in the output figure. Defaults to True.
box_border (bool, optional): Whether to display box borders in the output figure. Defaults to True.
box_space (int, optional): Space size between two boxes. Defaults to 1.
min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255).
max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0).
min_ratio (float, optional): Mutation ratio below `min_ratio` will be shown as white. Defaults to 0.0.
max_ratio (float, optional): Mutation ratio above `max_ratio` will be capped. Defaults to 1.0.
region_extend_length (int, optional): Number of base pairs to extend on either side of the region. Defaults to 0.
local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple:
(<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>).
Defaults to None.
local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0.
PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0.
get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False.
log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO".
Returns:
None
Examples:
>>> # Generate mutation table
>>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info.tsv
>>> # Generate region heatmap
>>> bioat target_seq region_heatmap --input_table test_sorted.mpileup.info.tsv --output_fig test_sorted.mpileup.info.pdf
"""
lm.set_names(func_name="region_heatmap")
lm.set_level(log_level)
mpl, plt, PatchCollection, Rectangle = _load_matplotlib_plotting()
if get_built_in_target_seq:
lm.logger.info(
"You can use <key> in built-in <target_seq> to represent your target_seq:\n"
+ "\t<key>\t<target_seq>\n"
+ "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]),
)
lm.logger.info("exit because of the defination for get_built_in_target_seq")
sys.exit(0)
else:
# set sgRNA info
lm.logger.debug(f"load target_seq = {target_seq}")
if target_seq in TARGET_SEQ_LIB:
lm.logger.debug(
f"target_seq information is abtained from the <key>={target_seq}",
)
target_seq = TARGET_SEQ_LIB[target_seq]
lm.logger.debug(f"target_seq is refered to {target_seq}")
elif target_seq is None:
pass
elif isinstance(target_seq, str):
lm.logger.debug(f"target_seq is refered to {target_seq}")
else:
lm.logger.error(
"<target_seq> should be set as:"
"DNA sequence / <key> in <get_built_in_target_seq> / None,"
f"but your <target_seq> is {target_seq}",
)
sys.exit(1)
# set alignment
aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix)
# load mpileup.info.tsv
if input_table_header:
df_bases = pd.read_csv(input_table, sep="\t")
else:
df_bases = pd.read_csv(input_table, sep="\t", header=None)
lm.logger.debug(
"df_bases.head(10):\n"
+ tabulate(df_bases.head(10), headers="keys", tablefmt="psql"),
)
# ref_seq & ref_seq_rc
if reference_seq:
# fa file or seq str to seq str
if os.path.isfile(reference_seq):
f = (
open(reference_seq)
if not reference_seq.endswith(".gz")
else gzip.open(reference_seq, "rt")
)
reference_seq = "".join([i.rstrip() for i in f.readlines()[1:]])
if reference_seq.count(">") > 0:
raise BioatFileFormatError
ref_seq = Seq(reference_seq)
lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}")
else:
ref_seq = Seq("".join(df_bases.ref_base))
lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}")
lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}")
ref_seq_rc = ref_seq.reverse_complement()
lm.logger.debug(
f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}",
)
# target_seq
if target_seq is None:
target_seq = ref_seq
PAM = None
elif "^" in target_seq:
if target_seq.find("^") == 0:
PAM, target_seq = target_seq[1:].strip().split("^")
target_seq = Seq(target_seq.upper())
PAM = {"PAM": PAM, "position": 0, "weight": PAM_priority_weight}
else:
target_seq, PAM = target_seq[:-1].strip().split("^")
target_seq = Seq(target_seq.upper())
PAM = {
"PAM": PAM,
"position": len(target_seq),
"weight": PAM_priority_weight,
}
else:
lm.logger.debug(
f"no PAM sequence is defined from target_seq = {target_seq}",
)
target_seq = Seq(target_seq)
PAM = None
lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}")
# fwd alignment & rev alignment
# 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上
if not PAM:
lm.logger.info("use PAMless mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
else:
lm.logger.info("use PAM mode to align")
# fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM)
# lm.logger.debug(f'final_align_forward:\n{fwd}')
# rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM)
# lm.logger.debug(f'final_align_reverse:\n{rev}')
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
# get fwd alignment info
reference_seq = fwd["alignment"]["reference_seq"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
align_info = fwd["alignment"]["aln_info"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
target_seq = fwd["alignment"]["target_seq"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
aln_score = fwd["aln_score"]
lm.logger.info(
f"Forward best alignment:\n"
f"reference : {reference_seq}\n"
f"align_info : {align_info}\n"
f"target_seq : {target_seq}\n"
f"align_score: {aln_score}",
)
# get rev alignment info
lm.logger.info(rev)
reference_seq_rev = rev["alignment"]["reference_seq"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
align_info_rev = rev["alignment"]["aln_info"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
target_seq_rev = rev["alignment"]["target_seq"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
aln_score_rev = rev["aln_score"]
lm.logger.info(
f"Reverse best alignment:\n"
f"reference : {reference_seq_rev}\n"
f"align_info : {align_info_rev}\n"
f"target_seq : {target_seq_rev}\n"
f"align_score: {aln_score_rev}",
)
# define align direction and final align res
# make alignment info
if any(
x >= local_alignment_min_score
for x in (fwd["aln_score"], rev["aln_score"])
):
if fwd["aln_score"] >= rev["aln_score"]:
aln_direction = "Forward Alignment"
aln = fwd
else:
aln_direction = "Reverse Alignment"
aln = rev
reference_seq = reference_seq_rev
align_info = align_info_rev
target_seq = target_seq_rev
aln_score = aln_score_rev
else:
lm.logger.critical("Alignment Error!")
sys.exit(1)
# mark aligned all bases in target_seq
ref_seq_length = len(ref_seq)
target_seq_aln = [""] * ref_seq_length
# mark inserted bases in target_seq
target_seq_aln_insert = [""] * ref_seq_length
DNA_rev_cmp_dict = {
"A": "T",
"T": "A",
"C": "G",
"G": "C",
"N": "N",
"-": "-",
}
reference_seq.count("-")
target_seq.count("-")
ref_gap_count = 0
ref_del_str = ""
aln_start = aln["ref_aln_start"]
aln_end = aln["ref_aln_end"]
# update target_seq_aln
lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (before update) = {target_seq_aln_insert}",
)
for idx, ref_base in enumerate(reference_seq):
# reference : GCCTCTGGAGAGGGAGGAGGG
# align_info : |.|.|||..|..|||||.||-
# target_seq : GGCACTGCGGCTGGAGGTGG-
if ref_base != "-":
ref_gap_count += 0
ref_del_str = ""
target_seq_aln[aln_start + idx - ref_gap_count] = (
ref_del_str + target_seq[idx]
)
else:
ref_gap_count += 1
ref_del_str += target_seq[idx] # for continuous gap
target_seq_aln_insert[aln_start + idx] = target_seq[idx]
lm.logger.debug(
"update target_seq_aln: show ''.join(target_seq_aln)\n"
f"target_seq_aln = {''.join(target_seq_aln)}\n"
f"target_seq = {target_seq}",
)
lm.logger.debug("update target_seq_aln: Done")
lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (after update) = {target_seq_aln_insert}",
)
# if reverse alignment
if aln_direction == "Reverse Alignment":
# illustrate reverse alignment as forward alignment
target_seq_aln = target_seq_aln[::-1]
target_seq_aln_insert = target_seq_aln_insert[::-1]
lm.logger.debug("use reverse alignment")
lm.logger.debug(f"target_seq_aln = {target_seq_aln}")
lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}")
# add PAM info
# print(target_seq_aln)
# print(target_seq_aln_insert)
# target_seq_aln
# ['-','-','G','G','C','A','G','C','G','G','C','T','G','G','A','A','A','A','A','A','A','A','A','A','A','A',
# 'A','A','A','G','T','','']
# target_seq_aln_insert
# ['','','','','','','','','','','','','','','','','','','','','','','','','','','','','','A','G','','']
# set plot_region
# print(aln_start, aln_end, region_extend_length, ref_seq_length)
# 56 76 25 266
# print(aln_start - region_extend_length)
# 31
# print(aln_end + region_extend_length)
# 101
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(
aln_end + region_extend_length, ref_seq_length
) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bases.iloc[plot_region[0] : plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql"),
)
# make plot
# set color
# show indel
np.set_printoptions(suppress=True)
indel_plot_state = show_indel
box_border_plot_state = box_border
# set panel size
panel_box_width = 0.4
panel_box_heigth = 0.4
panel_space = 0.05
panel_box_space = box_space
# color part
base_colors = {
"A": "#04E3E3",
"T": "#F9B874",
"C": "#B9E76B",
"G": "#F53798",
"N": "#AAAAAA",
"-": "#AAAAAA",
}
# make color breaks
color_break_num = 20
break_step = 1.0 / color_break_num
# mutation ratio of box lower than min_ratio will show as white
min_color_value = min_ratio
# mutation ratio of box lower than max_ratio will show as white
max_color_value = max_ratio
color_break = np.round(
np.arange(min_color_value, max_color_value, break_step), 5
)
# min_color: min color to plot heatmap with RGB format
# max_color: max color to plot heatmap with RGB format
color_list = make_color_list(min_color, max_color, len(color_break) - 1, "HEX")
color_list = ["#FFFFFF", *color_list]
lm.logger.debug(f"color_list = {color_list}")
# get plot info
total_box_count = plot_region[1] - plot_region[0]
# calculate base info and fix zero
base_sum_count = (
df_bases_select.loc[:, ["A", "G", "C", "T"]].sum(axis=1).astype(int)
)
lm.logger.debug(f"base_sum_count =\n{base_sum_count.values}")
total_sum_count = (
df_bases_select[["A", "G", "C", "T", "del_count", "insert_count"]]
.sum(axis=1)
.astype(int)
)
lm.logger.debug(f"total_sum_count =\n{total_sum_count.values}")
# fix 0 -> ["A.ratio", list(df_bases_select["A"] / base_sum_count)]
base_sum_count[base_sum_count == 0] = 1
total_sum_count[total_sum_count == 0] = 1
# make plot size
lm.logger.debug(f"aln =\n{aln}")
if indel_plot_state:
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * 6 + [0.5] * 6
panel_space_coef = (
[1.0] * 3 + [0.3] * 3 + [1.0, 0.3, 1.0] + [0.3] * 3 + [1.0, 0.3]
)
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", target_seq_aln[plot_region[0] : plot_region[1]]],
["Ref_seq", df_bases_select.ref_base],
["A", np.array(df_bases_select["A"])],
["G", np.array(df_bases_select["G"])],
["C", np.array(df_bases_select["C"])],
["T", np.array(df_bases_select["T"])],
["Del", np.array(df_bases_select["del_count"])],
["Ins", np.array(df_bases_select["insert_count"])],
["A %", list(df_bases_select["A"] / base_sum_count)],
["G %", list(df_bases_select["G"] / base_sum_count)],
["C %", list(df_bases_select["C"] / base_sum_count)],
["T %", list(df_bases_select["T"] / base_sum_count)],
["Del %", list(df_bases_select["del_count"] / total_sum_count)],
["Ins %", list(df_bases_select["insert_count"] / total_sum_count)],
]
else:
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * 4 + [0.5] * 4
panel_space_coef = [1.0] * 3 + [0.3] * 3 + [1.0] + [0.3] * 3
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", target_seq_aln[plot_region[0] : plot_region[1]]],
["Ref_seq", df_bases_select.ref_base],
["A", np.array(df_bases_select["A"])],
["G", np.array(df_bases_select["G"])],
["C", np.array(df_bases_select["C"])],
["T", np.array(df_bases_select["T"])],
["A %", list(df_bases_select["A"] / base_sum_count)],
["G %", list(df_bases_select["G"] / base_sum_count)],
["C %", list(df_bases_select["C"] / base_sum_count)],
["T %", list(df_bases_select["T"] / base_sum_count)],
]
# get box and space info
box_height_list = np.array(panel_height_coef) * panel_box_heigth
panel_space_list = np.array(panel_space_coef) * panel_space
lm.logger.debug(f"box_height_list = {box_height_list}")
lm.logger.debug(f"panel_space_list = {panel_space_list}")
# calculate figure total width and height
figure_width = (
total_box_count * panel_box_width
+ (total_box_count - 1) * panel_box_space
+ panel_box_width * 2
)
figure_height = sum(box_height_list) + sum(panel_space_list)
lm.logger.debug(f"figure_width = {figure_width}")
lm.logger.debug(f"figure_height = {figure_height}")
# make all box_x
box_x_vec = np.arange(
0, figure_width + panel_box_width, panel_box_width + panel_box_space
)
box_x_vec = box_x_vec[: (len(ref_seq) + 1)]
# lm.logger.debug(f'box_x_vec (x for each column) =\n{box_x_vec}')
# make box border
if box_border_plot_state:
box_edgecolor = "#AAAAAA"
box_linestyle = "-"
box_linewidth = 2
lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}")
lm.logger.debug(f"box_edgecolor = {box_edgecolor}")
lm.logger.debug(f"box_linestyle = {box_linestyle}")
lm.logger.debug(f"box_linewidth = {box_linewidth}")
else:
box_edgecolor = "#FFFFFF"
box_linestyle = "None"
box_linewidth = 0
lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}")
lm.logger.debug(f"box_edgecolor = {box_edgecolor}")
lm.logger.debug(f"box_linestyle = {box_linestyle}")
lm.logger.debug(f"box_linewidth = {box_linewidth}")
# make box_y initialize
current_y = 0
lm.logger.debug("start to plot")
fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1))
lm.logger.debug(
f"set new figure, figsize = ({figure_width * 1.1}, {figure_height * 1.1})",
)
plt.set_loglevel("info")
# will show matplotlib debug log with logging
ax = fig.add_subplot(111, aspect="equal")
# will show matplotlib debug log with logging
plt.xlim([0, figure_width])
plt.ylim([-figure_height, 0])
plt.axis("off")
lm.logger.debug(f"plt.xlim([0, {figure_width}])")
lm.logger.debug(f"plt.ylim([-{figure_height}, 0])")
lm.logger.debug('plt.axis("off")')
# make plot
text_list = []
patches = []
for panel_index in range(len(panel_height_coef)):
# panel name
panel_name = plot_data_list[panel_index][0]
panel_name_x = box_x_vec[0]
panel_name_y = current_y - box_height_list[panel_index] * 0.5
text_list.append((panel_name_x, panel_name_y, panel_name, 10))
# plot panel box
if panel_name == "Ref_index":
# don't draw box, only add text
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_x = box_x_vec[index + 1]
text_list.append(
(
box_x + panel_box_width * 0.5,
current_y - box_height_list[panel_index] * 0.5,
str(box_value),
7,
),
)
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
elif panel_name in ["Target_seq", "Ref_seq"]:
for index, box_value in enumerate(plot_data_list[panel_index][1]):
if box_value == "":
box_fill = False
box_color = "#FFFFFF"
else:
if "Reverse" in aln_direction:
if panel_name == "Ref_seq":
box_value = "".join(
[DNA_rev_cmp_dict.get(x) for x in box_value]
)
else:
pass
box_color = base_colors.get(box_value[0])
else:
box_color = base_colors.get(box_value[-1])
if not box_color:
box_fill = False
box_color = "#FFFFFF"
else:
box_fill = True
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=box_fill,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
str(box_value),
16,
),
)
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
elif panel_name in ["A", "G", "C", "T", "Del", "Ins"]:
if panel_name in ["Del", "Ins"]:
box_ratio = plot_data_list[panel_index][1] / total_sum_count
else:
box_ratio = plot_data_list[panel_index][1] / base_sum_count
box_color_list = map_color(box_ratio, color_break, color_list)
lm.logger.debug(f"get box_color_list for {panel_name}")
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_color = box_color_list[index]
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
str(box_value),
6,
),
)
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
else:
box_color_list = map_color(
plot_data_list[panel_index][1], color_break, color_list
)
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_color = box_color_list[index]
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
round(box_value * 100, 4),
6,
),
)
# make next panel_y
if panel_index < len(panel_space_list):
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
# plot box
lm.logger.debug("plot rectangles")
ax.add_collection(PatchCollection(patches, match_original=True))
lm.logger.debug("plot text on each rectangle")
for text_x, text_y, text_info, text_fontsize in text_list:
plt.text(
x=text_x,
y=text_y,
s=text_info,
horizontalalignment="center",
verticalalignment="center",
fontsize=text_fontsize,
fontname="Arial",
)
# output plot
fig.savefig(
fname=output_fig,
bbox_inches="tight",
dpi=output_fig_dpi,
format=output_fig_fmt,
)
[docs]
def region_heatmap_compare(
self,
input_tables: str,
labels: str | None = None,
target_seq: str | None = None, # sgRNA
reference_seq: str | None = None, # target locus
output_fig_heatmap: str | None = None,
output_fig_count_ratio: str | None = None,
output_table_heatmap: str | None = None,
output_table_count_ratio: str | None = None,
output_fig_fmt: str = "pdf",
input_table_header: bool = True,
to_base: tuple = ("A", "G", "C", "T", "Ins", "Del"),
heatmap_mut_direction: tuple = ("CT", "GA"),
count_ratio="all",
region_extend_length: int = 5,
output_fig_dpi: int = 100,
show_indel: bool = True,
show_index: bool = True,
block_ref: bool = True,
box_border: bool = False,
box_space: float = 0.03,
min_color: tuple = (250, 239, 230),
max_color: tuple = (154, 104, 57),
min_ratio: float = 0.001,
max_ratio: float = 0.99,
local_alignment_scoring_matrix: tuple = (5, -4, -24, -8),
local_alignment_min_score: int = 15,
PAM_priority_weight: float = 1.0,
get_built_in_target_seq: bool = False,
log_level: str = "INFO",
):
"""Plot region mutation information for multiple conditions.
This function generates a comparison of mutation information across multiple conditions
using tables generated by `bioat bam mpileup_to_table`.
Args:
input_tables (str):
Paths to input tables generated by `bioat bam mpileup_to_table`, separated by commas.
Each table should contain mutation information for a short genomic region (≤1k nt).
labels (str):
Labels for the panels, separated by commas.
target_seq (str, optional):
Sequence to align against the reference sequence in `mpileup.table`. Examples:
- 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^).
- '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^).
- 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM.
Defaults to None.
reference_seq (str, optional):
Custom reference sequence to overwrite the one in `mpileup.table`.
Can be a FASTA file or a DNA sequence. Defaults to None.
output_fig_heatmap (str, optional):
Path to the heatmap output figure. Defaults to None.
output_fig_count_ratio (str, optional):
Path to the count/ratio output figure. Defaults to None.
output_table_heatmap (str, optional):
Path to the heatmap output table. Defaults to None.
output_table_count_ratio (str, optional):
Path to the count/ratio output table. Defaults to None.
output_fig_fmt (str, optional):
Format of the output figures, either "pdf" or "png". Defaults to "pdf".
input_table_header (bool, optional):
Whether the input tables have headers. Defaults to True.
to_base (str, optional):
Reference bases to convert to, separated by commas. Defaults to "A,G,C,T,Ins,Del".
heatmap_mut_direction (str, optional):
Mutation directions to plot, specified as [from Base][to Base], separated by commas.
Defaults to "CT,GA".
count_ratio (str, optional):
Type of plot to generate: "count", "ratio", or "all". Defaults to "all".
region_extend_length (int, optional):
Number of base pairs to extend on both sides of the region. Defaults to 0.
output_fig_dpi (int, optional):
DPI for the output figures. Defaults to 300.
show_indel (bool, optional):
Whether to show indel information in the output figures. Defaults to True.
show_index (bool, optional):
Whether to display index information in the output figures. Defaults to True.
block_ref (bool, optional):
Whether to hide colors for reference sites in the output figures. Defaults to True.
box_border (bool, optional):
Whether to display box borders in the output figures. Defaults to True.
box_space (int, optional):
Space size between two boxes in the heatmap. Defaults to 1.
min_color (tuple, optional):
Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255).
max_color (tuple, optional):
Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0).
min_ratio (float, optional):
Mutation ratios below this value will appear white. Defaults to 0.0.
max_ratio (float, optional):
Mutation ratios above this value will be capped. Defaults to 1.0.
local_alignment_scoring_matrix (tuple, optional):
Alignment scoring parameters as a tuple:
(<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>).
Defaults to None.
local_alignment_min_score (int, optional):
Minimum alignment score to consider as a valid alignment. Defaults to 0.
PAM_priority_weight (float, optional):
Weight multiplier for PAM alignment scores. Defaults to 1.0.
get_built_in_target_seq (bool, optional):
Set to True to return built-in target sequence information. Defaults to False.
log_level (str, optional):
Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO".
Returns:
None
Examples:
>>> # Generate mutation tables
>>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info1.tsv
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info2.tsv
>>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info3.tsv
>>> # Generate region heatmap comparison
>>> bioat target_seq region_heatmap_compare --input_tables test_sorted.mpileup.info1.tsv,test_sorted.mpileup.info2.tsv,test_sorted.mpileup.info3.tsv --labels condition1,condition2,condition3 --target_seq HEK4
"""
lm.set_names(func_name="region_heatmap_compare")
lm.set_level(log_level)
mpl, plt, PatchCollection, Rectangle = _load_matplotlib_plotting()
base_color_dict = {
"A": "#04E3E3",
"T": "#F9B874",
"C": "#B9E76B",
"G": "#F53798",
"N": "#DDEAF6",
}
def plot_agct(x):
if x == "": # target_seq sgRNA左右没align的位置给空值
return "#FFFFFF"
if x in ["A", "G", "C", "T", "N"]:
return base_color_dict[x]
return "#AAAAAA"
# check whether return built-in information
if get_built_in_target_seq:
lm.logger.info(
"You can use <key> in built-in <target_seq> to represent your target_seq:\n"
+ "\t<key>\t<target_seq>\n"
+ "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]),
)
lm.logger.info("exit because of the defination for get_built_in_target_seq")
sys.exit(0)
else:
# set target_seq(sgRNA) info
lm.logger.debug(f"load target_seq = {target_seq}")
if target_seq in TARGET_SEQ_LIB:
lm.logger.debug(
f"target_seq information is abtained from the <key>={target_seq}",
)
target_seq = TARGET_SEQ_LIB[target_seq]
lm.logger.debug(f"target_seq is refered to {target_seq}")
elif target_seq is None:
pass
elif isinstance(target_seq, str):
lm.logger.debug(f"target_seq is refered to {target_seq}")
else:
lm.logger.error(
"<target_seq> should be set as:"
"DNA sequence / <key> in <get_built_in_target_seq> / None,"
f"but your <target_seq> is {target_seq}",
)
sys.exit(1)
# check if nothing to output
outputs = (
output_fig_heatmap,
output_fig_count_ratio,
output_table_heatmap,
output_table_count_ratio,
)
if not [i for i in outputs if i is not None]:
lm.logger.error(
"outputs should be set as free combination of (output_fig_heatmap, output_fig_count_ratio, "
"output_table_heatmap, output_table_count_ratio), but none of outputs was defined",
)
sys.exit(1)
# check parameter: to_base
# check to_base: tuple for A G C T Ins Del or combine like A G T Ins
default_bases = ("A", "G", "C", "T", "Ins", "Del")
lm.logger.debug("check to_base parameter")
for base in to_base:
if base not in default_bases:
lm.logger.error("check to_base parameter")
lm.logger.error(f"{base} not in {default_bases}")
raise BioatInvalidOptionError
lm.logger.debug(f"to_base was defined as {to_base}")
# check parameter: count_ratio
default_count_ratio_choices = ("count", "ratio", "all")
lm.logger.debug("check count_ratio parameter")
if count_ratio not in default_count_ratio_choices:
lm.logger.error("check count_ratio parameter")
lm.logger.error(f"{count_ratio} not in {default_count_ratio_choices}")
raise BioatInvalidOptionError
lm.logger.debug(f"count_ratio was defined as {count_ratio}")
# parse parameter: input_tables
if isinstance(input_tables, tuple):
input_tables = [str(i).strip() for i in input_tables]
elif isinstance(input_tables, str):
input_tables = (
input_tables.replace(" ", "").replace("\t", "").strip().split(",")
)
else:
raise BioatInvalidInputError
lm.logger.debug(f"parse tables... {input_tables}")
# parse parameter: labels, if labels is None, labels value will be input_tables names
if isinstance(labels, tuple):
labels = [str(i).strip() for i in labels]
elif isinstance(labels, str):
labels = (
labels.replace(" ", "").replace("\t", "").strip().split(",")
if labels
else input_tables
)
else:
raise BioatInvalidInputError
lm.logger.debug(f"parse labels... {labels}")
# set alignment
aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix)
# load & merge bmat files & set treatment names
ls = []
for input_table, label in zip(input_tables, labels, strict=False):
df = (
pd.read_csv(input_table, sep="\t")
if input_table_header
else pd.read_csv(input_table, sep="\t", header=None)
)
df["label"] = label
ls.append(df)
df_bases = pd.concat(ls)
lm.logger.debug(
"df_bases.head(10):\n"
+ tabulate(df_bases.head(8), headers="keys", tablefmt="psql"),
)
# parse parameter: reference_seq, if reference_seq is None, reference_seq value will be derived from df_bases
if reference_seq:
# fa file or seq str to seq str
if os.path.isfile(reference_seq):
f = (
open(reference_seq)
if not reference_seq.endswith(".gz")
else gzip.open(reference_seq, "rt")
)
reference_seq = "".join([i.rstrip() for i in f.readlines()[1:]])
if reference_seq.count(">") > 0:
raise BioatFileFormatError
ref_seq = Seq(reference_seq)
lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}")
else:
ref_seq = Seq("".join(df_bases.query("label==@labels[0]").ref_base))
lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}")
lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}")
ref_seq_rc = ref_seq.reverse_complement()
lm.logger.debug(
f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}",
)
# parse parameter: target_seq
if target_seq is None:
target_seq = ref_seq
PAM = None
elif "^" in target_seq:
if target_seq.find("^") == 0:
PAM, target_seq = target_seq[1:].strip().split("^")
target_seq = Seq(target_seq.upper())
PAM = {"PAM": PAM, "position": 0, "weight": PAM_priority_weight}
else:
target_seq, PAM = target_seq[:-1].strip().split("^")
target_seq = Seq(target_seq.upper())
PAM = {
"PAM": PAM,
"position": len(target_seq),
"weight": PAM_priority_weight,
}
else:
lm.logger.debug(
f"no PAM sequence is defined from target_seq = {target_seq}",
)
target_seq = Seq(target_seq)
PAM = None
lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}")
# fwd alignment & rev alignment
# 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上
if not PAM:
lm.logger.info("use PAMless mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner)
lm.logger.debug(f"final_align_reverse:\n{rev}")
else:
lm.logger.info("use PAM mode to align")
fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM)
lm.logger.debug(f"final_align_forward:\n{fwd}")
rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM)
lm.logger.debug(f"final_align_reverse:\n{rev}")
# get fwd alignment info
reference_seq = fwd["alignment"]["reference_seq"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
align_info = fwd["alignment"]["aln_info"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
target_seq = fwd["alignment"]["target_seq"][
fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1
]
aln_score = fwd["aln_score"]
lm.logger.info(
f"Forward best alignment:\n"
f"reference : {reference_seq}\n"
f"align_info : {align_info}\n"
f"target_seq : {target_seq}\n"
f"align_score: {aln_score}",
)
# get rev alignment info
reference_seq_rev = rev["alignment"]["reference_seq"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
align_info_rev = rev["alignment"]["aln_info"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
target_seq_rev = rev["alignment"]["target_seq"][
rev["ref_aln_start"] : rev["ref_aln_end"] + 1
]
aln_score_rev = rev["aln_score"]
lm.logger.info(
f"Reverse best alignment:\n"
f"reference : {reference_seq_rev}\n"
f"align_info : {align_info_rev}\n"
f"target_seq : {target_seq_rev}\n"
f"align_score: {aln_score_rev}",
)
# define align direction and final align res
# make alignment info
if any(
x >= local_alignment_min_score
for x in (fwd["aln_score"], rev["aln_score"])
):
if fwd["aln_score"] >= rev["aln_score"]:
aln_direction = "Forward Alignment"
aln = fwd
else:
aln_direction = "Reverse Alignment"
aln = rev
reference_seq = reference_seq_rev
align_info = align_info_rev
target_seq = target_seq_rev
aln_score = aln_score_rev
else:
lm.logger.critical("Alignment Error!")
sys.exit(1)
# mark aligned all bases in target_seq
ref_seq_length = len(ref_seq)
target_seq_aln = [""] * ref_seq_length
# mark inserted bases in target_seq
target_seq_aln_insert = [""] * ref_seq_length
DNA_rev_cmp_dict = {
"A": "T",
"T": "A",
"C": "G",
"G": "C",
"N": "N",
"-": "-",
}
reference_seq.count("-")
target_seq.count("-")
ref_gap_count = 0
ref_del_str = ""
aln_start = aln["ref_aln_start"]
aln_end = aln["ref_aln_end"]
# update target_seq_aln
lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (before update) = {target_seq_aln_insert}",
)
for idx, ref_base in enumerate(reference_seq):
# reference : GCCTCTGGAGAGGGAGGAGGG
# align_info : |.|.|||..|..|||||.||-
# target_seq : GGCACTGCGGCTGGAGGTGG-
if ref_base != "-":
ref_gap_count += 0
ref_del_str = ""
target_seq_aln[aln_start + idx - ref_gap_count] = (
ref_del_str + target_seq[idx]
)
else:
ref_gap_count += 1
ref_del_str += target_seq[idx] # for continuous gap
target_seq_aln_insert[aln_start + idx] = target_seq[idx]
lm.logger.debug(
"update target_seq_aln: show ''.join(target_seq_aln)\n"
f"target_seq_aln = {''.join(target_seq_aln)}\n"
f"target_seq = {target_seq}",
)
lm.logger.debug("update target_seq_aln: Done")
lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}")
lm.logger.debug(
f"target_seq_aln_insert (after update) = {target_seq_aln_insert}",
)
# if reverse alignment
if aln_direction == "Reverse Alignment":
# illustrate reverse alignment as forward alignment
target_seq_aln = target_seq_aln[::-1]
target_seq_aln_insert = target_seq_aln_insert[::-1]
lm.logger.debug("use reverse alignment")
lm.logger.debug(f"target_seq_aln = {target_seq_aln}")
lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}")
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(
aln_end + region_extend_length, ref_seq_length
) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bases.iloc[plot_region[0] : plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql"),
)
# ---------------------------------------------------------------->>>>>
# load .bmat file
# ---------------------------------------------------------------->>>>>
label_panel = labels
ls_bmat = input_tables
ls_bmat_table = []
for label, path_bmat in zip(label_panel, ls_bmat, strict=False):
# 读取文件
if input_table_header:
df = pd.read_csv(path_bmat, sep="\t")
else:
df = pd.read_csv(path_bmat, sep="\t", header=None)
# 删除不必要的列
df = df.drop(columns=["chr_name"], errors="ignore")
# 添加标签列(如果后面需要)
df["label"] = label
# 重命名除 chr_index 之外的列,避免 merge 时产生 _x/_y
rename_map = {c: f"{c}_{label}" for c in df.columns if c != "chr_index"}
df = df.rename(columns=rename_map)
# 写回
ls_bmat_table.append(df)
# 合并所有表,仅 chr_index 对齐
df_tmp = ls_bmat_table[0].copy()
for df in ls_bmat_table[1:]:
df_tmp = df_tmp.merge(df, on="chr_index", how="outer")
df_bmat_all = df_tmp.copy()
# ref_seq:用第一份表构建
ref_seq = "".join(ls_bmat_table[0][f"ref_base_{label_panel[0]}"].astype(str).fillna("").tolist())
possible_target_region_start = max(aln_start - region_extend_length, 0)
possible_target_region_end = min(
aln_end + region_extend_length, ref_seq_length
) # 266: 0~255 266 - 1
plot_region = (possible_target_region_start, possible_target_region_end)
df_bases_select = df_bmat_all.iloc[plot_region[0] : plot_region[1], :]
lm.logger.debug(
"df_bases_select:\n"
+ tabulate(df_bases_select, headers="keys", tablefmt="psql"),
)
# make plot
# set color
# show indel
np.set_printoptions(suppress=True)
box_border_plot_state = box_border
# set panel size
panel_box_width = 0.4
panel_box_heigth = 0.4
panel_space = 0.05
panel_box_space = box_space
# color part
# make color breaks
color_break_num = 20
break_step = 1.0 / color_break_num
# mutation ratio of box lower than min_ratio will show as white
min_color_value = min_ratio
# mutation ratio of box lower than max_ratio will show as white
max_color_value = max_ratio
color_break = np.round(
np.arange(min_color_value, max_color_value, break_step), 5
)
# min_color: min color to plot heatmap with RGB format
# max_color: max color to plot heatmap with RGB format
if isinstance(min_color, tuple):
low_color = min_color
else:
raise BioatInvalidOptionError
if isinstance(max_color, tuple):
high_color = max_color
else:
raise BioatInvalidOptionError
try:
color_list = make_color_list(
low_color, high_color, len(color_break) - 1, "Hex"
)
color_list = ["#FFFFFF", *color_list]
except:
lm.logger.debug(low_color, high_color)
lm.logger.debug(color_break)
# get plot info
total_box_count = plot_region[1] - plot_region[0]
# calculate base info and fix zero
ls_base_sum_count = []
ls_total_sum_count = []
for label in label_panel:
base_sum_count = df_bases_select[
[f"A_{label}", f"G_{label}", f"C_{label}", f"T_{label}"]
].apply(lambda x: x.sum(), axis=1)
# print base_sum_count
total_sum_count = df_bases_select[
[
f"A_{label}",
f"G_{label}",
f"C_{label}",
f"T_{label}",
f"del_count_{label}",
f"insert_count_{label}",
]
].apply(lambda x: x.sum(), axis=1)
base_sum_count[base_sum_count == 0] = 1
base_sum_count = base_sum_count + 1
total_sum_count = total_sum_count + 1
ls_base_sum_count.append(base_sum_count)
ls_total_sum_count.append(total_sum_count)
# make plot size
plot_data_list = None
panel_space_coef = None
panel_height_coef = None
if count_ratio == "all":
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(
default_bases
) * 2
else:
panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(
default_bases
)
# 根据bmat表格个数来控制方格高度
# panel_space_coef = [1, 1, 1] + [0.3] * 3 + [1, 0.3, 1] + [0.3] * 3 + [1, 0.3]
if count_ratio == "all":
panel_space_coef = [1, 1, 1] + (
[0.3] * (len(ls_bmat_table) - 1) + [4]
) * len(default_bases) * 2
else:
panel_space_coef = [1, 1, 1] + (
[0.3] * (len(ls_bmat_table) - 1) + [4]
) * len(default_bases)
# 更正heatmap align错误
# plot_heatmap_index = bmat_table_select.chr_index
plot_data_list = [
["Ref_index", df_bases_select.chr_index],
["Target_seq", np.array(target_seq_aln[plot_region[0] : plot_region[1]])],
["Ref_seq", df_bases_select[f"ref_base_{label_panel[0]}"]],
]
if count_ratio in {"count", "all"}:
for to_base in default_bases:
if to_base == "A":
for label in label_panel:
plot_data_list.append(
[f"{label}: to A", np.array(df_bases_select[f"A_{label}"])],
)
if to_base == "G":
for label in label_panel:
plot_data_list.append(
[f"{label}: to G", np.array(df_bases_select[f"G_{label}"])],
)
if to_base == "C":
for label in label_panel:
plot_data_list.append(
[f"{label}: to C", np.array(df_bases_select[f"C_{label}"])],
)
if to_base == "T":
for label in label_panel:
plot_data_list.append(
[f"{label}: to T", np.array(df_bases_select[f"T_{label}"])],
)
if to_base == "Del":
for label in label_panel:
plot_data_list.append(
[
f"{label}: to Del",
np.array(df_bases_select[f"del_count_{label}"]),
],
)
if to_base == "Ins":
for label in label_panel:
plot_data_list.append(
[
f"{label}: to Ins",
np.array(df_bases_select[f"insert_count_{label}"]),
],
)
if count_ratio in {"ratio", "all"}:
for to_base in default_bases:
if to_base == "A":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to A(%)",
np.array(
df_bases_select[f"A_{label}"]
/ ls_base_sum_count[index]
),
],
)
if to_base == "G":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to G(%)",
np.array(
df_bases_select[f"G_{label}"]
/ ls_base_sum_count[index]
),
],
)
if to_base == "C":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to C(%)",
np.array(
df_bases_select[f"C_{label}"]
/ ls_base_sum_count[index]
),
],
)
if to_base == "T":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to T(%)",
np.array(
df_bases_select[f"T_{label}"]
/ ls_base_sum_count[index]
),
],
)
if to_base == "Del":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to Del(%)",
np.array(
df_bases_select[f"del_count_{label}"]
/ ls_total_sum_count[index]
),
],
)
if to_base == "Ins":
for index, label in enumerate(label_panel):
plot_data_list.append(
[
f"{label}: to Ins(%)",
np.array(
df_bases_select[f"insert_count_{label}"]
/ ls_total_sum_count[index]
),
],
)
# get box and space info
box_height_list = np.array(panel_height_coef) * panel_box_heigth
panel_space_list = np.array(panel_space_coef) * panel_space
# for i, j in enumerate(panel_space_coef):
# print i,j
# print panel_space
# calculate figure total width and height
figure_width = (
total_box_count * panel_box_width
+ (total_box_count - 1) * panel_box_space
+ panel_box_width * 2
)
figure_height = sum(box_height_list) + sum(panel_space_list)
lm.logger.debug(f"figure_width = {figure_width}")
lm.logger.debug(f"figure_height = {figure_height}")
# make all box_x
box_x_vec = np.arange(
0, figure_width + panel_box_width, panel_box_width + panel_box_space
)
box_x_vec = box_x_vec[: (len(ref_seq) + 1)]
# make box border
if box_border_plot_state:
box_edgecolor = "#AAAAAA"
box_linestyle = "-"
box_linewidth = 2
else:
box_edgecolor = "#FFFFFF"
box_linestyle = "None"
box_linewidth = 0
# make box_y initialize
current_y = 0
df_matrix = pd.DataFrame(
[[""] * len(plot_data_list[0][1])] * len(panel_height_coef)
)
ls_row_name = []
for panel_index in range(len(panel_height_coef)):
ls_row_name.append(plot_data_list[panel_index][0])
df_matrix.iloc[panel_index, :] = df_matrix.iloc[panel_index, :].astype(
np.str_
)
for index, box_value in enumerate(plot_data_list[panel_index][1]):
row = panel_index
col = index
df_matrix.iloc[row, col] = box_value
# print(f'box_value = {box_value}, row = {row}, col = {col}')
# print(f'df_matrix.iloc[row, col] = {df_matrix.iloc[row, col]}')
# if box_value == '':
# raise ValueError
df_matrix.index = ls_row_name
# plot heatmap
if output_fig_heatmap:
# 预设参数
num_extend = region_extend_length
# rename columns of df_matrix
df_matrix.columns = range(1, df_matrix.shape[1] + 1)
dt_base = {"A": [], "G": [], "C": [], "T": []}
for info in heatmap_mut_direction:
dt_base[info[0]] += info[1]
lm.logger.debug(f"[mut_direction] for heatmap: {dt_base}")
lm.logger.debug(f"[label_panel] for heatmap: {label_panel}")
lm.logger.debug(
f"[region_extend_length] for heatmap: {region_extend_length}",
)
lm.logger.debug(f"[block_ref] for Target-seq multiplot: {block_ref}")
ls_col_not_null = df_matrix.columns[
-(df_matrix.loc["Target_seq", :] == "")
].tolist()
df_matrix.columns.tolist()
ls_plot_range = list(
range(ls_col_not_null[0] - num_extend, ls_col_not_null[-1] + num_extend)
)
try:
df_matrix = df_matrix.loc[:, ls_plot_range]
# 有不在的就不为空,不为空则判断成立
# print 'testTTTTTTTTTTTTTTT'
except:
msg = "The param [num_extend] is too large or <0, it must be a integer>=0"
raise ValueError(
msg
)
df_matrix.loc["Target_seq", :].tolist()
ls_ref = df_matrix.loc["Ref_seq", :].tolist()
ls_ratio = []
for label in label_panel:
df_sample = df_matrix[
[
label == i.split(": to")[0] and "(%)" not in i
for i in df_matrix.index
]
]
df_sample.index = ["A", "G", "C", "T", "Ins", "Del"]
ls_A = df_sample.T["A"].isnull().tolist()
ls_G = df_sample.T["G"].isnull().tolist()
ls_C = df_sample.T["C"].isnull().tolist()
ls_T = df_sample.T["T"].isnull().tolist()
ls_bool = []
for i in range(len(ls_A)):
if ls_A[i] & ls_C[i] & ls_G[i] & ls_T[i]:
ls_bool.append(False)
else:
ls_bool.append(True)
ls_bl_select_df_sample = ls_bool
df_sample = df_sample.loc[:, ls_bl_select_df_sample]
df_sample = df_sample.loc[["A", "G", "C", "T"], :].T.copy()
df_sample = df_sample.astype(int)
df_sample["sum"] = (
df_sample["A"] + df_sample["G"] + df_sample["C"] + df_sample["T"]
)
df_sample["Ref_seq"] = np.array(ls_ref)[ls_bl_select_df_sample]
df_sample["To_base"] = df_sample["Ref_seq"].map(lambda x: dt_base[x])
df_sample["Mut_count"] = 0
for recode in df_sample.iterrows():
index_row = recode[0]
for to_base in recode[1]["To_base"]:
df_sample.loc[index_row, "Mut_count"] += recode[1][to_base]
df_sample["Mut_ratio"] = df_sample["Mut_count"] / df_sample["sum"]
ls_sample_ratio = df_sample["Mut_ratio"].tolist()
ls_ratio.append(ls_sample_ratio)
df_ratio_all = pd.DataFrame(ls_ratio).T
df_ratio_all.columns = label_panel
try:
lm.logger.debug(f"Catch NA: {df_ratio_all.isna().sum().sum()}")
lm.logger.debug(df_ratio_all.head())
# print(f'df_matrix.T = \n{df_matrix.T}')
df_ratio_all.index = (
df_matrix.T["Ref_index"][ls_bl_select_df_sample]
.map(float)
.map(int)
.tolist()
)
df_ratio_all["Target_seq"] = df_matrix.T["Target_seq"][
ls_bl_select_df_sample
].tolist()
df_ratio_all["Ref_seq"] = df_matrix.T["Ref_seq"][
ls_bl_select_df_sample
].tolist()
# print(f'df_ratio_all.T = \n{df_ratio_all.T}')
# exit()
except ValueError:
start_idx = (
df_matrix.T["Ref_index"][ls_bl_select_df_sample]
.map(float)
.map(int)
.tolist()[0]
)
end_idx = start_idx + len(df_ratio_all.index.tolist())
df_ratio_all.index = np.arange(start_idx, end_idx)
df_ratio_all["Target_seq"] = df_matrix.T["Target_seq"].tolist()
df_ratio_all["Ref_seq"] = df_matrix.T["Ref_seq"].tolist()
df_ratio_all["On-Target"] = ""
df_ratio_all["Reference"] = ""
df_ratio_all = df_ratio_all[
["Target_seq", "Ref_seq", "On-Target", "Reference", *label_panel]
].T.copy()
df_onTarget_Ref = df_ratio_all.iloc[:2, :].fillna(" ")
lm.logger.debug(f"df_onTarget_Ref: \n{df_onTarget_Ref}")
# exit()
# print(df_onTarget_Ref == '')
df_onTarget_Ref_color = df_ratio_all.iloc[:2, :].fillna("").map(plot_agct)
lm.logger.debug(f"df_onTarget_Ref_color: \n{df_onTarget_Ref_color}")
# exit()
lm.logger.debug(f"df_ratio_all: \n{df_ratio_all}")
# exit()
df_plot = df_ratio_all.iloc[2:, :].copy()
lm.logger.debug(f"df_plot: \n{df_plot}")
# print(f'df_plot: \n{df_plot}')
# exit()
df_plot.iloc[2:, :] = df_plot.iloc[2:, :].astype(float) * 100
ls_max = []
for i in df_plot.iloc[2:, :].values.tolist():
ls_max.extend(i)
try:
ls_break = list(np.arange(0, max(ls_max), max(ls_max) / 100))
except ZeroDivisionError:
ls_break = [
0,
0.003,
0.006,
0.009,
0.012,
0.015,
0.018,
0.022,
0.026,
0.030,
0.035,
0.040,
0.05,
0.06,
0.07,
0.08,
1.00,
2.00,
3.00,
]
ls_color_middle = make_color_list(
low_color_RGB=convert_hex_to_rgb("#87BDDB"),
high_color_RGB=convert_hex_to_rgb("#0A306A"),
length_out=80,
return_fmt="Hex",
)
ls_color_top = ["#0A306A"] * 20
ls_color_bottom = (
["#EFEFEF"] * 5 + ["#DDEAF6"] * 5 + ["#A9CEE4"] * 5 + ["#87BDDB"] * 5
)
ls_color = ls_color_bottom + ls_color_middle + ls_color_top
lm.logger.debug(f"ls_color: {ls_color}")
# exit()
df_onTarget_Ref_tmp = df_onTarget_Ref.T
df_onTarget_Ref_tmp.loc[
df_onTarget_Ref_tmp["Target_seq"].isnull(), "Target_seq"
] = pd.NA
df_onTarget_Ref = df_onTarget_Ref_tmp.T
df_plot_rec = df_plot.copy()
lm.logger.debug(
"df_plot_rec:\n"
+ tabulate(df_plot_rec, headers="keys", tablefmt="psql"),
)
# heatmap颜色,去map_hex_for_matrix函数中调整
def map_hex_for_matrix(x):
# 判断value范围并返回颜色的Hex值
value = -1
for value in ls_break:
if x < value:
break
continue
return ls_color[ls_break.index(value) - 1]
lm.logger.debug(f"df_plot_rec: \n{df_plot_rec}")
# print(f'df_plot_rec: \n{df_plot_rec}')
# exit()
df_plot_rec_cmap = df_plot_rec.copy()
df_plot_rec_cmap.iloc[2:, :] = df_plot_rec.iloc[2:, :].map(map_hex_for_matrix)
# print(df_plot_rec_cmap)
# exit()
# print(f'df_plot_rec = \n{df_plot_rec}')
# print(df_plot_rec.loc['On-Target', :])
# print(df_plot_rec.loc['On-Target', :].dtype)
# df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef))
# print(f"df_onTarget_Ref.loc['Target_seq', :] = \n{df_onTarget_Ref.loc['Target_seq', :]}")
# print('df_onTarget_Ref', df_onTarget_Ref.loc['Target_seq', :] == '')
# print(f"df_plot_rec.loc['On-Target', :] = \n{df_plot_rec.loc['On-Target', :]}")
# print('df_plot_rec', df_plot_rec.loc['On-Target', :] == '')
# exit()
df_plot_rec.loc["On-Target", :] = df_onTarget_Ref.loc["Target_seq", :]
# print(df_plot_rec.loc['On-Target', :])
# print(df_onTarget_Ref.loc['Target_seq', :])
df_plot_rec.loc["Reference", :] = df_onTarget_Ref.loc["Ref_seq", :]
df_plot_rec_cmap.loc["On-Target", :] = df_onTarget_Ref.loc[
"Target_seq", :
].map(plot_agct)
df_plot_rec_cmap.loc["Reference", :] = df_onTarget_Ref.loc[
"Ref_seq", :
].map(plot_agct)
figure_width_heatmap = df_plot_rec.shape[1]
figure_height_heatmap = max(df_plot_rec.shape[0], len(ls_break) * 0.8)
scale = 1.1 # 1.1
fig_heatmap = plt.figure(
figsize=(figure_width_heatmap * scale, figure_height_heatmap * scale)
)
ax = fig_heatmap.add_subplot(111, aspect="equal")
plt.xlim([0, figure_width_heatmap + 5])
plt.ylim([-figure_height_heatmap - 0.5, 0])
plt.axis("off")
# export heatmap values
lm.logger.debug("export heatmap reference table...")
lm.logger.debug(f"df_plot_rec: \n{df_plot_rec.index}")
# exit()
for row in range(df_plot_rec.shape[0]):
row_name = df_plot_rec.index[row]
# add panel name
ax.text(
x=-1,
y=0.55 - row - 1.1,
s=row_name,
horizontalalignment="right",
verticalalignment="center",
fontsize=34 / 1.1 * scale,
fontname="DejaVu Sans",
alpha=1,
)
# add Target_seq and ref
if row_name in ["On-Target", "Reference"]:
for col in range(df_plot.shape[1]):
# plot Rectangle
site_x = col
site_y = -row + 0.05 - 1.05
ax.add_patch(
mpl.patches.Rectangle(
(site_x, site_y),
width=1,
height=1,
linestyle="-",
fill=True,
facecolor=df_plot_rec_cmap.iloc[row, col],
edgecolor="#AAAAAA"
if df_plot_rec_cmap.iloc[row, col] != "#FFFFFF"
else "#FFFFFF",
linewidth=3.5,
)
)
# plot text
site_x = col
site_y = -row - 0.55
ax.text(
x=site_x + 0.5,
y=site_y,
s=df_plot_rec.iloc[row, col],
horizontalalignment="center",
verticalalignment="center",
fontsize=34 / 1 * scale,
fontname="DejaVu Sans",
alpha=1,
)
# add seq panels
else:
for col in range(df_plot.shape[1]):
# plot Rectangle
site_x = col
site_y = -row - 1.05
ax.add_patch(
mpl.patches.Rectangle(
(site_x, site_y),
width=1,
height=1,
linestyle="-",
fill=True,
facecolor=df_plot_rec_cmap.iloc[row, col],
edgecolor="#FFFFFF",
linewidth=3,
)
)
# add cbar
# start from this site
site_x = df_plot.shape[1] + 1
site_y = -figure_height_heatmap * 0.1 - 1
step_scale = 0.1
lm.logger.debug(f"[cbar_scale]: {step_scale}")
for color in ls_color:
site_y += step_scale
ax.add_patch(
mpl.patches.Rectangle(
(site_x, site_y),
width=1,
height=step_scale,
fill=True,
facecolor=color,
edgecolor=color,
)
)
# add cbar text
# start from this site
site_x = df_plot.shape[1] + 1
site_y = -figure_height_heatmap * 0.1
for idx, label in enumerate(ls_break):
site_y += step_scale
if idx % 10 == 0:
ax.text(
x=site_x + 1.1,
y=site_y - 1,
s=round(label, 4),
horizontalalignment="left",
verticalalignment="center",
fontsize=34 / 1.5,
# fontsize=34 / 1.5 * step_scale,
fontname="DejaVu Sans",
alpha=1,
)
fig_heatmap.savefig(output_fig_heatmap, bbox_inches="tight") # 减少边缘空白
# plt.show()
if output_table_heatmap:
if output_table_heatmap.endswith(".csv"):
df_plot_rec.to_csv(output_table_heatmap, sep=",")
elif output_table_heatmap.endswith(".tsv"):
df_plot_rec.to_csv(output_table_heatmap, sep="\t")
if output_fig_count_ratio:
# set new figure
fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1))
ax = fig.add_subplot(111, aspect="equal")
plt.xlim([0, figure_width])
plt.ylim([-figure_height, 0])
plt.axis("off")
text_list = []
patches = []
# print panel_height_coef
for panel_index in range(len(panel_height_coef)):
# print 'panel index:', panel_index
# panel name
panel_name = plot_data_list[panel_index][0]
panel_name_x = box_x_vec[0]
# print panel_name
# print panel_name_x
panel_name_y = current_y - box_height_list[panel_index] * 0.5
# print panel_name_y
text_list.append((panel_name_x, panel_name_y, panel_name, 10))
# plot panel box
# plot Index行
if panel_name == "Ref_index":
# don't draw box, only add text
for index, box_value in enumerate(plot_data_list[panel_index][1]):
box_x = box_x_vec[index + 1]
text_list.append(
(
box_x + panel_box_width * 0.5,
current_y - box_height_list[panel_index] * 0.5,
str(box_value),
6,
)
)
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
# plot Target_seq和Ref_seq行
elif panel_name in ["Target_seq", "Ref_seq"]:
# if panel_name = Ref, form a new list to store plot info for checking the block_ref param
if panel_name == "Ref_seq":
plot_data_list_ref = plot_data_list[panel_index][1]
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# print(index, box_value)
if box_value == "":
box_fill = False
box_color = "#FFFFFF"
else:
if aln_direction == "Reverse Alignment":
if panel_name == "Ref_seq":
box_value = "".join(
[DNA_rev_cmp_dict.get(x) for x in box_value]
)
else:
### fix bug
for x in box_value:
box_value = DNA_rev_cmp_dict.get(x)
box_color = base_color_dict.get(box_value[0])
else:
box_color = base_color_dict.get(box_value[-1])
if not box_color:
box_fill = False
box_color = "#FFFFFF"
else:
box_fill = True
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=box_fill,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
str(box_value),
16,
)
)
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
# 用来plot Sample行的
# 这里plot counts
elif panel_name.split("to ")[-1] in default_bases:
# print 'panel name:', panel_name.split('to ')[-1].replace('(%)', '')
# print panel_name, 'check panel name'
# 改进了这里total 和 base的值的引用
if count_ratio == "all":
ls_base_sum_count_all = (
ls_base_sum_count * 2 * len(default_bases)
)
ls_total_sum_count_all = (
ls_total_sum_count * 2 * len(default_bases)
)
total_sum_count = ls_total_sum_count_all[panel_index - 3]
base_sum_count = ls_base_sum_count_all[panel_index - 3]
else:
ls_base_sum_count_all = ls_base_sum_count * len(default_bases)
ls_total_sum_count_all = ls_total_sum_count * len(default_bases)
total_sum_count = ls_total_sum_count_all[panel_index - 3]
base_sum_count = ls_base_sum_count_all[panel_index - 3]
panel_name = panel_name.split("to ")[-1]
if panel_name in ["Del", "Ins"]:
box_ratio = plot_data_list[panel_index][1] / total_sum_count
else:
box_ratio = plot_data_list[panel_index][1] / base_sum_count
box_color_list = map_color(box_ratio, color_break, color_list)
base_sum_count.tolist()
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# if count_ratio == 'all':
box_color = box_color_list[index]
# check block_ref state: count plot
if block_ref:
if plot_data_list_ref.iloc[index] == panel_name:
# if box_value >= ls_base_sum[index] * 0.8:
# box_fill = False
box_color = "#FFFFFF"
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
# text
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
str(box_value),
6,
)
)
# print box_value
# make next panel_y
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
# 这里plot ratio(剩下的都是base(%))
else:
box_color_list = map_color(
plot_data_list[panel_index][1], color_break, color_list
)
# print box_color_list
panel_name = panel_name.split("to ")[-1].replace("(%)", "")
for index, box_value in enumerate(plot_data_list[panel_index][1]):
# ratio
box_color = box_color_list[index]
# check block_ref state: ratio plot
if block_ref:
if plot_data_list_ref.iloc[index] == panel_name:
# if box_value >= 0.85:
# box_fill = False
box_color = "#FFFFFF"
box_x = box_x_vec[index + 1]
patches.append(
Rectangle(
xy=(box_x, current_y - box_height_list[panel_index]),
width=panel_box_width,
height=box_height_list[panel_index],
fill=True,
alpha=1,
linestyle=box_linestyle,
linewidth=box_linewidth,
edgecolor=box_edgecolor,
facecolor=box_color,
),
)
if "(%)" in plot_data_list[panel_index][0]:
# text
# ref
# print 'test', box_value
# if ARGS.block_ref == True:
# if box_value >= 0.99:
# box_value = 0
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
round(box_value * 100, 4),
6,
)
)
else:
text_list.append(
(
box_x + 0.5 * panel_box_width,
current_y - 0.5 * box_height_list[panel_index],
box_value,
6,
)
)
# make next panel_y
if panel_index < len(panel_space_list):
current_y = current_y - (
box_height_list[panel_index] + panel_space_list[panel_index]
)
# plot box
ax.add_collection(PatchCollection(patches, match_original=True))
# add text
for text_x, text_y, text_info, text_fontsize in text_list:
if " to " in str(text_info):
plt.text(
x=text_x + 0.3,
y=text_y,
s=text_info,
horizontalalignment="right",
verticalalignment="center",
fontsize=text_fontsize,
fontname="DejaVu Sans",
)
else:
plt.text(
x=text_x,
y=text_y,
s=text_info,
horizontalalignment="center",
verticalalignment="center",
fontsize=text_fontsize,
fontname="DejaVu Sans",
)
# output plot
if output_fig_fmt.upper == "PNG":
fig.savefig(
fname=output_fig_count_ratio,
bbox_inches="tight",
dpi=output_fig_dpi,
format="png",
)
else:
fig.savefig(
fname=output_fig_count_ratio,
bbox_inches="tight",
dpi=output_fig_dpi,
format="pdf",
)
if output_table_count_ratio:
if output_table_count_ratio.endswith(".csv"):
df_bases_select.to_csv(output_table_count_ratio, sep=",")
elif output_table_count_ratio.endswith(".tsv"):
df_bases_select.to_csv(output_table_count_ratio, sep="\t")
if __name__ == "__main__":
pass