Source code for bioat.target_seq

"""_summary_

author: Herman Huanan Zhao
email: hermanzhaozzzz@gmail.com
homepage: https://github.com/hermanzhaozzzz

_description_

example 1:
    bioat list
        <in shell>:
            $ bioat list
        <in python consolo>:
            >>> from bioat.cli import Cli
            >>> bioat = Cli()
            >>> bioat.list()
            >>> print(bioat.list())

example 2:
    _example_
"""

import gzip
import os.path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from Bio.Seq import Seq
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from tabulate import tabulate

from bioat.exceptions import (
    BioatFileFormatError,
    BioatInvalidInputError,
    BioatInvalidOptionError,
)
from bioat.lib.libalignment import instantiate_pairwise_aligner
from bioat.lib.libcolor import convert_hex_to_rgb, make_color_list, map_color
from bioat.lib.libcrispr import TARGET_SEQ_LIB, run_target_seq_align
from bioat.lib.libpandas import set_option
from bioat.logger import LoggerManager

lm = LoggerManager(mod_name="bioat.target_seq")

set_option(log_level='ERROR')


[docs] class TargetSeq: """Target Deep Sequencing toolbox.""" lm.set_names(cls_name="TargetSeq") def __init__(self): pass
[docs] def region_heatmap( self, input_table: str, output_fig: str, target_seq: str = None, # sgRNA reference_seq: str = None, # target locus input_table_header: bool = True, output_fig_fmt: str = 'pdf', output_fig_dpi: int = 100, show_indel: bool = True, show_index: bool = True, box_border: bool = False, box_space: float = 0.03, min_color: tuple = (250, 239, 230), max_color: tuple = (154, 104, 57), min_ratio: float = 0.001, max_ratio: float = 0.99, region_extend_length: int = 5, local_alignment_scoring_matrix: tuple = (5, -4, -24, -8), local_alignment_min_score: int = 15, PAM_priority_weight: float = 1.0, get_built_in_target_seq: bool = False, log_level: str = 'INFO' ): """Plot region mutation information using a table generated by `bioat bam mpileup_to_table`. This function generates a visualization of mutation information for a specific genomic region based on a table created by the `bioat bam mpileup_to_table` command. Args: input_table (str): Path to the table generated by `bioat bam mpileup_to_table`. This table should contain base mutation information for a short genome region (no more than 1k nt). output_fig (str): Path to the output figure file. target_seq (str, optional): Target sequence to align against the reference sequence in `mpileup.table`. Examples: - 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^). - '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^). - 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM. Defaults to None. reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`. Can be a FASTA file or a DNA sequence. Defaults to None. input_table_header (bool, optional): Whether the `input_table` contains a header. Defaults to True. output_fig_fmt (str, optional): Format of the output figure ("pdf" or "png"). Defaults to "pdf". output_fig_dpi (int, optional): DPI for the output figure. Defaults to 300. show_indel (bool, optional): Whether to show indel information in the output figure. Defaults to True. show_index (bool, optional): Whether to display index information in the output figure. Defaults to True. box_border (bool, optional): Whether to display box borders in the output figure. Defaults to True. box_space (int, optional): Space size between two boxes. Defaults to 1. min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255). max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0). min_ratio (float, optional): Mutation ratio below `min_ratio` will be shown as white. Defaults to 0.0. max_ratio (float, optional): Mutation ratio above `max_ratio` will be capped. Defaults to 1.0. region_extend_length (int, optional): Number of base pairs to extend on either side of the region. Defaults to 0. local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple: (<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>). Defaults to None. local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0. PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0. get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False. log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO". Returns: None Examples: >>> # Generate mutation table >>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info.tsv >>> # Generate region heatmap >>> bioat target_seq region_heatmap --input_table test_sorted.mpileup.info.tsv --output_fig test_sorted.mpileup.info.pdf """ lm.set_names(func_name="region_heatmap") lm.set_log_level(log_level) if get_built_in_target_seq: lm.logger.info( "You can use <key> in built-in <target_seq> to represent your target_seq:\n" + "\t<key>\t<target_seq>\n" + "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]) ) lm.logger.info("exit because of the defination for get_built_in_target_seq") exit(0) else: # set sgRNA info lm.logger.debug(f"load target_seq = {target_seq}") if target_seq in TARGET_SEQ_LIB: lm.logger.debug( f"target_seq information is abtained from the <key>={target_seq}" ) target_seq = TARGET_SEQ_LIB[target_seq] lm.logger.debug(f"target_seq is refered to {target_seq}") elif target_seq is None: pass elif isinstance(target_seq, str): lm.logger.debug(f"target_seq is refered to {target_seq}") else: lm.logger.error( "<target_seq> should be set as:" "DNA sequence / <key> in <get_built_in_target_seq> / None," f"but your <target_seq> is {target_seq}" ) exit(1) # set alignment aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix) # load mpileup.info.tsv if input_table_header: df_bases = pd.read_csv(input_table, sep="\t") else: df_bases = pd.read_csv(input_table, sep="\t", header=None) lm.logger.debug( "df_bases.head(10):\n" + tabulate(df_bases.head(10), headers="keys", tablefmt="psql") ) # ref_seq & ref_seq_rc if reference_seq: # fa file or seq str to seq str if os.path.isfile(reference_seq): f = open(reference_seq, 'rt') if not reference_seq.endswith('.gz') else gzip.open(reference_seq, 'rt') reference_seq = ''.join([i.rstrip() for i in f.readlines()[1:]]) if reference_seq.count('>') > 0: raise BioatFileFormatError ref_seq = Seq(reference_seq) lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}") else: ref_seq = Seq("".join(df_bases.ref_base)) lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}") lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}") ref_seq_rc = ref_seq.reverse_complement() lm.logger.debug( f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}" ) # target_seq if target_seq is None: target_seq = ref_seq PAM = None elif '^' in target_seq: if target_seq.find('^') == 0: PAM, target_seq = target_seq[1:].strip().split('^') target_seq = Seq(target_seq.upper()) PAM = {'PAM': PAM, 'position': 0, 'weight': PAM_priority_weight} else: target_seq, PAM = target_seq[:-1].strip().split('^') target_seq = Seq(target_seq.upper()) PAM = {'PAM': PAM, 'position': len(target_seq), 'weight': PAM_priority_weight} else: lm.logger.debug( f"no PAM sequence is defined from target_seq = {target_seq}" ) target_seq = Seq(target_seq) PAM = None lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}") # fwd alignment & rev alignment # 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上 if not PAM: lm.logger.info("use PAMless mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") else: lm.logger.info("use PAM mode to align") # fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM) # lm.logger.debug(f'final_align_forward:\n{fwd}') # rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM) # lm.logger.debug(f'final_align_reverse:\n{rev}') fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") # get fwd alignment info reference_seq = fwd['alignment']['reference_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] align_info = fwd['alignment']['aln_info'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] target_seq = fwd['alignment']['target_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] aln_score = fwd['aln_score'] lm.logger.info( f"Forward best alignment:\n" f"reference : {reference_seq}\n" f"align_info : {align_info}\n" f"target_seq : {target_seq}\n" f"align_score: {aln_score}" ) # get rev alignment info lm.logger.info(rev) reference_seq_rev = rev['alignment']['reference_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] align_info_rev = rev['alignment']['aln_info'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] target_seq_rev = rev['alignment']['target_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] aln_score_rev = rev['aln_score'] lm.logger.info( f"Reverse best alignment:\n" f"reference : {reference_seq_rev}\n" f"align_info : {align_info_rev}\n" f"target_seq : {target_seq_rev}\n" f"align_score: {aln_score_rev}" ) # define align direction and final align res # make alignment info if any([x >= local_alignment_min_score for x in (fwd['aln_score'], rev['aln_score'])]): if fwd['aln_score'] >= rev['aln_score']: aln_direction = "Forward Alignment" aln = fwd else: aln_direction = "Reverse Alignment" aln = rev reference_seq = reference_seq_rev align_info = align_info_rev target_seq = target_seq_rev aln_score = aln_score_rev else: lm.logger.critical("Alignment Error!") exit(1) # mark aligned all bases in target_seq ref_seq_length = len(ref_seq) target_seq_aln = [""] * ref_seq_length # mark inserted bases in target_seq target_seq_aln_insert = [""] * ref_seq_length DNA_rev_cmp_dict = { "A": "T", "T": "A", "C": "G", "G": "C", "N": "N", "-": "-" } reference_gap_count = reference_seq.count("-") target_seq_gap_count = target_seq.count("-") ref_gap_count = 0 ref_del_str = "" aln_start = aln['ref_aln_start'] aln_end = aln['ref_aln_end'] # update target_seq_aln lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (before update) = {target_seq_aln_insert}" ) for idx, ref_base in enumerate(reference_seq): # reference : GCCTCTGGAGAGGGAGGAGGG # align_info : |.|.|||..|..|||||.||- # target_seq : GGCACTGCGGCTGGAGGTGG- if ref_base != "-": ref_gap_count += 0 ref_del_str = "" target_seq_aln[aln_start + idx - ref_gap_count] = ref_del_str + target_seq[idx] else: ref_gap_count += 1 ref_del_str += target_seq[idx] # for continuous gap target_seq_aln_insert[aln_start + idx] = target_seq[idx] lm.logger.debug( "update target_seq_aln: show ''.join(target_seq_aln)\n" f"target_seq_aln = {''.join(target_seq_aln)}\n" f"target_seq = {target_seq}" ) lm.logger.debug("update target_seq_aln: Done") lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (after update) = {target_seq_aln_insert}" ) # if reverse alignment if aln_direction == "Reverse Alignment": # illustrate reverse alignment as forward alignment target_seq_aln = target_seq_aln[::-1] target_seq_aln_insert = target_seq_aln_insert[::-1] lm.logger.debug("use reverse alignment") lm.logger.debug(f"target_seq_aln = {target_seq_aln}") lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}") # add PAM info # print(target_seq_aln) # print(target_seq_aln_insert) # target_seq_aln # ['-','-','G','G','C','A','G','C','G','G','C','T','G','G','A','A','A','A','A','A','A','A','A','A','A','A', # 'A','A','A','G','T','',''] # target_seq_aln_insert # ['','','','','','','','','','','','','','','','','','','','','','','','','','','','','','A','G','',''] # set plot_region # print(aln_start, aln_end, region_extend_length, ref_seq_length) # 56 76 25 266 # print(aln_start - region_extend_length) # 31 # print(aln_end + region_extend_length) # 101 possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bases.iloc[plot_region[0]: plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql") ) # make plot # set color # show indel matplotlib.use('Agg') np.set_printoptions(suppress=True) indel_plot_state = show_indel index_plot_state = show_index box_border_plot_state = box_border # set panel size panel_box_width = 0.4 panel_box_heigth = 0.4 panel_space = 0.05 panel_box_space = box_space # color part base_colors = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#AAAAAA", "-": "#AAAAAA"} # make color breaks color_break_num = 20 break_step = 1.0 / color_break_num # mutation ratio of box lower than min_ratio will show as white min_color_value = min_ratio # mutation ratio of box lower than max_ratio will show as white max_color_value = max_ratio color_break = np.round(np.arange(min_color_value, max_color_value, break_step), 5) # min_color: min color to plot heatmap with RGB format # max_color: max color to plot heatmap with RGB format color_list = make_color_list(min_color, max_color, len(color_break) - 1, "HEX") color_list = ["#FFFFFF"] + color_list lm.logger.debug(f"color_list = {color_list}") # get plot info total_box_count = plot_region[1] - plot_region[0] # calculate base info and fix zero base_sum_count = df_bases_select.loc[:, ["A", "G", "C", "T"]].sum(axis=1).astype(int) lm.logger.debug(f"base_sum_count =\n{base_sum_count.values}") total_sum_count = df_bases_select[["A", "G", "C", "T", "del_count", "insert_count"]].sum(axis=1).astype(int) lm.logger.debug(f"total_sum_count =\n{total_sum_count.values}") # fix 0 -> ["A.ratio", list(df_bases_select["A"] / base_sum_count)] base_sum_count[base_sum_count == 0] = 1 total_sum_count[total_sum_count == 0] = 1 # make plot size lm.logger.debug(f"aln =\n{aln}") if indel_plot_state: panel_height_coef = [.5, .9, .9] + [.5] * 6 + [.5] * 6 panel_space_coef = [1.] * 3 + [.3] * 3 + [1., .3, 1.] + [.3] * 3 + [1., .3] plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", target_seq_aln[plot_region[0]: plot_region[1]]], ["Ref_seq", df_bases_select.ref_base], ["A", np.array(df_bases_select["A"])], ["G", np.array(df_bases_select["G"])], ["C", np.array(df_bases_select["C"])], ["T", np.array(df_bases_select["T"])], ["Del", np.array(df_bases_select["del_count"])], ["Ins", np.array(df_bases_select["insert_count"])], ["A %", list(df_bases_select["A"] / base_sum_count)], ["G %", list(df_bases_select["G"] / base_sum_count)], ["C %", list(df_bases_select["C"] / base_sum_count)], ["T %", list(df_bases_select["T"] / base_sum_count)], ["Del %", list(df_bases_select["del_count"] / total_sum_count)], ["Ins %", list(df_bases_select["insert_count"] / total_sum_count)] ] else: panel_height_coef = [.5, .9, .9] + [.5] * 4 + [.5] * 4 panel_space_coef = [1.] * 3 + [0.3] * 3 + [1.] + [0.3] * 3 plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", target_seq_aln[plot_region[0]: plot_region[1]]], ["Ref_seq", df_bases_select.ref_base], ["A", np.array(df_bases_select["A"])], ["G", np.array(df_bases_select["G"])], ["C", np.array(df_bases_select["C"])], ["T", np.array(df_bases_select["T"])], ["A %", list(df_bases_select["A"] / base_sum_count)], ["G %", list(df_bases_select["G"] / base_sum_count)], ["C %", list(df_bases_select["C"] / base_sum_count)], ["T %", list(df_bases_select["T"] / base_sum_count)] ] # get box and space info box_height_list = np.array(panel_height_coef) * panel_box_heigth panel_space_list = np.array(panel_space_coef) * panel_space lm.logger.debug(f"box_height_list = {box_height_list}") lm.logger.debug(f"panel_space_list = {panel_space_list}") # calculate figure total width and height figure_width = total_box_count * panel_box_width + \ (total_box_count - 1) * panel_box_space + panel_box_width * 2 figure_height = sum(box_height_list) + sum(panel_space_list) lm.logger.debug(f"figure_width = {figure_width}") lm.logger.debug(f"figure_height = {figure_height}") # make all box_x box_x_vec = np.arange(0, figure_width + panel_box_width, panel_box_width + panel_box_space) box_x_vec = box_x_vec[:(len(ref_seq) + 1)] # lm.logger.debug(f'box_x_vec (x for each column) =\n{box_x_vec}') # make box border if box_border_plot_state: box_edgecolor = "#AAAAAA" box_linestyle = "-" box_linewidth = 2 lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}") lm.logger.debug(f"box_edgecolor = {box_edgecolor}") lm.logger.debug(f"box_linestyle = {box_linestyle}") lm.logger.debug(f"box_linewidth = {box_linewidth}") else: box_edgecolor = "#FFFFFF" box_linestyle = "None" box_linewidth = 0 lm.logger.debug(f"box_border_plot_state = {box_border_plot_state}") lm.logger.debug(f"box_edgecolor = {box_edgecolor}") lm.logger.debug(f"box_linestyle = {box_linestyle}") lm.logger.debug(f"box_linewidth = {box_linewidth}") # make box_y initialize current_y = 0 lm.logger.debug("start to plot") fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1)) lm.logger.debug( f"set new figure, figsize = ({figure_width * 1.1}, {figure_height * 1.1})" ) plt.set_loglevel("info") # will show matplotlib debug log with logging ax = fig.add_subplot(111, aspect="equal") # will show matplotlib debug log with logging plt.xlim([0, figure_width]) plt.ylim([-figure_height, 0]) plt.axis("off") lm.logger.debug(f"plt.xlim([0, {figure_width}])") lm.logger.debug(f"plt.ylim([-{figure_height}, 0])") lm.logger.debug('plt.axis("off")') # make plot text_list = [] patches = [] for panel_index in range(len(panel_height_coef)): # panel name panel_name = plot_data_list[panel_index][0] panel_name_x = box_x_vec[0] panel_name_y = current_y - box_height_list[panel_index] * 0.5 text_list.append((panel_name_x, panel_name_y, panel_name, 10)) # plot panel box if panel_name == "Ref_index": # don't draw box, only add text for index, box_value in enumerate( plot_data_list[panel_index][1]): box_x = box_x_vec[index + 1] text_list.append( ( box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5, str(box_value), 7 ) ) # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) elif panel_name in ["Target_seq", "Ref_seq"]: for index, box_value in enumerate(plot_data_list[panel_index][1]): if box_value == "": box_fill = False box_color = "#FFFFFF" else: if "Reverse" in aln_direction: if panel_name == "Ref_seq": box_value = "".join([DNA_rev_cmp_dict.get(x) for x in box_value]) else: pass box_color = base_colors.get(box_value[0]) else: box_color = base_colors.get(box_value[-1]) if not box_color: box_fill = False box_color = "#FFFFFF" else: box_fill = True box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=box_fill, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color ) ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 16 ) ) # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) elif panel_name in ["A", "G", "C", "T", "Del", "Ins"]: if panel_name in ["Del", "Ins"]: box_ratio = plot_data_list[panel_index][1] / total_sum_count else: box_ratio = plot_data_list[panel_index][1] / base_sum_count box_color_list = map_color(box_ratio, color_break, color_list) lm.logger.debug(f"get box_color_list for {panel_name}") for index, box_value in enumerate(plot_data_list[panel_index][1]): box_color = box_color_list[index] box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color ) ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 6 ) ) # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) else: box_color_list = map_color(plot_data_list[panel_index][1], color_break, color_list) for index, box_value in enumerate(plot_data_list[panel_index][1]): box_color = box_color_list[index] box_x = box_x_vec[index + 1] patches.append( Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color ) ) # text text_list.append( ( box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], round(box_value * 100, 4), 6 ) ) # make next panel_y if panel_index < len(panel_space_list): current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) # plot box lm.logger.debug("plot rectangles") ax.add_collection(PatchCollection(patches, match_original=True)) lm.logger.debug("plot text on each rectangle") for text_x, text_y, text_info, text_fontsize in text_list: plt.text( x=text_x, y=text_y, s=text_info, horizontalalignment='center', verticalalignment='center', fontsize=text_fontsize, fontname="Arial" ) # output plot fig.savefig(fname=output_fig, bbox_inches='tight', dpi=output_fig_dpi, format=output_fig_fmt)
[docs] def region_heatmap_compare( self, input_tables: str, labels: str | None = None, target_seq: str | None = None, # sgRNA reference_seq: str | None = None, # target locus output_fig_heatmap: str | None = None, output_fig_count_ratio: str | None = None, output_table_heatmap: str | None = None, output_table_count_ratio: str | None = None, output_fig_fmt: str = "pdf", input_table_header: bool = True, to_base: tuple = ("A", "G", "C", "T", "Ins", "Del"), heatmap_mut_direction: tuple = ("CT", "GA"), count_ratio="all", region_extend_length: int = 5, output_fig_dpi: int = 100, show_indel: bool = True, show_index: bool = True, block_ref: bool = True, box_border: bool = False, box_space: float = 0.03, min_color: tuple = (250, 239, 230), max_color: tuple = (154, 104, 57), min_ratio: float = 0.001, max_ratio: float = 0.99, local_alignment_scoring_matrix: tuple = (5, -4, -24, -8), local_alignment_min_score: int = 15, PAM_priority_weight: float = 1.0, get_built_in_target_seq: bool = False, log_level: str = "INFO", ): """Plot region mutation information for multiple conditions. This function generates a comparison of mutation information across multiple conditions using tables generated by `bioat bam mpileup_to_table`. Args: input_tables (str): Paths to input tables generated by `bioat bam mpileup_to_table`, separated by commas. Each table should contain mutation information for a short genomic region (≤1k nt). labels (str): Labels for the panels, separated by commas. target_seq (str, optional): Sequence to align against the reference sequence in `mpileup.table`. Examples: - 'GAGTCCGAGCAGAAGAAGAA^GGG^' for SpCas9-BE (PAM: ^GGG^). - '^TTTA^GCCCCAATAATCCCCACATGTCA' for cpf1-BE (PAM: ^TTTA^). - 'TGCTAGTAACCACGTTCTCCTGATCAAATATCACTCTCCTACTTACAGGA' for no PAM. Defaults to None. reference_seq (str, optional): Custom reference sequence to overwrite the one in `mpileup.table`. Can be a FASTA file or a DNA sequence. Defaults to None. output_fig_heatmap (str, optional): Path to the heatmap output figure. Defaults to None. output_fig_count_ratio (str, optional): Path to the count/ratio output figure. Defaults to None. output_table_heatmap (str, optional): Path to the heatmap output table. Defaults to None. output_table_count_ratio (str, optional): Path to the count/ratio output table. Defaults to None. output_fig_fmt (str, optional): Format of the output figures, either "pdf" or "png". Defaults to "pdf". input_table_header (bool, optional): Whether the input tables have headers. Defaults to True. to_base (str, optional): Reference bases to convert to, separated by commas. Defaults to "A,G,C,T,Ins,Del". heatmap_mut_direction (str, optional): Mutation directions to plot, specified as [from Base][to Base], separated by commas. Defaults to "CT,GA". count_ratio (str, optional): Type of plot to generate: "count", "ratio", or "all". Defaults to "all". region_extend_length (int, optional): Number of base pairs to extend on both sides of the region. Defaults to 0. output_fig_dpi (int, optional): DPI for the output figures. Defaults to 300. show_indel (bool, optional): Whether to show indel information in the output figures. Defaults to True. show_index (bool, optional): Whether to display index information in the output figures. Defaults to True. block_ref (bool, optional): Whether to hide colors for reference sites in the output figures. Defaults to True. box_border (bool, optional): Whether to display box borders in the output figures. Defaults to True. box_space (int, optional): Space size between two boxes in the heatmap. Defaults to 1. min_color (tuple, optional): Minimum color for the heatmap in RGB format. Defaults to (255, 255, 255). max_color (tuple, optional): Maximum color for the heatmap in RGB format. Defaults to (0, 0, 0). min_ratio (float, optional): Mutation ratios below this value will appear white. Defaults to 0.0. max_ratio (float, optional): Mutation ratios above this value will be capped. Defaults to 1.0. local_alignment_scoring_matrix (tuple, optional): Alignment scoring parameters as a tuple: (<align_match_score>, <align_mismatch_score>, <align_gap_open_score>, <align_gap_extension_score>). Defaults to None. local_alignment_min_score (int, optional): Minimum alignment score to consider as a valid alignment. Defaults to 0. PAM_priority_weight (float, optional): Weight multiplier for PAM alignment scores. Defaults to 1.0. get_built_in_target_seq (bool, optional): Set to True to return built-in target sequence information. Defaults to False. log_level (str, optional): Logging level. One of 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'. Defaults to "INFO". Returns: None Examples: >>> # Generate mutation tables >>> samtools mpileup test_sorted.bam --reference HK4-AOut-1.ref.upper.fa | gzip > test_sorted.mpileup.gz >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info1.tsv >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info2.tsv >>> bioat bam mpileup_to_table test_sorted.mpileup.gz > test_sorted.mpileup.info3.tsv >>> # Generate region heatmap comparison >>> bioat target_seq region_heatmap_compare --input_tables test_sorted.mpileup.info1.tsv,test_sorted.mpileup.info2.tsv,test_sorted.mpileup.info3.tsv --labels condition1,condition2,condition3 --target_seq HEK4 """ lm.set_names(func_name="region_heatmap_compare") lm.set_level(log_level) base_color_dict = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#DDEAF6"} def plot_agct(x): if x == '': # target_seq sgRNA左右没align的位置给空值 return '#FFFFFF' elif x in ['A', 'G', 'C', 'T', 'N']: return base_color_dict[x] else: return '#AAAAAA' # check whether return built-in information if get_built_in_target_seq: lm.logger.info( "You can use <key> in built-in <target_seq> to represent your target_seq:\n" + "\t<key>\t<target_seq>\n" + "".join([f"\t{k}\t{v}\n" for k, v in TARGET_SEQ_LIB.items()]) ) lm.logger.info("exit because of the defination for get_built_in_target_seq") exit(0) else: # set target_seq(sgRNA) info lm.logger.debug(f"load target_seq = {target_seq}") if target_seq in TARGET_SEQ_LIB: lm.logger.debug( f"target_seq information is abtained from the <key>={target_seq}" ) target_seq = TARGET_SEQ_LIB[target_seq] lm.logger.debug(f"target_seq is refered to {target_seq}") elif target_seq is None: pass elif isinstance(target_seq, str): lm.logger.debug(f"target_seq is refered to {target_seq}") else: lm.logger.error( "<target_seq> should be set as:" "DNA sequence / <key> in <get_built_in_target_seq> / None," f"but your <target_seq> is {target_seq}" ) exit(1) # check if nothing to output outputs = (output_fig_heatmap, output_fig_count_ratio, output_table_heatmap, output_table_count_ratio) if not [i for i in outputs if i is not None]: lm.logger.error( "outputs should be set as free combination of (output_fig_heatmap, output_fig_count_ratio, " "output_table_heatmap, output_table_count_ratio), but none of outputs was defined" ) exit(1) # check parameter: to_base # check to_base: tuple for A G C T Ins Del or combine like A G T Ins default_bases = ('A', 'G', 'C', 'T', 'Ins', 'Del') lm.logger.debug("check to_base parameter") for base in to_base: if base not in default_bases: lm.logger.error("check to_base parameter") lm.logger.error(f"{base} not in {default_bases}") raise BioatInvalidOptionError lm.logger.debug(f"to_base was defined as {to_base}") # check parameter: count_ratio default_count_ratio_choices = ('count', 'ratio', 'all') lm.logger.debug("check count_ratio parameter") if count_ratio not in default_count_ratio_choices: lm.logger.error("check count_ratio parameter") lm.logger.error(f"{count_ratio} not in {default_count_ratio_choices}") raise BioatInvalidOptionError lm.logger.debug(f"count_ratio was defined as {count_ratio}") # parse parameter: input_tables if isinstance(input_tables, tuple): input_tables = [i.strip() for i in input_tables] elif isinstance(input_tables, str): input_tables = input_tables.replace(' ', '').replace('\t', '').strip().split(',') else: raise BioatInvalidInputError lm.logger.debug(f"parse tables... {input_tables}") # parse parameter: labels, if labels is None, labels value will be input_tables names if isinstance(labels, tuple): labels = [i.strip() for i in labels] elif isinstance(labels, str): labels = labels.replace(' ', '').replace('\t', '').strip().split(',') if labels else input_tables else: raise BioatInvalidInputError lm.logger.debug(f"parse labels... {labels}") # set alignment aligner = instantiate_pairwise_aligner(*local_alignment_scoring_matrix) # load & merge bmat files & set treatment names ls = [] for input_table, label in zip(input_tables, labels): df = pd.read_csv(input_table, sep="\t") if input_table_header \ else pd.read_csv(input_table, sep="\t", header=None) df['label'] = label ls.append(df) df_bases = pd.concat(ls) lm.logger.debug( "df_bases.head(10):\n" + tabulate(df_bases.head(8), headers="keys", tablefmt="psql") ) # parse parameter: reference_seq, if reference_seq is None, reference_seq value will be derived from df_bases if reference_seq: # fa file or seq str to seq str if os.path.isfile(reference_seq): f = open(reference_seq, 'rt') if not reference_seq.endswith('.gz') else gzip.open(reference_seq, 'rt') reference_seq = ''.join([i.rstrip() for i in f.readlines()[1:]]) if reference_seq.count('>') > 0: raise BioatFileFormatError ref_seq = Seq(reference_seq) lm.logger.debug(f"load ref_seq from parameter:\n{ref_seq}") else: ref_seq = Seq("".join(df_bases.query('label==@labels[0]').ref_base)) lm.logger.debug(f"load ref_seq from df_bases.ref_base:\n{ref_seq}") lm.logger.debug(f"Seq object for ref_seq:\n{ref_seq}") ref_seq_rc = ref_seq.reverse_complement() lm.logger.debug( f"Seq object for ref_seq_rc (reverse_complement):\n{ref_seq_rc}" ) # parse parameter: target_seq if target_seq is None: target_seq = ref_seq PAM = None elif '^' in target_seq: if target_seq.find('^') == 0: PAM, target_seq = target_seq[1:].strip().split('^') target_seq = Seq(target_seq.upper()) PAM = {'PAM': PAM, 'position': 0, 'weight': PAM_priority_weight} else: target_seq, PAM = target_seq[:-1].strip().split('^') target_seq = Seq(target_seq.upper()) PAM = {'PAM': PAM, 'position': len(target_seq), 'weight': PAM_priority_weight} else: lm.logger.debug( f"no PAM sequence is defined from target_seq = {target_seq}" ) target_seq = Seq(target_seq) PAM = None lm.logger.info(f"parse target_seq: target_seq={target_seq}, PAM={PAM}") # fwd alignment & rev alignment # 为了判断target_seq (sgRNA) 是设计在fwd还是rev链上 if not PAM: lm.logger.info("use PAMless mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner) lm.logger.debug(f"final_align_reverse:\n{rev}") else: lm.logger.info("use PAM mode to align") fwd = run_target_seq_align(ref_seq, target_seq, aligner, PAM=PAM) lm.logger.debug(f"final_align_forward:\n{fwd}") rev = run_target_seq_align(ref_seq_rc, target_seq, aligner, PAM=PAM) lm.logger.debug(f"final_align_reverse:\n{rev}") # get fwd alignment info reference_seq = fwd['alignment']['reference_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] align_info = fwd['alignment']['aln_info'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] target_seq = fwd['alignment']['target_seq'][fwd['ref_aln_start']: fwd['ref_aln_end'] + 1] aln_score = fwd['aln_score'] lm.logger.info( f"Forward best alignment:\n" f"reference : {reference_seq}\n" f"align_info : {align_info}\n" f"target_seq : {target_seq}\n" f"align_score: {aln_score}" ) # get rev alignment info reference_seq_rev = rev['alignment']['reference_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] align_info_rev = rev['alignment']['aln_info'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] target_seq_rev = rev['alignment']['target_seq'][rev['ref_aln_start']: rev['ref_aln_end'] + 1] aln_score_rev = rev['aln_score'] lm.logger.info( f"Reverse best alignment:\n" f"reference : {reference_seq_rev}\n" f"align_info : {align_info_rev}\n" f"target_seq : {target_seq_rev}\n" f"align_score: {aln_score_rev}" ) # define align direction and final align res # make alignment info if any([x >= local_alignment_min_score for x in (fwd['aln_score'], rev['aln_score'])]): if fwd['aln_score'] >= rev['aln_score']: aln_direction = "Forward Alignment" aln = fwd else: aln_direction = "Reverse Alignment" aln = rev reference_seq = reference_seq_rev align_info = align_info_rev target_seq = target_seq_rev aln_score = aln_score_rev else: lm.logger.critical("Alignment Error!") exit(1) # mark aligned all bases in target_seq ref_seq_length = len(ref_seq) target_seq_aln = [""] * ref_seq_length # mark inserted bases in target_seq target_seq_aln_insert = [""] * ref_seq_length DNA_rev_cmp_dict = { "A": "T", "T": "A", "C": "G", "G": "C", "N": "N", "-": "-" } reference_gap_count = reference_seq.count("-") target_seq_gap_count = target_seq.count("-") ref_gap_count = 0 ref_del_str = "" aln_start = aln['ref_aln_start'] aln_end = aln['ref_aln_end'] # update target_seq_aln lm.logger.debug(f"target_seq_aln (before update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (before update) = {target_seq_aln_insert}" ) for idx, ref_base in enumerate(reference_seq): # reference : GCCTCTGGAGAGGGAGGAGGG # align_info : |.|.|||..|..|||||.||- # target_seq : GGCACTGCGGCTGGAGGTGG- if ref_base != "-": ref_gap_count += 0 ref_del_str = "" target_seq_aln[aln_start + idx - ref_gap_count] = ref_del_str + target_seq[idx] else: ref_gap_count += 1 ref_del_str += target_seq[idx] # for continuous gap target_seq_aln_insert[aln_start + idx] = target_seq[idx] lm.logger.debug( "update target_seq_aln: show ''.join(target_seq_aln)\n" f"target_seq_aln = {''.join(target_seq_aln)}\n" f"target_seq = {target_seq}" ) lm.logger.debug("update target_seq_aln: Done") lm.logger.debug(f"target_seq_aln (after update) = {target_seq_aln}") lm.logger.debug( f"target_seq_aln_insert (after update) = {target_seq_aln_insert}" ) # if reverse alignment if aln_direction == "Reverse Alignment": # illustrate reverse alignment as forward alignment target_seq_aln = target_seq_aln[::-1] target_seq_aln_insert = target_seq_aln_insert[::-1] lm.logger.debug("use reverse alignment") lm.logger.debug(f"target_seq_aln = {target_seq_aln}") lm.logger.debug(f"target_seq_aln_insert = {target_seq_aln_insert}") possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bases.iloc[plot_region[0]: plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql") ) # ---------------------------------------------------------------->>>>> # load .bmat file # ---------------------------------------------------------------->>>>> label_panel = labels ls_bmat = input_tables ls_bmat_table = [pd.read_csv(path_bmat, sep='\t') for path_bmat in ls_bmat] for index, bmat in enumerate(ls_bmat_table): bmat['label'] = label_panel[index] bmat.drop('chr_name', axis=1, inplace=True) df_tmp = ls_bmat_table[0] bmat_table = ls_bmat_table[0] for df_bmat in ls_bmat_table[1:]: suffixes = (f'_{df_tmp.iloc[0, -1]}', f'_{df_bmat.iloc[0, -1]}') df_tmp = pd.merge(df_tmp, df_bmat, on='chr_index', how='outer', suffixes=suffixes) ls_columns = ['chr_index'] for label in label_panel: str_tmp = ' '.join( ['ref_base_', 'A_', 'G_', 'C_', 'T_', 'del_count_', 'insert_count_', 'ambiguous_count_', 'deletion_', 'insertion_', 'ambiguous_', 'mut_num_', 'label_ ']) ls_tmp = str_tmp.replace('_ ', '_{label} ').format(label=label).strip().split(' ') ls_columns.extend(ls_tmp) df_tmp.columns = ls_columns # print(f'df_tmp = \n{df_tmp}') # define ref_seq as the referencing sequence if len(ls_bmat_table) == 1: ref_seq = "".join(bmat_table['ref_base_' + label_panel[0]].tolist()) elif len(ls_bmat_table) > 1: ref_seq = "".join(bmat_table.ref_base) else: raise ValueError('ref info gets wrong!') df_bmat_all = df_tmp.copy() # print(f'df_bmat_all = \n{df_bmat_all}') # print(f'ref_seq = \n{ref_seq}') # make alignment info sgRNA_align = [""] * len(ref_seq) sgRNA_align_insert = [""] * len(ref_seq) possible_target_region_start = max(aln_start - region_extend_length, 0) possible_target_region_end = min(aln_end + region_extend_length, ref_seq_length) # 266: 0~255 266 - 1 plot_region = (possible_target_region_start, possible_target_region_end) df_bases_select = df_bmat_all.iloc[plot_region[0]: plot_region[1], :] lm.logger.debug( "df_bases_select:\n" + tabulate(df_bases_select, headers="keys", tablefmt="psql") ) # make plot # set color # show indel matplotlib.use('Agg') np.set_printoptions(suppress=True) indel_plot_state = show_indel index_plot_state = show_index box_border_plot_state = box_border # set panel size panel_box_width = 0.4 panel_box_heigth = 0.4 panel_space = 0.05 panel_box_space = box_space # color part base_colors = {"A": "#04E3E3", "T": "#F9B874", "C": "#B9E76B", "G": "#F53798", "N": "#AAAAAA", "-": "#AAAAAA"} # make color breaks color_break_num = 20 break_step = 1.0 / color_break_num # mutation ratio of box lower than min_ratio will show as white min_color_value = min_ratio # mutation ratio of box lower than max_ratio will show as white max_color_value = max_ratio color_break = np.round(np.arange(min_color_value, max_color_value, break_step), 5) # min_color: min color to plot heatmap with RGB format # max_color: max color to plot heatmap with RGB format if isinstance(min_color, tuple): low_color = min_color else: raise BioatInvalidOptionError if isinstance(max_color, tuple): high_color = max_color else: raise BioatInvalidOptionError try: color_list = make_color_list(low_color, high_color, len(color_break) - 1, "Hex") color_list = ["#FFFFFF"] + color_list except: lm.logger.debug(low_color, high_color) lm.logger.debug(color_break) # get plot info total_box_count = plot_region[1] - plot_region[0] # calculate base info and fix zero ls_base_sum_count = [] ls_total_sum_count = [] for label in label_panel: base_sum_count = df_bases_select[ ["A_%s" % label, "G_%s" % label, "C_%s" % label, "T_%s" % label]].apply( lambda x: x.sum(), axis=1) # print base_sum_count total_sum_count = df_bases_select[["A_%s" % label, "G_%s" % label, "C_%s" % label, "T_%s" % label, "del_count_%s" % label, "insert_count_%s" % label]].apply( lambda x: x.sum(), axis=1) base_sum_count[base_sum_count == 0] = 1 base_sum_count = base_sum_count + 1 total_sum_count = total_sum_count + 1 ls_base_sum_count.append(base_sum_count) ls_total_sum_count.append(total_sum_count) # make plot size plot_data_list = None panel_space_coef = None panel_height_coef = None if count_ratio == 'all': panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(default_bases) * 2 else: panel_height_coef = [0.5, 0.9, 0.9] + [0.5] * len(ls_bmat_table) * len(default_bases) # 根据bmat表格个数来控制方格高度 # panel_space_coef = [1, 1, 1] + [0.3] * 3 + [1, 0.3, 1] + [0.3] * 3 + [1, 0.3] if count_ratio == 'all': panel_space_coef = [1, 1, 1] + ([0.3] * (len(ls_bmat_table) - 1) + [4]) * len(default_bases) * 2 else: panel_space_coef = [1, 1, 1] + ([0.3] * (len(ls_bmat_table) - 1) + [4]) * len(default_bases) # 更正heatmap align错误 # plot_heatmap_index = bmat_table_select.chr_index plot_data_list = [ ["Ref_index", df_bases_select.chr_index], ["Target_seq", np.array(target_seq_aln[plot_region[0]: plot_region[1]])], ["Ref_seq", df_bases_select['ref_base_%s' % label_panel[0]]] ] if count_ratio == 'count' or count_ratio == 'all': for to_base in default_bases: if to_base == 'A': for label in label_panel: plot_data_list.append( ["{label}: to A".format(label=label), np.array(df_bases_select["A_{label}".format(label=label)])] ) if to_base == 'G': for label in label_panel: plot_data_list.append( ["{label}: to G".format(label=label), np.array(df_bases_select["G_{label}".format(label=label)])] ) if to_base == 'C': for label in label_panel: plot_data_list.append( ["{label}: to C".format(label=label), np.array(df_bases_select["C_{label}".format(label=label)])] ) if to_base == 'T': for label in label_panel: plot_data_list.append( ["{label}: to T".format(label=label), np.array(df_bases_select["T_{label}".format(label=label)])] ) if to_base == 'Del': for label in label_panel: plot_data_list.append( ["{label}: to Del".format(label=label), np.array(df_bases_select["del_count_{label}".format(label=label)])] ) if to_base == 'Ins': for label in label_panel: plot_data_list.append( ["{label}: to Ins".format(label=label), np.array(df_bases_select["insert_count_{label}".format(label=label)])] ) if count_ratio == 'ratio' or count_ratio == 'all': for to_base in default_bases: if to_base == 'A': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to A(%)".format(label=label), np.array( df_bases_select["A_{label}".format(label=label)] / ls_base_sum_count[index])] ) if to_base == 'G': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to G(%)".format(label=label), np.array( df_bases_select["G_{label}".format(label=label)] / ls_base_sum_count[index])] ) if to_base == 'C': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to C(%)".format(label=label), np.array( df_bases_select["C_{label}".format(label=label)] / ls_base_sum_count[index])] ) if to_base == 'T': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to T(%)".format(label=label), np.array( df_bases_select["T_{label}".format(label=label)] / ls_base_sum_count[index])] ) if to_base == 'Del': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to Del(%)".format(label=label), np.array( df_bases_select["del_count_{label}".format(label=label)] / ls_total_sum_count[ index])] ) if to_base == 'Ins': for index, label in enumerate(label_panel): plot_data_list.append( ["{label}: to Ins(%)".format(label=label), np.array( df_bases_select["insert_count_{label}".format(label=label)] / ls_total_sum_count[ index])] ) # get box and space info box_height_list = np.array(panel_height_coef) * panel_box_heigth panel_space_list = np.array(panel_space_coef) * panel_space # for i, j in enumerate(panel_space_coef): # print i,j # print panel_space # calculate figure total width and height figure_width = total_box_count * panel_box_width + ( total_box_count - 1) * panel_box_space + panel_box_width * 2 figure_height = sum(box_height_list) + sum(panel_space_list) lm.logger.debug(f"figure_width = {figure_width}") lm.logger.debug(f"figure_height = {figure_height}") # make all box_x box_x_vec = np.arange(0, figure_width + panel_box_width, panel_box_width + panel_box_space) box_x_vec = box_x_vec[:(len(ref_seq) + 1)] # make box border if box_border_plot_state: box_edgecolor = "#AAAAAA" box_linestyle = "-" box_linewidth = 2 else: box_edgecolor = "#FFFFFF" box_linestyle = "None" box_linewidth = 0 # make box_y initialize current_y = 0 df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef)) ls_row_name = [] for panel_index in range(len(panel_height_coef)): ls_row_name.append(plot_data_list[panel_index][0]) df_matrix.iloc[panel_index, :] = df_matrix.iloc[panel_index, :].astype(np.str_) for index, box_value in enumerate(plot_data_list[panel_index][1]): row = panel_index col = index df_matrix.iloc[row, col] = box_value # print(f'box_value = {box_value}, row = {row}, col = {col}') # print(f'df_matrix.iloc[row, col] = {df_matrix.iloc[row, col]}') # if box_value == '': # raise ValueError df_matrix.index = ls_row_name # plot heatmap if output_fig_heatmap: # 预设参数 num_extend = region_extend_length # rename columns of df_matrix df_matrix.columns = range(1, df_matrix.shape[1] + 1) dt_base = {'A': [], 'G': [], 'C': [], 'T': []} for info in heatmap_mut_direction: dt_base[info[0]] += info[1] lm.logger.debug("[mut_direction] for heatmap: %s" % dt_base) lm.logger.debug("[label_panel] for heatmap: %s" % label_panel) lm.logger.debug( "[region_extend_length] for heatmap: %s" % region_extend_length ) lm.logger.debug("[block_ref] for Target-seq multiplot: %s" % block_ref) ls_col_not_null = df_matrix.columns[-(df_matrix.loc['Target_seq', :] == '')].tolist() ls_col_all = df_matrix.columns.tolist() ls_plot_range = list(range(ls_col_not_null[0] - num_extend, ls_col_not_null[-1] + num_extend)) try: df_matrix = df_matrix.loc[:, ls_plot_range] # 有不在的就不为空,不为空则判断成立 # print 'testTTTTTTTTTTTTTTT' except: raise ValueError('The param [num_extend] is too large or <0, it must be a integer>=0') ls_on_target = df_matrix.loc['Target_seq', :].tolist() ls_ref = df_matrix.loc['Ref_seq', :].tolist() ls_ratio = [] for label in label_panel: df_sample = df_matrix[[label == i.split(': to')[0] and '(%)' not in i for i in df_matrix.index]] df_sample.index = ['A', 'G', 'C', 'T', 'Ins', 'Del'] ls_A = df_sample.T['A'].isnull().tolist() ls_G = df_sample.T['G'].isnull().tolist() ls_C = df_sample.T['C'].isnull().tolist() ls_T = df_sample.T['T'].isnull().tolist() ls_bool = [] for i in range(len(ls_A)): if ls_A[i] & ls_C[i] & ls_G[i] & ls_T[i]: ls_bool.append(False) else: ls_bool.append(True) ls_bl_select_df_sample = ls_bool df_sample = df_sample.loc[:, ls_bl_select_df_sample] df_sample = df_sample.loc[['A', 'G', 'C', 'T'], :].T.copy() df_sample = df_sample.astype(int) df_sample['sum'] = df_sample['A'] + df_sample['G'] + df_sample['C'] + df_sample['T'] df_sample['Ref_seq'] = np.array(ls_ref)[ls_bl_select_df_sample] df_sample['To_base'] = df_sample['Ref_seq'].map(lambda x: dt_base[x]) df_sample['Mut_count'] = 0 for recode in df_sample.iterrows(): index_row = recode[0] for to_base in recode[1]['To_base']: df_sample.loc[index_row, 'Mut_count'] += recode[1][to_base] df_sample['Mut_ratio'] = df_sample['Mut_count'] / df_sample['sum'] ls_sample_ratio = df_sample['Mut_ratio'].tolist() ls_ratio.append(ls_sample_ratio) df_ratio_all = pd.DataFrame(ls_ratio).T df_ratio_all.columns = label_panel try: lm.logger.debug("Catch NA: {}".format(df_ratio_all.isna().sum().sum())) lm.logger.debug(df_ratio_all.head()) # print(f'df_matrix.T = \n{df_matrix.T}') df_ratio_all.index = df_matrix.T['Ref_index'][ls_bl_select_df_sample].map(float).map(int).tolist() df_ratio_all['Target_seq'] = df_matrix.T['Target_seq'][ls_bl_select_df_sample].tolist() df_ratio_all['Ref_seq'] = df_matrix.T['Ref_seq'][ls_bl_select_df_sample].tolist() # print(f'df_ratio_all.T = \n{df_ratio_all.T}') # exit() except ValueError: start_idx = df_matrix.T['Ref_index'][ls_bl_select_df_sample].map(float).map(int).tolist()[0] end_idx = start_idx + len(df_ratio_all.index.tolist()) df_ratio_all.index = np.arange(start_idx, end_idx) df_ratio_all['Target_seq'] = df_matrix.T['Target_seq'].tolist() df_ratio_all['Ref_seq'] = df_matrix.T['Ref_seq'].tolist() df_ratio_all['On-Target'] = '' df_ratio_all['Reference'] = '' df_ratio_all = df_ratio_all[['Target_seq', 'Ref_seq', 'On-Target', 'Reference'] + label_panel].T.copy() df_onTarget_Ref = df_ratio_all.iloc[:2, :].fillna(' ') lm.logger.debug(f"df_onTarget_Ref: \n{df_onTarget_Ref}") # exit() # print(df_onTarget_Ref == '') df_onTarget_Ref_color = df_ratio_all.iloc[:2, :].fillna('').map(plot_agct) lm.logger.debug(f"df_onTarget_Ref_color: \n{df_onTarget_Ref_color}") # exit() lm.logger.debug(f"df_ratio_all: \n{df_ratio_all}") # exit() df_plot = df_ratio_all.iloc[2:, :].copy() lm.logger.debug(f"df_plot: \n{df_plot}") # print(f'df_plot: \n{df_plot}') # exit() df_plot.iloc[2:, :] = df_plot.iloc[2:, :].astype(float) * 100 ls_max = [] for i in df_plot.iloc[2:, :].values.tolist(): ls_max.extend(i) try: ls_break = list(np.arange(0, max(ls_max), max(ls_max) / 100)) except ZeroDivisionError: ls_break = [0, 0.003, 0.006, 0.009, 0.012, 0.015, 0.018, 0.022, 0.026, 0.030, 0.035, 0.040, 0.05, 0.06, 0.07, 0.08, 1.00, 2.00, 3.00] ls_color_middle = make_color_list( low_color_RGB=convert_hex_to_rgb('#87BDDB'), high_color_RGB=convert_hex_to_rgb('#0A306A'), length_out=80, return_fmt="Hex") ls_color_top = ['#0A306A'] * 20 ls_color_bottom = ['#EFEFEF'] * 5 + ['#DDEAF6'] * 5 + ['#A9CEE4'] * 5 + ['#87BDDB'] * 5 ls_color = ls_color_bottom + ls_color_middle + ls_color_top lm.logger.debug(f"ls_color: {ls_color}") # exit() df_onTarget_Ref_tmp = df_onTarget_Ref.T df_onTarget_Ref_tmp.loc[df_onTarget_Ref_tmp['Target_seq'].isnull(), 'Target_seq'] = pd.NA df_onTarget_Ref = df_onTarget_Ref_tmp.T df_plot_rec = df_plot.copy() lm.logger.debug( "df_plot_rec:\n" + tabulate(df_plot_rec, headers="keys", tablefmt="psql") ) # heatmap颜色,去map_hex_for_matrix函数中调整 def map_hex_for_matrix(x): # 判断value范围并返回颜色的Hex值 value = -1 for value in ls_break: if x < value: break else: continue return ls_color[ls_break.index(value) - 1] lm.logger.debug(f"df_plot_rec: \n{df_plot_rec}") # print(f'df_plot_rec: \n{df_plot_rec}') # exit() df_plot_rec.iloc[2:, :] = df_plot_rec.iloc[2:, :].map(map_hex_for_matrix) df_plot_rec_cmap = df_plot_rec.copy() # print(df_plot_rec_cmap) # exit() # print(f'df_plot_rec = \n{df_plot_rec}') # print(df_plot_rec.loc['On-Target', :]) # print(df_plot_rec.loc['On-Target', :].dtype) # df_matrix = pd.DataFrame([[''] * len(plot_data_list[0][1])] * len(panel_height_coef)) # print(f"df_onTarget_Ref.loc['Target_seq', :] = \n{df_onTarget_Ref.loc['Target_seq', :]}") # print('df_onTarget_Ref', df_onTarget_Ref.loc['Target_seq', :] == '') # print(f"df_plot_rec.loc['On-Target', :] = \n{df_plot_rec.loc['On-Target', :]}") # print('df_plot_rec', df_plot_rec.loc['On-Target', :] == '') # exit() df_plot_rec.loc['On-Target', :] = df_onTarget_Ref.loc['Target_seq', :] # print(df_plot_rec.loc['On-Target', :]) # print(df_onTarget_Ref.loc['Target_seq', :]) df_plot_rec.loc['Reference', :] = df_onTarget_Ref.loc['Ref_seq', :] df_plot_rec_cmap.loc['On-Target', :] = df_onTarget_Ref.loc['Target_seq', :].map(plot_agct) df_plot_rec_cmap.loc['Reference', :] = df_onTarget_Ref.loc['Ref_seq', :].map(plot_agct) figure_width_heatmap = df_plot_rec.shape[1] figure_height_heatmap = max(df_plot_rec.shape[0], len(ls_break) * 0.8) scale = 1.1 # 1.1 fig_heatmap = plt.figure(figsize=(figure_width_heatmap * scale, figure_height_heatmap * scale)) ax = fig_heatmap.add_subplot(111, aspect="equal") plt.xlim([0, figure_width_heatmap + 5]) plt.ylim([-figure_height_heatmap - 0.5, 0]) plt.axis("off") # export heatmap values lm.logger.debug("export heatmap reference table...") lm.logger.debug(f"df_plot_rec: \n{df_plot_rec.index}") # exit() for row in range(df_plot_rec.shape[0]): row_name = df_plot_rec.index[row] # add panel name ax.text( x=-1, y=0.55 - row - 1.1, s=row_name, horizontalalignment='right', verticalalignment='center', fontsize=34 / 1.1 * scale, fontname="DejaVu Sans", alpha=1, ) # add Target_seq and ref if row_name in ['On-Target', 'Reference']: for col in range(df_plot.shape[1]): # plot Rectangle site_x = col site_y = - row + 0.05 - 1.05 ax.add_patch(matplotlib.patches.Rectangle( (site_x, site_y), width=1, height=1, linestyle='-', fill=True, facecolor=df_plot_rec_cmap.iloc[row, col], edgecolor='#AAAAAA' if df_plot_rec_cmap.iloc[row, col] != '#FFFFFF' else '#FFFFFF', linewidth=3.5 )) # plot text site_x = col site_y = -row - 0.55 ax.text( x=site_x + 0.5, y=site_y, s=df_plot_rec.iloc[row, col], horizontalalignment='center', verticalalignment='center', fontsize=34 / 1 * scale, fontname="DejaVu Sans", alpha=1 ) # add seq panels else: for col in range(df_plot.shape[1]): # plot Rectangle site_x = col site_y = -row - 1.05 ax.add_patch(matplotlib.patches.Rectangle( (site_x, site_y), width=1, height=1, linestyle='-', fill=True, facecolor=df_plot_rec_cmap.iloc[row, col], edgecolor='#FFFFFF', linewidth=3 )) # add cbar # start from this site site_x = df_plot.shape[1] + 1 site_y = -figure_height_heatmap * 0.1 - 1 step_scale = 0.1 lm.logger.debug("[cbar_scale]: %s" % step_scale) for color in ls_color: site_y += step_scale ax.add_patch(matplotlib.patches.Rectangle( (site_x, site_y), width=1, height=step_scale, fill=True, facecolor=color, edgecolor=color, )) # add cbar text # start from this site site_x = df_plot.shape[1] + 1 site_y = -figure_height_heatmap * 0.1 for idx, label in enumerate(ls_break): site_y += step_scale if idx % 10 == 0: ax.text( x=site_x + 1.1, y=site_y - 1, s=round(label, 4), horizontalalignment='left', verticalalignment='center', fontsize=34 / 1.5, # fontsize=34 / 1.5 * step_scale, fontname="DejaVu Sans", alpha=1 ) fig_heatmap.savefig(output_fig_heatmap, bbox_inches='tight') # 减少边缘空白 # plt.show() if output_table_heatmap: if output_table_heatmap.endswith('.csv'): df_plot_rec.to_csv(output_table_heatmap, sep=',') elif output_table_heatmap.endswith('.tsv'): df_plot_rec.to_csv(output_table_heatmap, sep='\t') if output_fig_count_ratio: # set new figure fig = plt.figure(figsize=(figure_width * 1.1, figure_height * 1.1)) ax = fig.add_subplot(111, aspect="equal") plt.xlim([0, figure_width]) plt.ylim([-figure_height, 0]) plt.axis("off") text_list = [] patches = [] # print panel_height_coef for panel_index in range(len(panel_height_coef)): # print 'panel index:', panel_index # panel name panel_name = plot_data_list[panel_index][0] panel_name_x = box_x_vec[0] # print panel_name # print panel_name_x panel_name_y = current_y - box_height_list[panel_index] * 0.5 # print panel_name_y text_list.append((panel_name_x, panel_name_y, panel_name, 10)) # plot panel box # plot Index行 if panel_name == "Ref_index": # don't draw box, only add text for index, box_value in enumerate(plot_data_list[panel_index][1]): box_x = box_x_vec[index + 1] text_list.append( (box_x + panel_box_width * 0.5, current_y - box_height_list[panel_index] * 0.5, str(box_value), 6)) # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) # plot Target_seq和Ref_seq行 elif panel_name in ["Target_seq", "Ref_seq"]: # if panel_name = Ref, form a new list to store plot info for checking the block_ref param if panel_name == 'Ref_seq': plot_data_list_ref = plot_data_list[panel_index][1] for index, box_value in enumerate(plot_data_list[panel_index][1]): # print(index, box_value) if box_value == "": box_fill = False box_color = "#FFFFFF" else: if aln_direction == "Reverse Alignment": if panel_name == "Ref_seq": box_value = "".join([DNA_rev_cmp_dict.get(x) for x in box_value]) else: ### fix bug for x in box_value: box_value = DNA_rev_cmp_dict.get(x) box_color = base_color_dict.get(box_value[0]) else: box_color = base_color_dict.get(box_value[-1]) if not box_color: box_fill = False box_color = "#FFFFFF" else: box_fill = True box_x = box_x_vec[index + 1] patches.append(Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=box_fill, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color) ) # text text_list.append( (box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 16)) # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) # 用来plot Sample行的 # 这里plot counts elif (panel_name.split('to ')[-1] in default_bases): # print 'panel name:', panel_name.split('to ')[-1].replace('(%)', '') # print panel_name, 'check panel name' # 改进了这里total 和 base的值的引用 if count_ratio == 'all': ls_base_sum_count_all = ls_base_sum_count * 2 * len(default_bases) ls_total_sum_count_all = ls_total_sum_count * 2 * len(default_bases) total_sum_count = ls_total_sum_count_all[panel_index - 3] base_sum_count = ls_base_sum_count_all[panel_index - 3] else: ls_base_sum_count_all = ls_base_sum_count * len(default_bases) ls_total_sum_count_all = ls_total_sum_count * len(default_bases) total_sum_count = ls_total_sum_count_all[panel_index - 3] base_sum_count = ls_base_sum_count_all[panel_index - 3] panel_name = panel_name.split('to ')[-1] if panel_name in ["Del", "Ins"]: box_ratio = plot_data_list[panel_index][1] / total_sum_count else: box_ratio = plot_data_list[panel_index][1] / base_sum_count box_color_list = map_color(box_ratio, color_break, color_list) ls_base_sum = base_sum_count.tolist() for index, box_value in enumerate(plot_data_list[panel_index][1]): # if count_ratio == 'all': box_color = box_color_list[index] # check block_ref state: count plot if block_ref: if plot_data_list_ref.iloc[index] == panel_name: # if box_value >= ls_base_sum[index] * 0.8: # box_fill = False box_color = "#FFFFFF" box_x = box_x_vec[index + 1] patches.append(Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color) ) # text text_list.append( (box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], str(box_value), 6)) # print box_value # make next panel_y current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) # 这里plot ratio(剩下的都是base(%)) else: box_color_list = map_color(plot_data_list[panel_index][1], color_break, color_list) # print box_color_list panel_name = panel_name.split('to ')[-1].replace('(%)', '') for index, box_value in enumerate(plot_data_list[panel_index][1]): # ratio box_color = box_color_list[index] # check block_ref state: ratio plot if block_ref: if plot_data_list_ref.iloc[index] == panel_name: # if box_value >= 0.85: # box_fill = False box_color = "#FFFFFF" box_x = box_x_vec[index + 1] patches.append(Rectangle( xy=(box_x, current_y - box_height_list[panel_index]), width=panel_box_width, height=box_height_list[panel_index], fill=True, alpha=1, linestyle=box_linestyle, linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor=box_color) ) if '(%)' in plot_data_list[panel_index][0]: # text # ref # print 'test', box_value # if ARGS.block_ref == True: # if box_value >= 0.99: # box_value = 0 text_list.append( (box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], round(box_value * 100, 4), 6)) else: text_list.append( (box_x + 0.5 * panel_box_width, current_y - 0.5 * box_height_list[panel_index], box_value, 6)) # make next panel_y if panel_index < len(panel_space_list): current_y = current_y - (box_height_list[panel_index] + panel_space_list[panel_index]) # plot box ax.add_collection(PatchCollection(patches, match_original=True)) # add text for text_x, text_y, text_info, text_fontsize in text_list: if " to " in str(text_info): plt.text( x=text_x + 0.3, y=text_y, s=text_info, horizontalalignment='right', verticalalignment='center', fontsize=text_fontsize, fontname="DejaVu Sans" ) else: plt.text( x=text_x, y=text_y, s=text_info, horizontalalignment='center', verticalalignment='center', fontsize=text_fontsize, fontname="DejaVu Sans" ) # output plot if output_fig_fmt.upper == "PNG": fig.savefig(fname=output_fig_count_ratio, bbox_inches='tight', dpi=output_fig_dpi, format="png") else: fig.savefig(fname=output_fig_count_ratio, bbox_inches='tight', dpi=output_fig_dpi, format="pdf") if output_table_count_ratio: if output_table_count_ratio.endswith('.csv'): df_bases_select.to_csv(output_table_count_ratio, sep=',') elif output_table_count_ratio.endswith('.tsv'): df_bases_select.to_csv(output_table_count_ratio, sep='\t')
if __name__ == '__main__': pass