Source code for bioat.target_seq

"""_summary_.

author: Herman Huanan Zhao
email: hermanzhaozzzz@gmail.com
homepage: https://github.com/hermanzhaozzzz

_description_

example 1:
    bioat list
        <in shell>:
            $ bioat list
        <in python consolo>:
            >>> from bioat.cli import Cli
            >>> bioat = Cli()
            >>> bioat.list()
            >>> print(bioat.list())

example 2:
    _example_
"""

import gzip
import os.path
import sys

import numpy as np
import pandas as pd
from Bio.Seq import Seq
from tabulate import tabulate

from bioat.exceptions import (
    BioatFileFormatError,
    BioatInvalidInputError,
    BioatInvalidOptionError,
)
from bioat.lib.libalignment import instantiate_pairwise_aligner
from bioat.lib.libcolor import convert_hex_to_rgb, make_color_list, map_color
from bioat.lib.libcrispr import TARGET_SEQ_LIB, run_target_seq_align
from bioat.lib.libpandas import set_option
from bioat.logger import LoggerManager

lm = LoggerManager(mod_name="bioat.target_seq")

set_option(log_level="ERROR")


def _load_matplotlib_plotting():
    import matplotlib as mpl

    mpl.use("Agg")

    import matplotlib.pyplot as plt
    from matplotlib.collections import PatchCollection
    from matplotlib.patches import Rectangle

    return mpl, plt, PatchCollection, Rectangle


[docs] class TargetSeq: """Target Deep Sequencing toolbox.""" lm.set_names(cls_name="TargetSeq") def __init__(self): pass
[docs] def region_heatmap( self, input_table: str, output_fig: str, target_seq: str | None = None, # sgRNA reference_seq: str | None = None, # target locus input_table_header: bool = True, output_fig_fmt: str = "pdf", output_fig_dpi: int = 100, show_indel: bool = True, show_index: bool = True, box_border: bool = False, box_space: float = 0.03, min_color: tuple = (250, 239, 230), max_color: tuple = (154, 104, 57), min_ratio: float = 0.001, max_ratio: float = 0.99, region_extend_length: int = 5, local_alignment_scoring_matrix: tuple = (5, -4, -24, -8), local_alignment_min_score: int = 15, PAM_priority_weight: float = 1.0, get_built_in_target_seq: bool = False, log_level: str = "INFO", ): """Plot region mutation information using a table generated by `bioat bam mpileup_to_table`. This function generates a visualization of mutation information for a specific genomic region based on a table created by the `bioat bam mpileup_to_table` command. Args: input_table (str): Path to the table generated by `bioat bam mpileup_to_table`. This table should contain base mutation information for a short genome region (no more than 1k nt). output_fig (str): Path to the output figure file. target_seq (str, optional): Target sequence to align against the reference sequence in `mpileup.table`. Examples: - 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^). - '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^). - 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM. Defaults to None. reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`. Can be a FASTA file or a DNA sequence. Defaults to None. input_table_header (bool, optional): Whether the `input_table` contains a header. Defaults to True. output_fig_fmt (str, optional): Format of the output figure ("pdf" or "png"). Defaults to "pdf". output_fig_dpi (int, optional): DPI for the output figure. Defaults to 300. show_indel (bool, optional): Whether to show indel information in the output figure. Defaults to True. show_index (bool, optional): Whether to display index information in the output figure. Defaults to True. box_border (bool, optional): Whether to display box borders in the output figure. Defaults to True. box_space (int, optional): Space size between two boxes. Defaults to 1. min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255). max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0). min_ratio (float, optional): Mutation ratio below `min_ratio` will be shown as white. Defaults to 0.0. max_ratio (float, optional): Mutation ratio above `max_ratio` will be capped. Defaults to 1.0. region_extend_length (int, optional): Number of base pairs to extend on either side of the region. Defaults to 0. local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple: (<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>). Defaults to None. local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0. PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0. get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False. log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO". Returns: None Examples: >>> # Generate mutation table >>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info.tsv >>> # Generate region heatmap >>> bioat target_seq region_heatmap --input_table test_sorted.mpileup.info.tsv --output_fig test_sorted.mpileup.info.pdf """ lm.set_names(func_name="region_heatmap") lm.set_level(log_level) mpl, plt, PatchCollection, Rectangle = _load_matplotlib_plotting() if get_built_in_target_seq: lm.logger.info( "You can use <key> in built-in <target_seq> to represent your target_seq:\n" + "\t<key>\t<target_seq>\n" + "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]), ) lm.logger.info("exit because of the defination for get_built_in_target_seq") sys.exit(0) else: # set sgRNA info lm.logger.debug(f"load target_seq = {target_seq}") if target_seq in TARGET_SEQ_LIB: lm.logger.debug( f"target_seq information is abtained from the <key>={target_seq}", ) target_seq = TARGET_SEQ_LIB[target_seq] lm.logger.debug(f"target_seq is refered to {target_seq}") elif target_seq is None: pass elif isinstance(target_seq, str): lm.logger.debug(f"target_seq is refered to {target_seq}") else: lm.logger.error( "<target_seq> should be set as:" "DNA sequence / <key> in <get_built_in_target_seq> / None," f"but your <target_seq> is {target_seq}", ) sys.exit(1) # set alignment aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix) # load mpileup.info.tsv if input_table_header: df_bases = pd.read_csv(input_table, sep="\t") else: df_bases = pd.read_csv(input_table, sep="\t", header=None) lm.logger.debug( "df_bases.head(10):\n" + tabulate(df_bases.head(10), headers="keys", tablefmt="psql"), ) # ref_seq & ref_seq_rc if reference_seq: # fa file or seq str to seq str if os.path.isfile(reference_seq): f = ( open(reference_seq) if not reference_seq.endswith(".gz") else gzip.open(reference_seq, "rt") ) reference_seq = "".join([i.rstrip() for i in f.readlines()[1:]]) if reference_seq.count(">") > 0: raise BioatFileFormatError ref_seq = Seq(reference_seq) lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}") else: ref_seq = Seq("".join(df_bases.ref_base)) lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}") lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}") ref_seq_rc = ref_seq.reverse_complement() lm.logger.debug( f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}", ) # target_seq if target_seq is None: target_seq = ref_seq PAM = None elif "^" in target_seq: if target_seq.find("^") == 0: PAM, target_seq = target_seq[1:].strip().split("^") target_seq = Seq(target_seq.upper()) PAM = {"PAM": PAM, "position": 0, "weight": PAM_priority_weight} else: target_seq, PAM = target_seq[:-1].strip().split("^") target_seq = Seq(target_seq.upper()) PAM = { "PAM": PAM, "position": len(target_seq), "weight": PAM_priority_weight, } else: lm.logger.debug( f"no PAM sequence is defined from target_seq = {target_seq}", ) target_seq = Seq(target_seq) PAM = None lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}") # fwd alignment & rev alignment # 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上 if not PAM: lm.logger.info("use PAMless mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") else: lm.logger.info("use PAM mode to align") # fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM) # lm.logger.debug(f'final_align_forward:\n{fwd}') # rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM) # lm.logger.debug(f'final_align_reverse:\n{rev}') fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") # get fwd alignment info reference_seq = fwd["alignment"]["reference_seq"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] align_info = fwd["alignment"]["aln_info"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] target_seq = fwd["alignment"]["target_seq"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] aln_score = fwd["aln_score"] lm.logger.info( f"Forward best alignment:\n" f"reference : {reference_seq}\n" f"align_info : {align_info}\n" f"target_seq : {target_seq}\n" f"align_score: {aln_score}", ) # get rev alignment info lm.logger.info(rev) reference_seq_rev = rev["alignment"]["reference_seq"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] align_info_rev = rev["alignment"]["aln_info"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] target_seq_rev = rev["alignment"]["target_seq"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] aln_score_rev = rev["aln_score"] lm.logger.info( f"Reverse best alignment:\n" f"reference : {reference_seq_rev}\n" f"align_info : {align_info_rev}\n" f"target_seq : {target_seq_rev}\n" f"align_score: {aln_score_rev}", ) # define align direction and final align res # make alignment info if any( x >= local_alignment_min_score for x in (fwd["aln_score"], rev["aln_score"]) ): if fwd["aln_score"] >= rev["aln_score"]: aln_direction = "Forward Alignment" aln = fwd else: aln_direction = "Reverse Alignment" aln = rev reference_seq = reference_seq_rev align_info = align_info_rev target_seq = target_seq_rev aln_score = aln_score_rev else: lm.logger.critical("Alignment Error!") sys.exit(1) # mark aligned all bases in target_seq ref_seq_length = len(ref_seq) target_seq_aln = [""] * ref_seq_length # mark inserted bases in target_seq target_seq_aln_insert = [""] * ref_seq_length DNA_rev_cmp_dict = { "A": "T", "T": "A", "C": "G", "G": "C", "N": "N", "-": "-", } reference_seq.count("-") target_seq.count("-") ref_gap_count = 0 ref_del_str = "" aln_start = aln["ref_aln_start"] aln_end = aln["ref_aln_end"] # update target_seq_aln lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (before update) = {target_seq_aln_insert}", ) for idx, ref_base in enumerate(reference_seq): # reference : GCCTCTGGAGAGGGAGGAGGG # align_info : |.|.|||..|..|||||.||- # target_seq : GGCACTGCGGCTGGAGGTGG- if ref_base != "-": ref_gap_count += 0 ref_del_str = "" target_seq_aln[aln_start + idx - ref_gap_count] = ( ref_del_str + target_seq[idx] ) else: ref_gap_count += 1 ref_del_str += target_seq[idx] # for continuous gap target_seq_aln_insert[aln_start + idx] = target_seq[idx] lm.logger.debug( "update target_seq_aln: show ''.join(target_seq_aln)\n" f"target_seq_aln = {''.join(target_seq_aln)}\n" f"target_seq = {target_seq}", ) lm.logger.debug("update target_seq_aln: Done") lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (after update) = {target_seq_aln_insert}", ) # if reverse alignment if aln_direction == "Reverse Alignment": # illustrate reverse alignment as forward alignment target_seq_aln = target_seq_aln[::-1] target_seq_aln_insert = target_seq_aln_insert[::-1] lm.logger.debug("use reverse alignment") lm.logger.debug(f"target_seq_aln = {target_seq_aln}") lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}") # add PAM info # print(target_seq_aln) # print(target_seq_aln_insert) # target_seq_aln # ['-','-','G','G','C','A','G','C','G','G','C','T','G','G','A','A','A','A','A','A','A','A','A','A','A','A', # 'A','A','A','G','T','',''] # target_seq_aln_insert # ['','','','','','','','','','','','','','','','','','','','','','','','','','','','','','A','G','',''] # set plot_region # print(aln_start, aln_end, region_extend_length, ref_seq_length) # 56 76 25 266 # print(aln_start - region_extend_length) # 31 # print(aln_end + region_extend_length) # 101 possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min( aln_end + region_extend_length, ref_seq_length ) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bases.iloc[plot_region[0] : plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql"), ) # make plot # set color # show indel np.set_printoptions(suppress=True) indel_plot_state = show_indel box_border_plot_state = box_border # set panel size panel_box_width = 0.4 panel_box_heigth = 0.4 panel_space = 0.05 panel_box_space = box_space # color part base_colors = { "A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#AAAAAA", "-": "#AAAAAA", } # make color breaks color_break_num = 20 break_step = 1.0 / color_break_num # mutation ratio of box lower than min_ratio will show as white min_color_value = min_ratio # mutation ratio of box lower than max_ratio will show as white max_color_value = max_ratio color_break = np.round( np.arange(min_color_value, max_color_value, break_step), 5 ) # min_color: min color to plot heatmap with RGB format # max_color: max color to plot heatmap with RGB format color_list = make_color_list(min_color, max_color, len(color_break) - 1, "HEX") color_list = ["#FFFFFF", *color_list] lm.logger.debug(f"color_list = {color_list}") # get plot info total_box_count = plot_region[1] - plot_region[0] # calculate base info and fix zero base_sum_count = ( df_bases_select.loc[:, ["A", "G", "C", "T"]].sum(axis=1).astype(int) ) lm.logger.debug(f"base_sum_count =\n{base_sum_count.values}") total_sum_count = ( df_bases_select[["A", "G", "C", "T", "del_count", "insert_count"]] .sum(axis=1) .astype(int) ) lm.logger.debug(f"total_sum_count =\n{total_sum_count.values}") # fix 0 -> ["A.ratio", list(df_bases_select["A"] / base_sum_count)] base_sum_count[base_sum_count == 0] = 1 total_sum_count[total_sum_count == 0] = 1 # make plot size lm.logger.debug(f"aln =\n{aln}") if indel_plot_state: panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * 6 + [0.5] * 6 panel_space_coef = ( [1.0] * 3 + [0.3] * 3 + [1.0, 0.3, 1.0] + [0.3] * 3 + [1.0, 0.3] ) plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", target_seq_aln[plot_region[0] : plot_region[1]]], ["Ref_seq", df_bases_select.ref_base], ["A", np.array(df_bases_select["A"])], ["G", np.array(df_bases_select["G"])], ["C", np.array(df_bases_select["C"])], ["T", np.array(df_bases_select["T"])], ["Del", np.array(df_bases_select["del_count"])], ["Ins", np.array(df_bases_select["insert_count"])], ["A %", list(df_bases_select["A"] / base_sum_count)], ["G %", list(df_bases_select["G"] / base_sum_count)], ["C %", list(df_bases_select["C"] / base_sum_count)], ["T %", list(df_bases_select["T"] / base_sum_count)], ["Del %", list(df_bases_select["del_count"] / total_sum_count)], ["Ins %", list(df_bases_select["insert_count"] / total_sum_count)], ] else: panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * 4 + [0.5] * 4 panel_space_coef = [1.0] * 3 + [0.3] * 3 + [1.0] + [0.3] * 3 plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", target_seq_aln[plot_region[0] : plot_region[1]]], ["Ref_seq", df_bases_select.ref_base], ["A", np.array(df_bases_select["A"])], ["G", np.array(df_bases_select["G"])], ["C", np.array(df_bases_select["C"])], ["T", np.array(df_bases_select["T"])], ["A %", list(df_bases_select["A"] / base_sum_count)], ["G %", list(df_bases_select["G"] / base_sum_count)], ["C %", list(df_bases_select["C"] / base_sum_count)], ["T %", list(df_bases_select["T"] / base_sum_count)], ] # get box and space info box_height_list = np.array(panel_height_coef) * panel_box_heigth panel_space_list = np.array(panel_space_coef) * panel_space lm.logger.debug(f"box_height_list = {box_height_list}") lm.logger.debug(f"panel_space_list = {panel_space_list}") # calculate figure total width and height figure_width = ( total_box_count * panel_box_width + (total_box_count - 1) * panel_box_space + panel_box_width * 2 ) figure_height = sum(box_height_list) + sum(panel_space_list) lm.logger.debug(f"figure_width = {figure_width}") lm.logger.debug(f"figure_height = {figure_height}") # make all box_x box_x_vec = np.arange( 0, figure_width + panel_box_width, panel_box_width + panel_box_space ) box_x_vec = box_x_vec[: (len(ref_seq) + 1)] # lm.logger.debug(f'box_x_vec (x for each column) =\n{box_x_vec}') # make box border if box_border_plot_state: box_edgecolor = "#AAAAAA" box_linestyle = "-" box_linewidth = 2 lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}") lm.logger.debug(f"box_edgecolor = {box_edgecolor}") lm.logger.debug(f"box_linestyle = {box_linestyle}") lm.logger.debug(f"box_linewidth = {box_linewidth}") else: box_edgecolor = "#FFFFFF" box_linestyle = "None" box_linewidth = 0 lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}") lm.logger.debug(f"box_edgecolor = {box_edgecolor}") lm.logger.debug(f"box_linestyle = {box_linestyle}") lm.logger.debug(f"box_linewidth = {box_linewidth}") # make box_y initialize current_y = 0 lm.logger.debug("start to plot") fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1)) lm.logger.debug( f"set new figure, figsize = ({figure_width * 1.1}, {figure_height * 1.1})", ) plt.set_loglevel("info") # will show matplotlib debug log with logging ax = fig.add_subplot(111, aspect="equal") # will show matplotlib debug log with logging plt.xlim([0, figure_width]) plt.ylim([-figure_height, 0]) plt.axis("off") lm.logger.debug(f"plt.xlim([0, {figure_width}])") lm.logger.debug(f"plt.ylim([-{figure_height}, 0])") lm.logger.debug('plt.axis("off")') # make plot text_list = [] patches = [] for panel_index in range(len(panel_height_coef)): # panel name panel_name = plot_data_list[panel_index][0] panel_name_x = box_x_vec[0] panel_name_y = current_y - box_height_list[panel_index] * 0.5 text_list.append((panel_name_x, panel_name_y, panel_name, 10)) # plot panel box if panel_name == "Ref_index": # don't draw box, only add text for index, box_value in enumerate(plot_data_list[panel_index][1]): box_x = box_x_vec[index + 1] text_list.append( ( box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5, str(box_value), 7, ), ) # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) elif panel_name in ["Target_seq", "Ref_seq"]: for index, box_value in enumerate(plot_data_list[panel_index][1]): if box_value == "": box_fill = False box_color = "#FFFFFF" else: if "Reverse" in aln_direction: if panel_name == "Ref_seq": box_value = "".join( [DNA_rev_cmp_dict.get(x) for x in box_value] ) else: pass box_color = base_colors.get(box_value[0]) else: box_color = base_colors.get(box_value[-1]) if not box_color: box_fill = False box_color = "#FFFFFF" else: box_fill = True box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=box_fill, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 16, ), ) # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) elif panel_name in ["A", "G", "C", "T", "Del", "Ins"]: if panel_name in ["Del", "Ins"]: box_ratio = plot_data_list[panel_index][1] / total_sum_count else: box_ratio = plot_data_list[panel_index][1] / base_sum_count box_color_list = map_color(box_ratio, color_break, color_list) lm.logger.debug(f"get box_color_list for {panel_name}") for index, box_value in enumerate(plot_data_list[panel_index][1]): box_color = box_color_list[index] box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 6, ), ) # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) else: box_color_list = map_color( plot_data_list[panel_index][1], color_break, color_list ) for index, box_value in enumerate(plot_data_list[panel_index][1]): box_color = box_color_list[index] box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], round(box_value * 100, 4), 6, ), ) # make next panel_y if panel_index < len(panel_space_list): current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) # plot box lm.logger.debug("plot rectangles") ax.add_collection(PatchCollection(patches, match_original=True)) lm.logger.debug("plot text on each rectangle") for text_x, text_y, text_info, text_fontsize in text_list: plt.text( x=text_x, y=text_y, s=text_info, horizontalalignment="center", verticalalignment="center", fontsize=text_fontsize, fontname="Arial", ) # output plot fig.savefig( fname=output_fig, bbox_inches="tight", dpi=output_fig_dpi, format=output_fig_fmt, )
[docs] def region_heatmap_compare( self, input_tables: str, labels: str | None = None, target_seq: str | None = None, # sgRNA reference_seq: str | None = None, # target locus output_fig_heatmap: str | None = None, output_fig_count_ratio: str | None = None, output_table_heatmap: str | None = None, output_table_count_ratio: str | None = None, output_fig_fmt: str = "pdf", input_table_header: bool = True, to_base: tuple = ("A", "G", "C", "T", "Ins", "Del"), heatmap_mut_direction: tuple = ("CT", "GA"), count_ratio="all", region_extend_length: int = 5, output_fig_dpi: int = 100, show_indel: bool = True, show_index: bool = True, block_ref: bool = True, box_border: bool = False, box_space: float = 0.03, min_color: tuple = (250, 239, 230), max_color: tuple = (154, 104, 57), min_ratio: float = 0.001, max_ratio: float = 0.99, local_alignment_scoring_matrix: tuple = (5, -4, -24, -8), local_alignment_min_score: int = 15, PAM_priority_weight: float = 1.0, get_built_in_target_seq: bool = False, log_level: str = "INFO", ): """Plot region mutation information for multiple conditions. This function generates a comparison of mutation information across multiple conditions using tables generated by `bioat bam mpileup_to_table`. Args: input_tables (str): Paths to input tables generated by `bioat bam mpileup_to_table`, separated by commas. Each table should contain mutation information for a short genomic region (≤1k nt). labels (str): Labels for the panels, separated by commas. target_seq (str, optional): Sequence to align against the reference sequence in `mpileup.table`. Examples: - 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^). - '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^). - 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM. Defaults to None. reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`. Can be a FASTA file or a DNA sequence. Defaults to None. output_fig_heatmap (str, optional): Path to the heatmap output figure. Defaults to None. output_fig_count_ratio (str, optional): Path to the count/ratio output figure. Defaults to None. output_table_heatmap (str, optional): Path to the heatmap output table. Defaults to None. output_table_count_ratio (str, optional): Path to the count/ratio output table. Defaults to None. output_fig_fmt (str, optional): Format of the output figures, either "pdf" or "png". Defaults to "pdf". input_table_header (bool, optional): Whether the input tables have headers. Defaults to True. to_base (str, optional): Reference bases to convert to, separated by commas. Defaults to "A,G,C,T,Ins,Del". heatmap_mut_direction (str, optional): Mutation directions to plot, specified as [from Base][to Base], separated by commas. Defaults to "CT,GA". count_ratio (str, optional): Type of plot to generate: "count", "ratio", or "all". Defaults to "all". region_extend_length (int, optional): Number of base pairs to extend on both sides of the region. Defaults to 0. output_fig_dpi (int, optional): DPI for the output figures. Defaults to 300. show_indel (bool, optional): Whether to show indel information in the output figures. Defaults to True. show_index (bool, optional): Whether to display index information in the output figures. Defaults to True. block_ref (bool, optional): Whether to hide colors for reference sites in the output figures. Defaults to True. box_border (bool, optional): Whether to display box borders in the output figures. Defaults to True. box_space (int, optional): Space size between two boxes in the heatmap. Defaults to 1. min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255). max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0). min_ratio (float, optional): Mutation ratios below this value will appear white. Defaults to 0.0. max_ratio (float, optional): Mutation ratios above this value will be capped. Defaults to 1.0. local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple: (<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>). Defaults to None. local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0. PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0. get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False. log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO". Returns: None Examples: >>> # Generate mutation tables >>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info1.tsv >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info2.tsv >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info3.tsv >>> # Generate region heatmap comparison >>> bioat target_seq region_heatmap_compare --input_tables test_sorted.mpileup.info1.tsv,test_sorted.mpileup.info2.tsv,test_sorted.mpileup.info3.tsv --labels condition1,condition2,condition3 --target_seq HEK4 """ lm.set_names(func_name="region_heatmap_compare") lm.set_level(log_level) mpl, plt, PatchCollection, Rectangle = _load_matplotlib_plotting() base_color_dict = { "A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#DDEAF6", } def plot_agct(x): if x == "": # target_seq sgRNA左右没align的位置给空值 return "#FFFFFF" if x in ["A", "G", "C", "T", "N"]: return base_color_dict[x] return "#AAAAAA" # check whether return built-in information if get_built_in_target_seq: lm.logger.info( "You can use <key> in built-in <target_seq> to represent your target_seq:\n" + "\t<key>\t<target_seq>\n" + "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]), ) lm.logger.info("exit because of the defination for get_built_in_target_seq") sys.exit(0) else: # set target_seq(sgRNA) info lm.logger.debug(f"load target_seq = {target_seq}") if target_seq in TARGET_SEQ_LIB: lm.logger.debug( f"target_seq information is abtained from the <key>={target_seq}", ) target_seq = TARGET_SEQ_LIB[target_seq] lm.logger.debug(f"target_seq is refered to {target_seq}") elif target_seq is None: pass elif isinstance(target_seq, str): lm.logger.debug(f"target_seq is refered to {target_seq}") else: lm.logger.error( "<target_seq> should be set as:" "DNA sequence / <key> in <get_built_in_target_seq> / None," f"but your <target_seq> is {target_seq}", ) sys.exit(1) # check if nothing to output outputs = ( output_fig_heatmap, output_fig_count_ratio, output_table_heatmap, output_table_count_ratio, ) if not [i for i in outputs if i is not None]: lm.logger.error( "outputs should be set as free combination of (output_fig_heatmap, output_fig_count_ratio, " "output_table_heatmap, output_table_count_ratio), but none of outputs was defined", ) sys.exit(1) # check parameter: to_base # check to_base: tuple for A G C T Ins Del or combine like A G T Ins default_bases = ("A", "G", "C", "T", "Ins", "Del") lm.logger.debug("check to_base parameter") for base in to_base: if base not in default_bases: lm.logger.error("check to_base parameter") lm.logger.error(f"{base} not in {default_bases}") raise BioatInvalidOptionError lm.logger.debug(f"to_base was defined as {to_base}") # check parameter: count_ratio default_count_ratio_choices = ("count", "ratio", "all") lm.logger.debug("check count_ratio parameter") if count_ratio not in default_count_ratio_choices: lm.logger.error("check count_ratio parameter") lm.logger.error(f"{count_ratio} not in {default_count_ratio_choices}") raise BioatInvalidOptionError lm.logger.debug(f"count_ratio was defined as {count_ratio}") # parse parameter: input_tables if isinstance(input_tables, tuple): input_tables = [str(i).strip() for i in input_tables] elif isinstance(input_tables, str): input_tables = ( input_tables.replace(" ", "").replace("\t", "").strip().split(",") ) else: raise BioatInvalidInputError lm.logger.debug(f"parse tables... {input_tables}") # parse parameter: labels, if labels is None, labels value will be input_tables names if isinstance(labels, tuple): labels = [str(i).strip() for i in labels] elif isinstance(labels, str): labels = ( labels.replace(" ", "").replace("\t", "").strip().split(",") if labels else input_tables ) else: raise BioatInvalidInputError lm.logger.debug(f"parse labels... {labels}") # set alignment aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix) # load & merge bmat files & set treatment names ls = [] for input_table, label in zip(input_tables, labels, strict=False): df = ( pd.read_csv(input_table, sep="\t") if input_table_header else pd.read_csv(input_table, sep="\t", header=None) ) df["label"] = label ls.append(df) df_bases = pd.concat(ls) lm.logger.debug( "df_bases.head(10):\n" + tabulate(df_bases.head(8), headers="keys", tablefmt="psql"), ) # parse parameter: reference_seq, if reference_seq is None, reference_seq value will be derived from df_bases if reference_seq: # fa file or seq str to seq str if os.path.isfile(reference_seq): f = ( open(reference_seq) if not reference_seq.endswith(".gz") else gzip.open(reference_seq, "rt") ) reference_seq = "".join([i.rstrip() for i in f.readlines()[1:]]) if reference_seq.count(">") > 0: raise BioatFileFormatError ref_seq = Seq(reference_seq) lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}") else: ref_seq = Seq("".join(df_bases.query("label==@labels[0]").ref_base)) lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}") lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}") ref_seq_rc = ref_seq.reverse_complement() lm.logger.debug( f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}", ) # parse parameter: target_seq if target_seq is None: target_seq = ref_seq PAM = None elif "^" in target_seq: if target_seq.find("^") == 0: PAM, target_seq = target_seq[1:].strip().split("^") target_seq = Seq(target_seq.upper()) PAM = {"PAM": PAM, "position": 0, "weight": PAM_priority_weight} else: target_seq, PAM = target_seq[:-1].strip().split("^") target_seq = Seq(target_seq.upper()) PAM = { "PAM": PAM, "position": len(target_seq), "weight": PAM_priority_weight, } else: lm.logger.debug( f"no PAM sequence is defined from target_seq = {target_seq}", ) target_seq = Seq(target_seq) PAM = None lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}") # fwd alignment & rev alignment # 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上 if not PAM: lm.logger.info("use PAMless mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") else: lm.logger.info("use PAM mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM) lm.logger.debug(f"final_align_reverse:\n{rev}") # get fwd alignment info reference_seq = fwd["alignment"]["reference_seq"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] align_info = fwd["alignment"]["aln_info"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] target_seq = fwd["alignment"]["target_seq"][ fwd["ref_aln_start"] : fwd["ref_aln_end"] + 1 ] aln_score = fwd["aln_score"] lm.logger.info( f"Forward best alignment:\n" f"reference : {reference_seq}\n" f"align_info : {align_info}\n" f"target_seq : {target_seq}\n" f"align_score: {aln_score}", ) # get rev alignment info reference_seq_rev = rev["alignment"]["reference_seq"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] align_info_rev = rev["alignment"]["aln_info"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] target_seq_rev = rev["alignment"]["target_seq"][ rev["ref_aln_start"] : rev["ref_aln_end"] + 1 ] aln_score_rev = rev["aln_score"] lm.logger.info( f"Reverse best alignment:\n" f"reference : {reference_seq_rev}\n" f"align_info : {align_info_rev}\n" f"target_seq : {target_seq_rev}\n" f"align_score: {aln_score_rev}", ) # define align direction and final align res # make alignment info if any( x >= local_alignment_min_score for x in (fwd["aln_score"], rev["aln_score"]) ): if fwd["aln_score"] >= rev["aln_score"]: aln_direction = "Forward Alignment" aln = fwd else: aln_direction = "Reverse Alignment" aln = rev reference_seq = reference_seq_rev align_info = align_info_rev target_seq = target_seq_rev aln_score = aln_score_rev else: lm.logger.critical("Alignment Error!") sys.exit(1) # mark aligned all bases in target_seq ref_seq_length = len(ref_seq) target_seq_aln = [""] * ref_seq_length # mark inserted bases in target_seq target_seq_aln_insert = [""] * ref_seq_length DNA_rev_cmp_dict = { "A": "T", "T": "A", "C": "G", "G": "C", "N": "N", "-": "-", } reference_seq.count("-") target_seq.count("-") ref_gap_count = 0 ref_del_str = "" aln_start = aln["ref_aln_start"] aln_end = aln["ref_aln_end"] # update target_seq_aln lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (before update) = {target_seq_aln_insert}", ) for idx, ref_base in enumerate(reference_seq): # reference : GCCTCTGGAGAGGGAGGAGGG # align_info : |.|.|||..|..|||||.||- # target_seq : GGCACTGCGGCTGGAGGTGG- if ref_base != "-": ref_gap_count += 0 ref_del_str = "" target_seq_aln[aln_start + idx - ref_gap_count] = ( ref_del_str + target_seq[idx] ) else: ref_gap_count += 1 ref_del_str += target_seq[idx] # for continuous gap target_seq_aln_insert[aln_start + idx] = target_seq[idx] lm.logger.debug( "update target_seq_aln: show ''.join(target_seq_aln)\n" f"target_seq_aln = {''.join(target_seq_aln)}\n" f"target_seq = {target_seq}", ) lm.logger.debug("update target_seq_aln: Done") lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (after update) = {target_seq_aln_insert}", ) # if reverse alignment if aln_direction == "Reverse Alignment": # illustrate reverse alignment as forward alignment target_seq_aln = target_seq_aln[::-1] target_seq_aln_insert = target_seq_aln_insert[::-1] lm.logger.debug("use reverse alignment") lm.logger.debug(f"target_seq_aln = {target_seq_aln}") lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}") possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min( aln_end + region_extend_length, ref_seq_length ) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bases.iloc[plot_region[0] : plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql"), ) # ---------------------------------------------------------------->>>>> # load .bmat file # ---------------------------------------------------------------->>>>> label_panel = labels ls_bmat = input_tables ls_bmat_table = [] for label, path_bmat in zip(label_panel, ls_bmat, strict=False): # 读取文件 if input_table_header: df = pd.read_csv(path_bmat, sep="\t") else: df = pd.read_csv(path_bmat, sep="\t", header=None) # 删除不必要的列 df = df.drop(columns=["chr_name"], errors="ignore") # 添加标签列(如果后面需要) df["label"] = label # 重命名除 chr_index 之外的列,避免 merge 时产生 _x/_y rename_map = {c: f"{c}_{label}" for c in df.columns if c != "chr_index"} df = df.rename(columns=rename_map) # 写回 ls_bmat_table.append(df) # 合并所有表,仅 chr_index 对齐 df_tmp = ls_bmat_table[0].copy() for df in ls_bmat_table[1:]: df_tmp = df_tmp.merge(df, on="chr_index", how="outer") df_bmat_all = df_tmp.copy() # ref_seq:用第一份表构建 ref_seq = "".join(ls_bmat_table[0][f"ref_base_{label_panel[0]}"].astype(str).fillna("").tolist()) possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min( aln_end + region_extend_length, ref_seq_length ) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bmat_all.iloc[plot_region[0] : plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql"), ) # make plot # set color # show indel np.set_printoptions(suppress=True) box_border_plot_state = box_border # set panel size panel_box_width = 0.4 panel_box_heigth = 0.4 panel_space = 0.05 panel_box_space = box_space # color part # make color breaks color_break_num = 20 break_step = 1.0 / color_break_num # mutation ratio of box lower than min_ratio will show as white min_color_value = min_ratio # mutation ratio of box lower than max_ratio will show as white max_color_value = max_ratio color_break = np.round( np.arange(min_color_value, max_color_value, break_step), 5 ) # min_color: min color to plot heatmap with RGB format # max_color: max color to plot heatmap with RGB format if isinstance(min_color, tuple): low_color = min_color else: raise BioatInvalidOptionError if isinstance(max_color, tuple): high_color = max_color else: raise BioatInvalidOptionError try: color_list = make_color_list( low_color, high_color, len(color_break) - 1, "Hex" ) color_list = ["#FFFFFF", *color_list] except: lm.logger.debug(low_color, high_color) lm.logger.debug(color_break) # get plot info total_box_count = plot_region[1] - plot_region[0] # calculate base info and fix zero ls_base_sum_count = [] ls_total_sum_count = [] for label in label_panel: base_sum_count = df_bases_select[ [f"A_{label}", f"G_{label}", f"C_{label}", f"T_{label}"] ].apply(lambda x: x.sum(), axis=1) # print base_sum_count total_sum_count = df_bases_select[ [ f"A_{label}", f"G_{label}", f"C_{label}", f"T_{label}", f"del_count_{label}", f"insert_count_{label}", ] ].apply(lambda x: x.sum(), axis=1) base_sum_count[base_sum_count == 0] = 1 base_sum_count = base_sum_count + 1 total_sum_count = total_sum_count + 1 ls_base_sum_count.append(base_sum_count) ls_total_sum_count.append(total_sum_count) # make plot size plot_data_list = None panel_space_coef = None panel_height_coef = None if count_ratio == "all": panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len( default_bases ) * 2 else: panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len( default_bases ) # 根据bmat表格个数来控制方格高度 # panel_space_coef = [1, 1, 1] + [0.3] * 3 + [1, 0.3, 1] + [0.3] * 3 + [1, 0.3] if count_ratio == "all": panel_space_coef = [1, 1, 1] + ( [0.3] * (len(ls_bmat_table) - 1) + [4] ) * len(default_bases) * 2 else: panel_space_coef = [1, 1, 1] + ( [0.3] * (len(ls_bmat_table) - 1) + [4] ) * len(default_bases) # 更正heatmap align错误 # plot_heatmap_index = bmat_table_select.chr_index plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", np.array(target_seq_aln[plot_region[0] : plot_region[1]])], ["Ref_seq", df_bases_select[f"ref_base_{label_panel[0]}"]], ] if count_ratio in {"count", "all"}: for to_base in default_bases: if to_base == "A": for label in label_panel: plot_data_list.append( [f"{label}: to A", np.array(df_bases_select[f"A_{label}"])], ) if to_base == "G": for label in label_panel: plot_data_list.append( [f"{label}: to G", np.array(df_bases_select[f"G_{label}"])], ) if to_base == "C": for label in label_panel: plot_data_list.append( [f"{label}: to C", np.array(df_bases_select[f"C_{label}"])], ) if to_base == "T": for label in label_panel: plot_data_list.append( [f"{label}: to T", np.array(df_bases_select[f"T_{label}"])], ) if to_base == "Del": for label in label_panel: plot_data_list.append( [ f"{label}: to Del", np.array(df_bases_select[f"del_count_{label}"]), ], ) if to_base == "Ins": for label in label_panel: plot_data_list.append( [ f"{label}: to Ins", np.array(df_bases_select[f"insert_count_{label}"]), ], ) if count_ratio in {"ratio", "all"}: for to_base in default_bases: if to_base == "A": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to A(%)", np.array( df_bases_select[f"A_{label}"] / ls_base_sum_count[index] ), ], ) if to_base == "G": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to G(%)", np.array( df_bases_select[f"G_{label}"] / ls_base_sum_count[index] ), ], ) if to_base == "C": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to C(%)", np.array( df_bases_select[f"C_{label}"] / ls_base_sum_count[index] ), ], ) if to_base == "T": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to T(%)", np.array( df_bases_select[f"T_{label}"] / ls_base_sum_count[index] ), ], ) if to_base == "Del": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to Del(%)", np.array( df_bases_select[f"del_count_{label}"] / ls_total_sum_count[index] ), ], ) if to_base == "Ins": for index, label in enumerate(label_panel): plot_data_list.append( [ f"{label}: to Ins(%)", np.array( df_bases_select[f"insert_count_{label}"] / ls_total_sum_count[index] ), ], ) # get box and space info box_height_list = np.array(panel_height_coef) * panel_box_heigth panel_space_list = np.array(panel_space_coef) * panel_space # for i, j in enumerate(panel_space_coef): # print i,j # print panel_space # calculate figure total width and height figure_width = ( total_box_count * panel_box_width + (total_box_count - 1) * panel_box_space + panel_box_width * 2 ) figure_height = sum(box_height_list) + sum(panel_space_list) lm.logger.debug(f"figure_width = {figure_width}") lm.logger.debug(f"figure_height = {figure_height}") # make all box_x box_x_vec = np.arange( 0, figure_width + panel_box_width, panel_box_width + panel_box_space ) box_x_vec = box_x_vec[: (len(ref_seq) + 1)] # make box border if box_border_plot_state: box_edgecolor = "#AAAAAA" box_linestyle = "-" box_linewidth = 2 else: box_edgecolor = "#FFFFFF" box_linestyle = "None" box_linewidth = 0 # make box_y initialize current_y = 0 df_matrix = pd.DataFrame( [[""] * len(plot_data_list[0][1])] * len(panel_height_coef) ) ls_row_name = [] for panel_index in range(len(panel_height_coef)): ls_row_name.append(plot_data_list[panel_index][0]) df_matrix.iloc[panel_index, :] = df_matrix.iloc[panel_index, :].astype( np.str_ ) for index, box_value in enumerate(plot_data_list[panel_index][1]): row = panel_index col = index df_matrix.iloc[row, col] = box_value # print(f'box_value = {box_value}, row = {row}, col = {col}') # print(f'df_matrix.iloc[row, col] = {df_matrix.iloc[row, col]}') # if box_value == '': # raise ValueError df_matrix.index = ls_row_name # plot heatmap if output_fig_heatmap: # 预设参数 num_extend = region_extend_length # rename columns of df_matrix df_matrix.columns = range(1, df_matrix.shape[1] + 1) dt_base = {"A": [], "G": [], "C": [], "T": []} for info in heatmap_mut_direction: dt_base[info[0]] += info[1] lm.logger.debug(f"[mut_direction] for heatmap: {dt_base}") lm.logger.debug(f"[label_panel] for heatmap: {label_panel}") lm.logger.debug( f"[region_extend_length] for heatmap: {region_extend_length}", ) lm.logger.debug(f"[block_ref] for Target-seq multiplot: {block_ref}") ls_col_not_null = df_matrix.columns[ -(df_matrix.loc["Target_seq", :] == "") ].tolist() df_matrix.columns.tolist() ls_plot_range = list( range(ls_col_not_null[0] - num_extend, ls_col_not_null[-1] + num_extend) ) try: df_matrix = df_matrix.loc[:, ls_plot_range] # 有不在的就不为空,不为空则判断成立 # print 'testTTTTTTTTTTTTTTT' except: msg = "The param [num_extend] is too large or <0, it must be a integer>=0" raise ValueError( msg ) df_matrix.loc["Target_seq", :].tolist() ls_ref = df_matrix.loc["Ref_seq", :].tolist() ls_ratio = [] for label in label_panel: df_sample = df_matrix[ [ label == i.split(": to")[0] and "(%)" not in i for i in df_matrix.index ] ] df_sample.index = ["A", "G", "C", "T", "Ins", "Del"] ls_A = df_sample.T["A"].isnull().tolist() ls_G = df_sample.T["G"].isnull().tolist() ls_C = df_sample.T["C"].isnull().tolist() ls_T = df_sample.T["T"].isnull().tolist() ls_bool = [] for i in range(len(ls_A)): if ls_A[i] & ls_C[i] & ls_G[i] & ls_T[i]: ls_bool.append(False) else: ls_bool.append(True) ls_bl_select_df_sample = ls_bool df_sample = df_sample.loc[:, ls_bl_select_df_sample] df_sample = df_sample.loc[["A", "G", "C", "T"], :].T.copy() df_sample = df_sample.astype(int) df_sample["sum"] = ( df_sample["A"] + df_sample["G"] + df_sample["C"] + df_sample["T"] ) df_sample["Ref_seq"] = np.array(ls_ref)[ls_bl_select_df_sample] df_sample["To_base"] = df_sample["Ref_seq"].map(lambda x: dt_base[x]) df_sample["Mut_count"] = 0 for recode in df_sample.iterrows(): index_row = recode[0] for to_base in recode[1]["To_base"]: df_sample.loc[index_row, "Mut_count"] += recode[1][to_base] df_sample["Mut_ratio"] = df_sample["Mut_count"] / df_sample["sum"] ls_sample_ratio = df_sample["Mut_ratio"].tolist() ls_ratio.append(ls_sample_ratio) df_ratio_all = pd.DataFrame(ls_ratio).T df_ratio_all.columns = label_panel try: lm.logger.debug(f"Catch NA: {df_ratio_all.isna().sum().sum()}") lm.logger.debug(df_ratio_all.head()) # print(f'df_matrix.T = \n{df_matrix.T}') df_ratio_all.index = ( df_matrix.T["Ref_index"][ls_bl_select_df_sample] .map(float) .map(int) .tolist() ) df_ratio_all["Target_seq"] = df_matrix.T["Target_seq"][ ls_bl_select_df_sample ].tolist() df_ratio_all["Ref_seq"] = df_matrix.T["Ref_seq"][ ls_bl_select_df_sample ].tolist() # print(f'df_ratio_all.T = \n{df_ratio_all.T}') # exit() except ValueError: start_idx = ( df_matrix.T["Ref_index"][ls_bl_select_df_sample] .map(float) .map(int) .tolist()[0] ) end_idx = start_idx + len(df_ratio_all.index.tolist()) df_ratio_all.index = np.arange(start_idx, end_idx) df_ratio_all["Target_seq"] = df_matrix.T["Target_seq"].tolist() df_ratio_all["Ref_seq"] = df_matrix.T["Ref_seq"].tolist() df_ratio_all["On-Target"] = "" df_ratio_all["Reference"] = "" df_ratio_all = df_ratio_all[ ["Target_seq", "Ref_seq", "On-Target", "Reference", *label_panel] ].T.copy() df_onTarget_Ref = df_ratio_all.iloc[:2, :].fillna(" ") lm.logger.debug(f"df_onTarget_Ref: \n{df_onTarget_Ref}") # exit() # print(df_onTarget_Ref == '') df_onTarget_Ref_color = df_ratio_all.iloc[:2, :].fillna("").map(plot_agct) lm.logger.debug(f"df_onTarget_Ref_color: \n{df_onTarget_Ref_color}") # exit() lm.logger.debug(f"df_ratio_all: \n{df_ratio_all}") # exit() df_plot = df_ratio_all.iloc[2:, :].copy() lm.logger.debug(f"df_plot: \n{df_plot}") # print(f'df_plot: \n{df_plot}') # exit() df_plot.iloc[2:, :] = df_plot.iloc[2:, :].astype(float) * 100 ls_max = [] for i in df_plot.iloc[2:, :].values.tolist(): ls_max.extend(i) try: ls_break = list(np.arange(0, max(ls_max), max(ls_max) / 100)) except ZeroDivisionError: ls_break = [ 0, 0.003, 0.006, 0.009, 0.012, 0.015, 0.018, 0.022, 0.026, 0.030, 0.035, 0.040, 0.05, 0.06, 0.07, 0.08, 1.00, 2.00, 3.00, ] ls_color_middle = make_color_list( low_color_RGB=convert_hex_to_rgb("#87BDDB"), high_color_RGB=convert_hex_to_rgb("#0A306A"), length_out=80, return_fmt="Hex", ) ls_color_top = ["#0A306A"] * 20 ls_color_bottom = ( ["#EFEFEF"] * 5 + ["#DDEAF6"] * 5 + ["#A9CEE4"] * 5 + ["#87BDDB"] * 5 ) ls_color = ls_color_bottom + ls_color_middle + ls_color_top lm.logger.debug(f"ls_color: {ls_color}") # exit() df_onTarget_Ref_tmp = df_onTarget_Ref.T df_onTarget_Ref_tmp.loc[ df_onTarget_Ref_tmp["Target_seq"].isnull(), "Target_seq" ] = pd.NA df_onTarget_Ref = df_onTarget_Ref_tmp.T df_plot_rec = df_plot.copy() lm.logger.debug( "df_plot_rec:\n" + tabulate(df_plot_rec, headers="keys", tablefmt="psql"), ) # heatmap颜色,去map_hex_for_matrix函数中调整 def map_hex_for_matrix(x): # 判断value范围并返回颜色的Hex值 value = -1 for value in ls_break: if x < value: break continue return ls_color[ls_break.index(value) - 1] lm.logger.debug(f"df_plot_rec: \n{df_plot_rec}") # print(f'df_plot_rec: \n{df_plot_rec}') # exit() df_plot_rec_cmap = df_plot_rec.copy() df_plot_rec_cmap.iloc[2:, :] = df_plot_rec.iloc[2:, :].map(map_hex_for_matrix) # print(df_plot_rec_cmap) # exit() # print(f'df_plot_rec = \n{df_plot_rec}') # print(df_plot_rec.loc['On-Target', :]) # print(df_plot_rec.loc['On-Target', :].dtype) # df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef)) # print(f"df_onTarget_Ref.loc['Target_seq', :] = \n{df_onTarget_Ref.loc['Target_seq', :]}") # print('df_onTarget_Ref', df_onTarget_Ref.loc['Target_seq', :] == '') # print(f"df_plot_rec.loc['On-Target', :] = \n{df_plot_rec.loc['On-Target', :]}") # print('df_plot_rec', df_plot_rec.loc['On-Target', :] == '') # exit() df_plot_rec.loc["On-Target", :] = df_onTarget_Ref.loc["Target_seq", :] # print(df_plot_rec.loc['On-Target', :]) # print(df_onTarget_Ref.loc['Target_seq', :]) df_plot_rec.loc["Reference", :] = df_onTarget_Ref.loc["Ref_seq", :] df_plot_rec_cmap.loc["On-Target", :] = df_onTarget_Ref.loc[ "Target_seq", : ].map(plot_agct) df_plot_rec_cmap.loc["Reference", :] = df_onTarget_Ref.loc[ "Ref_seq", : ].map(plot_agct) figure_width_heatmap = df_plot_rec.shape[1] figure_height_heatmap = max(df_plot_rec.shape[0], len(ls_break) * 0.8) scale = 1.1 # 1.1 fig_heatmap = plt.figure( figsize=(figure_width_heatmap * scale, figure_height_heatmap * scale) ) ax = fig_heatmap.add_subplot(111, aspect="equal") plt.xlim([0, figure_width_heatmap + 5]) plt.ylim([-figure_height_heatmap - 0.5, 0]) plt.axis("off") # export heatmap values lm.logger.debug("export heatmap reference table...") lm.logger.debug(f"df_plot_rec: \n{df_plot_rec.index}") # exit() for row in range(df_plot_rec.shape[0]): row_name = df_plot_rec.index[row] # add panel name ax.text( x=-1, y=0.55 - row - 1.1, s=row_name, horizontalalignment="right", verticalalignment="center", fontsize=34 / 1.1 * scale, fontname="DejaVu Sans", alpha=1, ) # add Target_seq and ref if row_name in ["On-Target", "Reference"]: for col in range(df_plot.shape[1]): # plot Rectangle site_x = col site_y = -row + 0.05 - 1.05 ax.add_patch( mpl.patches.Rectangle( (site_x, site_y), width=1, height=1, linestyle="-", fill=True, facecolor=df_plot_rec_cmap.iloc[row, col], edgecolor="#AAAAAA" if df_plot_rec_cmap.iloc[row, col] != "#FFFFFF" else "#FFFFFF", linewidth=3.5, ) ) # plot text site_x = col site_y = -row - 0.55 ax.text( x=site_x + 0.5, y=site_y, s=df_plot_rec.iloc[row, col], horizontalalignment="center", verticalalignment="center", fontsize=34 / 1 * scale, fontname="DejaVu Sans", alpha=1, ) # add seq panels else: for col in range(df_plot.shape[1]): # plot Rectangle site_x = col site_y = -row - 1.05 ax.add_patch( mpl.patches.Rectangle( (site_x, site_y), width=1, height=1, linestyle="-", fill=True, facecolor=df_plot_rec_cmap.iloc[row, col], edgecolor="#FFFFFF", linewidth=3, ) ) # add cbar # start from this site site_x = df_plot.shape[1] + 1 site_y = -figure_height_heatmap * 0.1 - 1 step_scale = 0.1 lm.logger.debug(f"[cbar_scale]: {step_scale}") for color in ls_color: site_y += step_scale ax.add_patch( mpl.patches.Rectangle( (site_x, site_y), width=1, height=step_scale, fill=True, facecolor=color, edgecolor=color, ) ) # add cbar text # start from this site site_x = df_plot.shape[1] + 1 site_y = -figure_height_heatmap * 0.1 for idx, label in enumerate(ls_break): site_y += step_scale if idx % 10 == 0: ax.text( x=site_x + 1.1, y=site_y - 1, s=round(label, 4), horizontalalignment="left", verticalalignment="center", fontsize=34 / 1.5, # fontsize=34 / 1.5 * step_scale, fontname="DejaVu Sans", alpha=1, ) fig_heatmap.savefig(output_fig_heatmap, bbox_inches="tight") # 减少边缘空白 # plt.show() if output_table_heatmap: if output_table_heatmap.endswith(".csv"): df_plot_rec.to_csv(output_table_heatmap, sep=",") elif output_table_heatmap.endswith(".tsv"): df_plot_rec.to_csv(output_table_heatmap, sep="\t") if output_fig_count_ratio: # set new figure fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1)) ax = fig.add_subplot(111, aspect="equal") plt.xlim([0, figure_width]) plt.ylim([-figure_height, 0]) plt.axis("off") text_list = [] patches = [] # print panel_height_coef for panel_index in range(len(panel_height_coef)): # print 'panel index:', panel_index # panel name panel_name = plot_data_list[panel_index][0] panel_name_x = box_x_vec[0] # print panel_name # print panel_name_x panel_name_y = current_y - box_height_list[panel_index] * 0.5 # print panel_name_y text_list.append((panel_name_x, panel_name_y, panel_name, 10)) # plot panel box # plot Index行 if panel_name == "Ref_index": # don't draw box, only add text for index, box_value in enumerate(plot_data_list[panel_index][1]): box_x = box_x_vec[index + 1] text_list.append( ( box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5, str(box_value), 6, ) ) # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) # plot Target_seq和Ref_seq行 elif panel_name in ["Target_seq", "Ref_seq"]: # if panel_name = Ref, form a new list to store plot info for checking the block_ref param if panel_name == "Ref_seq": plot_data_list_ref = plot_data_list[panel_index][1] for index, box_value in enumerate(plot_data_list[panel_index][1]): # print(index, box_value) if box_value == "": box_fill = False box_color = "#FFFFFF" else: if aln_direction == "Reverse Alignment": if panel_name == "Ref_seq": box_value = "".join( [DNA_rev_cmp_dict.get(x) for x in box_value] ) else: ### fix bug for x in box_value: box_value = DNA_rev_cmp_dict.get(x) box_color = base_color_dict.get(box_value[0]) else: box_color = base_color_dict.get(box_value[-1]) if not box_color: box_fill = False box_color = "#FFFFFF" else: box_fill = True box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=box_fill, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 16, ) ) # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) # 用来plot Sample行的 # 这里plot counts elif panel_name.split("to ")[-1] in default_bases: # print 'panel name:', panel_name.split('to ')[-1].replace('(%)', '') # print panel_name, 'check panel name' # 改进了这里total 和 base的值的引用 if count_ratio == "all": ls_base_sum_count_all = ( ls_base_sum_count * 2 * len(default_bases) ) ls_total_sum_count_all = ( ls_total_sum_count * 2 * len(default_bases) ) total_sum_count = ls_total_sum_count_all[panel_index - 3] base_sum_count = ls_base_sum_count_all[panel_index - 3] else: ls_base_sum_count_all = ls_base_sum_count * len(default_bases) ls_total_sum_count_all = ls_total_sum_count * len(default_bases) total_sum_count = ls_total_sum_count_all[panel_index - 3] base_sum_count = ls_base_sum_count_all[panel_index - 3] panel_name = panel_name.split("to ")[-1] if panel_name in ["Del", "Ins"]: box_ratio = plot_data_list[panel_index][1] / total_sum_count else: box_ratio = plot_data_list[panel_index][1] / base_sum_count box_color_list = map_color(box_ratio, color_break, color_list) base_sum_count.tolist() for index, box_value in enumerate(plot_data_list[panel_index][1]): # if count_ratio == 'all': box_color = box_color_list[index] # check block_ref state: count plot if block_ref: if plot_data_list_ref.iloc[index] == panel_name: # if box_value >= ls_base_sum[index] * 0.8: # box_fill = False box_color = "#FFFFFF" box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 6, ) ) # print box_value # make next panel_y current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) # 这里plot ratio(剩下的都是base(%)) else: box_color_list = map_color( plot_data_list[panel_index][1], color_break, color_list ) # print box_color_list panel_name = panel_name.split("to ")[-1].replace("(%)", "") for index, box_value in enumerate(plot_data_list[panel_index][1]): # ratio box_color = box_color_list[index] # check block_ref state: ratio plot if block_ref: if plot_data_list_ref.iloc[index] == panel_name: # if box_value >= 0.85: # box_fill = False box_color = "#FFFFFF" box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color, ), ) if "(%)" in plot_data_list[panel_index][0]: # text # ref # print 'test', box_value # if ARGS.block_ref == True: # if box_value >= 0.99: # box_value = 0 text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], round(box_value * 100, 4), 6, ) ) else: text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], box_value, 6, ) ) # make next panel_y if panel_index < len(panel_space_list): current_y = current_y - ( box_height_list[panel_index] + panel_space_list[panel_index] ) # plot box ax.add_collection(PatchCollection(patches, match_original=True)) # add text for text_x, text_y, text_info, text_fontsize in text_list: if " to " in str(text_info): plt.text( x=text_x + 0.3, y=text_y, s=text_info, horizontalalignment="right", verticalalignment="center", fontsize=text_fontsize, fontname="DejaVu Sans", ) else: plt.text( x=text_x, y=text_y, s=text_info, horizontalalignment="center", verticalalignment="center", fontsize=text_fontsize, fontname="DejaVu Sans", ) # output plot if output_fig_fmt.upper == "PNG": fig.savefig( fname=output_fig_count_ratio, bbox_inches="tight", dpi=output_fig_dpi, format="png", ) else: fig.savefig( fname=output_fig_count_ratio, bbox_inches="tight", dpi=output_fig_dpi, format="pdf", ) if output_table_count_ratio: if output_table_count_ratio.endswith(".csv"): df_bases_select.to_csv(output_table_count_ratio, sep=",") elif output_table_count_ratio.endswith(".tsv"): df_bases_select.to_csv(output_table_count_ratio, sep="\t")
if __name__ == "__main__": pass