|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
from sklearn.manifold import TSNE |
|
import matplotlib.font_manager as fm |
|
from matplotlib.font_manager import FontProperties |
|
import matplotlib.pyplot as plt |
|
import matplotlib.gridspec as gridspec |
|
import matplotlib.patches as patches |
|
import seaborn as sns |
|
import umap |
|
import os |
|
|
|
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
|
import fuson_plm.benchmarking.embedding_exploration.config as config |
|
from fuson_plm.utils.visualizing import set_font |
|
from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS |
|
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy |
|
|
|
|
|
def get_dimred_embeddings(embeddings, dimred_type="umap"): |
|
if dimred_type=="umap": |
|
dimred_embeddings = get_umap_embeddings(embeddings) |
|
return dimred_embeddings |
|
if dimred_type=="tsne": |
|
dimred_embeddings = get_tsne_embeddings(embeddings) |
|
return dimred_embeddings |
|
|
|
def get_tsne_embeddings(embeddings): |
|
embeddings = np.array(embeddings) |
|
tsne = TSNE(n_components=2, random_state=42,perplexity=5) |
|
tsne_embeddings = tsne.fit_transform(embeddings) |
|
return tsne_embeddings |
|
|
|
def get_umap_embeddings(embeddings): |
|
embeddings = np.array(embeddings) |
|
umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') |
|
umap_embeddings = umap_model.fit_transform(embeddings) |
|
return umap_embeddings |
|
|
|
def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100): |
|
""" |
|
Plots a circle filled in halves with specified colors. |
|
|
|
Parameters: |
|
- ax: Matplotlib axis to draw on. |
|
- x, y: Coordinates of the marker. |
|
- left_color: Color of the left half. |
|
- right_color: Color of the right half. |
|
- size: Size of the marker. |
|
""" |
|
radius = (size ** 0.5) / 100 |
|
|
|
left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black") |
|
|
|
right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black") |
|
|
|
|
|
ax.add_patch(left_half) |
|
ax.add_patch(right_half) |
|
|
|
def plot_umap_scatter_tftf_kk(df, filename="umap.png"): |
|
""" |
|
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'. |
|
Only for TF::TF and Kinase::Kinase fusions |
|
|
|
Parameters: |
|
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns. |
|
""" |
|
set_font() |
|
|
|
|
|
colors = { |
|
"TF": "pink", |
|
"Kinase": "orange" |
|
} |
|
|
|
|
|
marker_colors = { |
|
"TF::TF": colors["TF"], |
|
"Kinase::Kinase": colors["Kinase"], |
|
} |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1 |
|
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1 |
|
ax.set_xlim(x_min, x_max) |
|
ax.set_ylim(y_min, y_max) |
|
|
|
|
|
for i in range(len(df)): |
|
row = df.iloc[i] |
|
marker_type = row["fusion_type"] |
|
x, y = row["umap1"], row["umap2"] |
|
color = marker_colors[marker_type] |
|
|
|
ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5) |
|
|
|
|
|
legend_elements = [ |
|
patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"), |
|
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase") |
|
] |
|
ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16) |
|
|
|
|
|
plt.xlabel("UMAP 1", fontsize=20) |
|
plt.ylabel("UMAP 2", fontsize=20) |
|
plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20) |
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(filename, dpi=300) |
|
plt.show() |
|
|
|
def plot_umap_scatter_half_filled(df, filename="umap.png"): |
|
""" |
|
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'. |
|
|
|
Parameters: |
|
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns. |
|
""" |
|
|
|
colors = { |
|
"TF": "pink", |
|
"Kinase": "orange", |
|
"Other": "grey" |
|
} |
|
|
|
|
|
marker_colors = { |
|
"TF::TF": {"left": colors["TF"], "right": colors["TF"]}, |
|
"TF::Other": {"left": colors["TF"], "right": colors["Other"]}, |
|
"Other::TF": {"left": colors["Other"], "right": colors["TF"]}, |
|
"Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]}, |
|
"Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]}, |
|
"Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]}, |
|
"Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]}, |
|
"TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]}, |
|
"Other::Other": {"left": colors["Other"], "right": colors["Other"]} |
|
} |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1 |
|
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1 |
|
ax.set_xlim(x_min, x_max) |
|
ax.set_ylim(y_min, y_max) |
|
|
|
|
|
for i in range(len(df)): |
|
row = df.iloc[i] |
|
marker_type = row["fusion_type"] |
|
x, y = row["umap1"], row["umap2"] |
|
left_color = marker_colors[marker_type]["left"] |
|
right_color = marker_colors[marker_type]["right"] |
|
plot_half_filled_circle(ax, x, y, left_color, right_color, size=100) |
|
|
|
|
|
legend_elements = [ |
|
patches.Patch(facecolor="pink", edgecolor="black", label="TF"), |
|
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"), |
|
patches.Patch(facecolor="grey", edgecolor="black", label="Other") |
|
] |
|
ax.legend(handles=legend_elements, title="Type") |
|
|
|
|
|
plt.xlabel("UMAP 1") |
|
plt.ylabel("UMAP 2") |
|
plt.title("UMAP Scatter Plot") |
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(filename, dpi=300) |
|
plt.show() |
|
|
|
def get_gene_type(gene, d): |
|
if gene in d: |
|
if d[gene] == 'kinase': |
|
return 'Kinase' |
|
if d[gene] == 'tf': |
|
return 'TF' |
|
else: |
|
return 'Other' |
|
|
|
def get_tf_and_kinase_fusions_dataset(): |
|
|
|
tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv") |
|
print(tf_kinase_parts) |
|
ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF'])) |
|
|
|
|
|
fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv") |
|
fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True) |
|
|
|
fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict)) |
|
fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict)) |
|
fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type'] |
|
fuson_ht_db['type']=['fusion']*len(fuson_ht_db) |
|
|
|
categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist() |
|
categories = ["TF::TF","Kinase::Kinase"] |
|
print(categories) |
|
plot_df = None |
|
|
|
for i, cat in enumerate(categories): |
|
random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True) |
|
|
|
if i==0: |
|
plot_df = random_sample |
|
else: |
|
plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True) |
|
|
|
print(plot_df['fusion_type'].value_counts()) |
|
|
|
|
|
plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename( |
|
columns={'aa_seq':'sequence','fusiongenes':'ID'} |
|
) |
|
|
|
return plot_df |
|
|
|
def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'): |
|
fuson_db = pd.read_csv("../../data/fuson_db.csv") |
|
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) |
|
|
|
|
|
data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']] |
|
data['seq_id'] = data['sequence'].map(seq_id_dict) |
|
|
|
tfkinase_save_dir = f"{savedir}" |
|
os.makedirs(tfkinase_save_dir,exist_ok=True) |
|
data.to_csv(f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_source_data.csv",index=False) |
|
plot_umap_scatter_tftf_kk(data,filename=f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_visualization.png") |
|
|
|
def tf_and_kinase_fusions_plot(dimred_types, output_dir): |
|
""" |
|
Makes the embeddings, THEN calls the plot. only on the four favorites |
|
""" |
|
plot_df = get_tf_and_kinase_fusions_dataset() |
|
plot_df.to_csv("data/tf_and_kinase_fusions.csv",index=False) |
|
|
|
|
|
input_fname='tf_and_kinase' |
|
all_embedding_paths = embed_dataset_for_benchmark( |
|
fuson_ckpts=config.FUSON_PLM_CKPT, |
|
input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname, |
|
average=True, seq_col='sequence', |
|
benchmark_fusonplm=True, |
|
benchmark_esm=False, |
|
benchmark_fo_puncta_ml=False, |
|
overwrite=config.PERMISSION_TO_OVERWRITE) |
|
|
|
|
|
log_update("\nEmbedding sequences") |
|
|
|
for embedding_path, details in all_embedding_paths.items(): |
|
log_update(f"\tBenchmarking embeddings at: {embedding_path}") |
|
try: |
|
with open(embedding_path, "rb") as f: |
|
embeddings = pickle.load(f) |
|
except: |
|
raise Exception(f"Cannot read embeddings from {embedding_path}") |
|
|
|
|
|
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) |
|
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) |
|
seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner') |
|
|
|
for dimred_type in dimred_types: |
|
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type) |
|
|
|
|
|
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2']) |
|
|
|
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1]) |
|
|
|
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data |
|
|
|
|
|
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1]) |
|
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}" |
|
|
|
os.makedirs(cur_output_dir,exist_ok=True) |
|
make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type) |
|
|
|
def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'): |
|
""" |
|
Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails |
|
""" |
|
set_font() |
|
|
|
|
|
data = pd.read_csv("data/top_genes.csv") |
|
seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence") |
|
seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings) |
|
seqs_with_embeddings.loc[ |
|
seqs_with_embeddings["gene"].str.contains("::"),"Type" |
|
] = "fusion_embeddings" |
|
heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist() |
|
tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist() |
|
seqs_with_embeddings.loc[ |
|
seqs_with_embeddings["gene"].isin(heads),"Type" |
|
] = "h_embeddings" |
|
seqs_with_embeddings.loc[ |
|
seqs_with_embeddings["gene"].isin(tails),"Type" |
|
] = "t_embeddings" |
|
|
|
|
|
merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']] |
|
merge["head"] = merge["gene"].str.split("::",expand=True)[0] |
|
merge["tail"] = merge["gene"].str.split("::",expand=True)[1] |
|
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename( |
|
columns={'gene': 'head', 'sequence': 'h_sequence'}), |
|
on='head',how='left' |
|
) |
|
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename( |
|
columns={'gene': 'tail', 'sequence': 't_sequence'}), |
|
on='tail',how='left' |
|
) |
|
|
|
plt.figure() |
|
|
|
|
|
colors = { |
|
'fusion_embeddings': '#cf9dfa', |
|
'h_embeddings': '#eb8888', |
|
't_embeddings': '#5fa3e3', |
|
} |
|
markers = { |
|
'fusion_embeddings': 'o', |
|
'h_embeddings': '^', |
|
't_embeddings': 'v' |
|
} |
|
label_map = { |
|
'fusion_embeddings': 'Fusion', |
|
'h_embeddings': 'Head', |
|
't_embeddings': 'Tail', |
|
} |
|
|
|
|
|
fig, axes = plt.subplots(2, 3, figsize=(18, 12)) |
|
|
|
|
|
|
|
all_tsne1 = seqs_with_embeddings[f'{dimred_type}1'] |
|
all_tsne2 = seqs_with_embeddings[f'{dimred_type}2'] |
|
x_min, x_max = all_tsne1.min(), all_tsne1.max() |
|
y_min, y_max = all_tsne2.min(), all_tsne2.max() |
|
x_min, x_max = [11, 16] |
|
y_min, y_max = [10, 22] |
|
|
|
|
|
x_ticks = np.arange(x_min, x_max + 1, 1) |
|
y_ticks = np.arange(y_min, y_max + 1, 1) |
|
|
|
|
|
axes = axes.flatten() |
|
|
|
for i, ax in enumerate(axes): |
|
|
|
fgene_name = merge.loc[i, 'gene'] |
|
hgene = merge.loc[i, 'head'] |
|
tgene = merge.loc[i, 'tail'] |
|
|
|
|
|
tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])] |
|
|
|
|
|
for emb_type in tsne_data['Type'].unique(): |
|
subset = tsne_data[tsne_data['Type'] == emb_type] |
|
ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3) |
|
|
|
ax.set_title(f'{fgene_name}',fontsize=44) |
|
label_transform = { |
|
'tsne': 't-SNE', |
|
'umap': 'UMAP' |
|
} |
|
ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44) |
|
ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44) |
|
ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1) |
|
|
|
|
|
ax.set_xlim(x_min, x_max) |
|
ax.set_ylim(y_min, y_max) |
|
ax.set_xticks(x_ticks) |
|
ax.set_yticks(y_ticks) |
|
|
|
|
|
ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right') |
|
|
|
ax.tick_params(axis='x', labelsize=16) |
|
ax.tick_params(axis='y', labelsize=16) |
|
|
|
for label in ax.get_xticklabels(): |
|
label.set_fontsize(24) |
|
for label in ax.get_yticklabels(): |
|
label.set_fontsize(24) |
|
|
|
|
|
if i == 0: |
|
legend = ax.legend(fontsize=20, markerscale=2, loc='best') |
|
for text in legend.get_texts(): |
|
text.set_fontsize(24) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
plt.show() |
|
|
|
|
|
plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300) |
|
|
|
|
|
seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv") |
|
seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id'])) |
|
seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict) |
|
seqs_with_embeddings[['umap1','umap2','sequence','Type','gene','id','seq_id']].to_csv(f"{savedir}/{dimred_type}_favorites_source_data.csv",index=False) |
|
|
|
def fusion_v_parts_favorites(dimred_types, output_dir): |
|
""" |
|
Makes the embeddings, THEN calls the plot. only on the four favorites |
|
""" |
|
|
|
|
|
input_fname='favorites' |
|
all_embedding_paths = embed_dataset_for_benchmark( |
|
fuson_ckpts=config.FUSON_PLM_CKPT, |
|
input_data_path='data/top_genes.csv', input_fname=input_fname, |
|
average=True, seq_col='sequence', |
|
benchmark_fusonplm=True, |
|
benchmark_esm=False, |
|
benchmark_fo_puncta_ml=False, |
|
overwrite=config.PERMISSION_TO_OVERWRITE) |
|
|
|
|
|
log_update("\nEmbedding sequences") |
|
|
|
for embedding_path, details in all_embedding_paths.items(): |
|
log_update(f"\tBenchmarking embeddings at: {embedding_path}") |
|
try: |
|
with open(embedding_path, "rb") as f: |
|
embeddings = pickle.load(f) |
|
except: |
|
raise Exception(f"Cannot read embeddings from {embedding_path}") |
|
|
|
|
|
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) |
|
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) |
|
|
|
|
|
for dimred_type in dimred_types: |
|
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type) |
|
|
|
|
|
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2']) |
|
|
|
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1]) |
|
|
|
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data |
|
|
|
|
|
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1]) |
|
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}" |
|
|
|
os.makedirs(cur_output_dir,exist_ok=True) |
|
make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type) |
|
|
|
def main(): |
|
|
|
os.makedirs('results',exist_ok=True) |
|
output_dir = f'results/{get_local_time()}' |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
dimred_types = [] |
|
if config.PLOT_UMAP: |
|
dimred_types.append("umap") |
|
|
|
os.makedirs(f"{output_dir}/umap_plots",exist_ok=True) |
|
if config.PLOT_TSNE: |
|
dimred_types.append("tsne") |
|
|
|
os.makedirs(f"{output_dir}/tsne_plots",exist_ok=True) |
|
|
|
with open_logfile(f'{output_dir}/embedding_exploration_log.txt'): |
|
print_configpy(config) |
|
|
|
fusion_v_parts_favorites(dimred_types, output_dir) |
|
|
|
tf_and_kinase_fusions_plot(dimred_types, output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |