ProteinGenesis / app.py
aiqcamp's picture
Update app.py
93008f1 verified
raw
history blame
60.4 kB
import os,sys
# install required packages
os.system('pip install plotly') # plotly ์„ค์น˜
os.system('pip install matplotlib') # matplotlib ์„ค์น˜
os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
os.environ["DGLBACKEND"] = "pytorch"
print('Modules installed')
# ๊ธฐ๋ณธ args ์„ค์ •
if not os.path.exists('./tmp'):
os.makedirs('./tmp')
if not os.path.exists('./tmp/args.json'):
default_args = {
'checkpoint': None,
'dump_trb': False,
'dump_args': True,
'save_best_plddt': True,
'T': 25,
'strand_bias': 0.0,
'loop_bias': 0.0,
'helix_bias': 0.0,
'd_t1d': 24,
'potentials': None,
'potential_scale': None,
'aa_composition': None
}
with open('./tmp/args.json', 'w') as f:
json.dump(default_args, f)
# ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
if not os.path.exists('./SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'):
print('Downloading model weights 1')
os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
print('Successfully Downloaded')
if not os.path.exists('./SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'):
print('Downloading model weights 2')
os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
print('Successfully Downloaded')
from openai import OpenAI
import gradio as gr
import json # json ๋ชจ๋“ˆ ์ถ”๊ฐ€
from datasets import load_dataset
import plotly.graph_objects as go
import numpy as np
import py3Dmol
from io import StringIO
import json
import secrets
import copy
import matplotlib.pyplot as plt
from utils.sampler import HuggingFace_sampler
from utils.parsers_inference import parse_pdb
from model.util import writepdb
from utils.inpainting_util import *
import os
# args ๋กœ๋“œ
with open('./tmp/args.json', 'r') as f:
args = json.load(f)
plt.rcParams.update({'font.size': 13})
# manually set checkpoint to load
args['checkpoint'] = None
args['dump_trb'] = False
args['dump_args'] = True
args['save_best_plddt'] = True
args['T'] = 25
args['strand_bias'] = 0.0
args['loop_bias'] = 0.0
args['helix_bias'] = 0.0
# Hugging Face ํ† ํฐ ์„ค์ •
ACCESS_TOKEN = os.getenv("HF_TOKEN")
if not ACCESS_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
# OpenAI ํด๋ผ์ด์–ธํŠธ ์„ค์ • (Hugging Face ์—”๋“œํฌ์ธํŠธ ์‚ฌ์šฉ)
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=ACCESS_TOKEN,
)
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ๋ฐ ๊ตฌ์กฐ ํ™•์ธ
try:
ds = load_dataset("lamm-mit/protein_secondary_structure_from_PDB",
token=ACCESS_TOKEN)
print("Dataset structure:", ds)
print("First entry example:", next(iter(ds['train'])))
except Exception as e:
print(f"Dataset loading error: {str(e)}")
raise
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for msg in history:
messages.append({"role": "user", "content": msg[0]})
if msg[1]:
messages.append({"role": "assistant", "content": msg[1]})
messages.append({"role": "user", "content": message})
try:
response = ""
for chunk in client.chat.completions.create(
model="CohereForAI/c4ai-command-r-plus-08-2024",
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages=messages,
):
if hasattr(chunk.choices[0].delta, 'content'):
token = chunk.choices[0].delta.content
if token is not None:
response += token
yield [{"role": "user", "content": message},
{"role": "assistant", "content": response}]
return [{"role": "user", "content": message},
{"role": "assistant", "content": response}]
except Exception as e:
print(f"Error in respond: {str(e)}")
return [{"role": "user", "content": message},
{"role": "assistant", "content": f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"}]
def analyze_prompt(message):
"""LLM์„ ์‚ฌ์šฉํ•˜์—ฌ ํ”„๋กฌํ”„ํŠธ ๋ถ„์„"""
try:
analysis_prompt = f"""
๋‹ค์Œ ์š”์ฒญ์„ ๋ถ„์„ํ•˜์—ฌ ๋‹จ๋ฐฑ์งˆ ์„ค๊ณ„์— ํ•„์š”ํ•œ ์ฃผ์š” ํŠน์„ฑ์„ ์ถ”์ถœํ•˜์„ธ์š”:
์š”์ฒญ: {message}
๋‹ค์Œ ํ•ญ๋ชฉ๋“ค์„ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”:
1. ์ฃผ์š” ๊ธฐ๋Šฅ (์˜ˆ: ์น˜๋ฃŒ, ๊ฒฐํ•ฉ, ์ด‰๋งค ๋“ฑ)
2. ๋ชฉํ‘œ ํ™˜๊ฒฝ (์˜ˆ: ์„ธํฌ๋ง‰, ์ˆ˜์šฉ์„ฑ, ๋“ฑ)
3. ํ•„์š”ํ•œ ๊ตฌ์กฐ์  ํŠน์ง•
4. ํฌ๊ธฐ ๋ฐ ๋ณต์žก๋„ ์š”๊ตฌ์‚ฌํ•ญ
"""
response = client.chat.completions.create(
model="CohereForAI/c4ai-command-r-plus-08-2024",
messages=[{"role": "user", "content": analysis_prompt}],
temperature=0.7
)
return response.choices[0].message.content
except Exception as e:
print(f"ํ”„๋กฌํ”„ํŠธ ๋ถ„์„ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return None
def search_protein_data(analysis, dataset):
"""๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์œ ์‚ฌํ•œ ๊ตฌ์กฐ ๊ฒ€์ƒ‰"""
try:
# ํ‚ค์›Œ๋“œ ์ถ”์ถœ
keywords = extract_keywords(analysis)
print("Extracted keywords:", keywords)
# ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ํ™•์ธ
if not dataset or 'train' not in dataset:
print("Invalid dataset structure")
return []
# ์œ ์‚ฌ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ
scored_entries = []
for entry in dataset['train']:
try:
score = calculate_similarity(keywords, entry)
scored_entries.append((score, entry))
except Exception as e:
print(f"Error processing entry: {str(e)}")
continue
# ๊ฒฐ๊ณผ ์ •๋ ฌ ๋ฐ ๋ฐ˜ํ™˜
scored_entries.sort(reverse=True)
return scored_entries[:3]
except Exception as e:
print(f"๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return []
def extract_parameters(analysis, similar_structures):
"""๋ถ„์„ ๊ฒฐ๊ณผ์™€ ์œ ์‚ฌ ๊ตฌ์กฐ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฒฐ์ •"""
try:
# ๊ธฐ๋ณธ ํŒŒ๋ผ๋ฏธํ„ฐ ํ…œํ”Œ๋ฆฟ
params = {
'sequence_length': 100,
'helix_bias': 0.02,
'strand_bias': 0.02,
'loop_bias': 0.1,
'hydrophobic_target_score': 0
}
# ๋ถ„์„ ๊ฒฐ๊ณผ์—์„œ ๊ตฌ์กฐ์  ์š”๊ตฌ์‚ฌํ•ญ ํŒŒ์•…
if "๋ง‰ ํˆฌ๊ณผ" in analysis or "์†Œ์ˆ˜์„ฑ" in analysis:
params['hydrophobic_target_score'] = -2
params['helix_bias'] = 0.03
elif "์ˆ˜์šฉ์„ฑ" in analysis or "๊ฐ€์šฉ์„ฑ" in analysis:
params['hydrophobic_target_score'] = 2
params['loop_bias'] = 0.15
# ์œ ์‚ฌ ๊ตฌ์กฐ๋“ค์˜ ํŠน์„ฑ ๋ฐ˜์˜
if similar_structures:
avg_length = sum(len(s[1]['sequence']) for s in similar_structures) / len(similar_structures)
params['sequence_length'] = int(avg_length)
# ๊ตฌ์กฐ์  ํŠน์„ฑ ๋ถ„์„ ๋ฐ ๋ฐ˜์˜
for _, structure in similar_structures:
if 'secondary_structure' in structure:
helix_ratio = structure['secondary_structure'].count('H') / len(structure['secondary_structure'])
sheet_ratio = structure['secondary_structure'].count('E') / len(structure['secondary_structure'])
params['helix_bias'] = max(0.01, min(0.05, helix_ratio))
params['strand_bias'] = max(0.01, min(0.05, sheet_ratio))
return params
except Exception as e:
print(f"ํŒŒ๋ผ๋ฏธํ„ฐ ์ถ”์ถœ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return None
def process_chat(message, history):
try:
if any(keyword in message.lower() for keyword in ['protein', 'generate', '๋‹จ๋ฐฑ์งˆ', '์ƒ์„ฑ', '์น˜๋ฃŒ']):
# 1. LLM์„ ์‚ฌ์šฉํ•œ ํ”„๋กฌํ”„ํŠธ ๋ถ„์„
analysis = analyze_prompt(message)
if not analysis:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "์š”์ฒญ ๋ถ„์„์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."}
]
# 2. ์œ ์‚ฌ ๊ตฌ์กฐ ๊ฒ€์ƒ‰
similar_structures = search_protein_data(analysis, ds)
if not similar_structures:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "์ ํ•ฉํ•œ ์ฐธ์กฐ ๊ตฌ์กฐ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."}
]
# 3. ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฒฐ์ •
params = extract_parameters(analysis, similar_structures)
if not params:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."}
]
# 4. ๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ
try:
protein_result = protein_diffusion_model(
sequence=None,
seq_len=params['sequence_length'],
helix_bias=params['helix_bias'],
strand_bias=params['strand_bias'],
loop_bias=params['loop_bias'],
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(params['hydrophobic_target_score']),
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
output_seq, output_pdb, structure_view, plddt_plot = next(protein_result)
# 5. ๊ฒฐ๊ณผ ์„ค๋ช… ์ƒ์„ฑ
explanation = f"""
์š”์ฒญํ•˜์‹  ๊ธฐ๋Šฅ์— ๋งž๋Š” ๋‹จ๋ฐฑ์งˆ์„ ์ƒ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค:
๋ถ„์„๋œ ์š”๊ตฌ์‚ฌํ•ญ:
{analysis}
์„ค๊ณ„๋œ ๊ตฌ์กฐ์  ํŠน์ง•:
- ๊ธธ์ด: {params['sequence_length']} ์•„๋ฏธ๋…ธ์‚ฐ
- ์•ŒํŒŒ ํ—ฌ๋ฆญ์Šค ๋น„์œจ: {params['helix_bias']*100:.1f}%
- ๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ: {params['strand_bias']*100:.1f}%
- ๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ: {params['loop_bias']*100:.1f}%
- ์†Œ์ˆ˜์„ฑ ์ ์ˆ˜: {params['hydrophobic_target_score']}
์ฐธ์กฐ๋œ ์œ ์‚ฌ ๊ตฌ์กฐ: {len(similar_structures)}๊ฐœ
์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ์˜ 3D ๊ตฌ์กฐ์™€ ์‹œํ€€์Šค๋ฅผ ํ™•์ธํ•˜์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"""
# 6. ๊ฒฐ๊ณผ ์ €์žฅ
global current_protein_result
current_protein_result = {
'sequence': output_seq,
'pdb': output_pdb,
'structure_view': structure_view,
'plddt_plot': plddt_plot,
'params': params
}
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": explanation}
]
except Exception as e:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"}
]
else:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ๊ด€๋ จ ํ‚ค์›Œ๋“œ๋ฅผ ํฌํ•จํ•ด์ฃผ์„ธ์š”."}
]
except Exception as e:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"}
]
def generate_protein(params):
# ๊ธฐ์กด protein_diffusion_model ํ•จ์ˆ˜ ํ˜ธ์ถœ
result = protein_diffusion_model(
sequence=None,
seq_len=params['sequence_length'],
helix_bias=params['helix_bias'],
strand_bias=params['strand_bias'],
loop_bias=params['loop_bias'],
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(params['hydrophobic_target_score']),
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
return result
def generate_explanation(result, params):
explanation = f"""
์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ ๋ถ„์„:
- ๊ธธ์ด: {params['sequence_length']} ์•„๋ฏธ๋…ธ์‚ฐ
- ๊ตฌ์กฐ์  ํŠน์ง•:
* ์•ŒํŒŒ ๋‚˜์„  ๋น„์œจ: {params['helix_bias']*100}%
* ๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ: {params['strand_bias']*100}%
* ๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ: {params['loop_bias']*100}%
- ํŠน์ˆ˜ ๊ธฐ๋Šฅ: {result.get('special_features', '์—†์Œ')}
"""
return explanation
# ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ์ ˆ๋Œ€ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
# ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ์กด์žฌ ํ™•์ธ
if not os.path.exists(dssp_checkpoint):
raise FileNotFoundError(f"DSSP checkpoint file not found at: {dssp_checkpoint}")
if not os.path.exists(og_checkpoint):
raise FileNotFoundError(f"OG checkpoint file not found at: {og_checkpoint}")
model_args = copy.deepcopy(args)
# make sampler
S = HuggingFace_sampler(args=model_args)
# get random prefix
S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
# set args
S.args['checkpoint'] = None
S.args['dump_trb'] = False
S.args['dump_args'] = True
S.args['save_best_plddt'] = True
S.args['T'] = 20
S.args['strand_bias'] = 0.0
S.args['loop_bias'] = 0.0
S.args['helix_bias'] = 0.0
S.args['potentials'] = None
S.args['potential_scale'] = None
S.args['aa_composition'] = None
# get sequence if entered and make sure all chars are valid
alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
if sequence not in ['',None]:
L = len(sequence)
aa_seq = []
for aa in sequence.upper():
if aa in alt_aa_dict.keys():
aa_seq.append(np.random.choice(alt_aa_dict[aa]))
else:
aa_seq.append(aa)
S.args['sequence'] = aa_seq
elif contigs not in ['',None]:
S.args['contigs'] = [contigs]
else:
S.args['contigs'] = [f'{seq_len}']
L = int(seq_len)
print('DEBUG: ',rewrite_pdb)
if rewrite_pdb not in ['',None]:
S.args['pdb'] = rewrite_pdb.name
if seq_mask not in ['',None]:
S.args['inpaint_seq'] = [seq_mask]
if str_mask not in ['',None]:
S.args['inpaint_str'] = [str_mask]
if secondary_structure in ['',None]:
secondary_structure = None
else:
secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
if L < len(secondary_structure):
secondary_structure = secondary_structure[:len(sequence)]
elif L == len(secondary_structure):
pass
else:
dseq = L - len(secondary_structure)
secondary_structure += secondary_structure[-1]*dseq
# potentials
potential_list = []
potential_bias_list = []
if aa_bias not in ['',None]:
potential_list.append('aa_bias')
S.args['aa_composition'] = aa_bias
if aa_bias_potential in ['',None]:
aa_bias_potential = 3
potential_bias_list.append(str(aa_bias_potential))
'''
if target_charge not in ['',None]:
potential_list.append('charge')
if charge_potential in ['',None]:
charge_potential = 1
potential_bias_list.append(str(charge_potential))
S.args['target_charge'] = float(target_charge)
if target_ph in ['',None]:
target_ph = 7.4
S.args['target_pH'] = float(target_ph)
'''
if hydrophobic_target_score not in ['',None]:
potential_list.append('hydrophobic')
S.args['hydrophobic_score'] = float(hydrophobic_target_score)
if hydrophobic_potential in ['',None]:
hydrophobic_potential = 3
potential_bias_list.append(str(hydrophobic_potential))
if pssm not in ['',None]:
potential_list.append('PSSM')
potential_bias_list.append('5')
S.args['PSSM'] = pssm.name
if len(potential_list) > 0:
S.args['potentials'] = ','.join(potential_list)
S.args['potential_scale'] = ','.join(potential_bias_list)
# normalise secondary_structure bias from range 0-0.3
S.args['secondary_structure'] = secondary_structure
S.args['helix_bias'] = helix_bias
S.args['strand_bias'] = strand_bias
S.args['loop_bias'] = loop_bias
# set T
if num_steps in ['',None]:
S.args['T'] = 20
else:
S.args['T'] = int(num_steps)
# noise
if 'normal' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [0]
S.args['sample_distribution_gmm_variances'] = [1]
elif 'gmm2' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,1]
S.args['sample_distribution_gmm_variances'] = [1,1]
elif 'gmm3' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,0,1]
S.args['sample_distribution_gmm_variances'] = [1,1,1]
if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
S.args['checkpoint'] = dssp_checkpoint
S.args['d_t1d'] = 29
print('using dssp checkpoint')
else:
S.args['checkpoint'] = og_checkpoint
S.args['d_t1d'] = 24
print('using og checkpoint')
for k,v in S.args.items():
print(f"{k} --> {v}")
# init S
S.model_init()
S.diffuser_init()
S.setup()
# sampling loop
plddt_data = []
for j in range(S.max_t):
print(f'on step {j}')
output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
plddt_data.append(plddt)
yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
output_seq, output_pdb, plddt = S.get_outputs()
return output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
def get_plddt_plot(plddt_data, max_t):
x = [i+1 for i in range(len(plddt_data))]
fig, ax = plt.subplots(figsize=(15,6))
ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
ax.set_xticks([i+1 for i in range(max_t)])
ax.set_yticks([(i+1)/10 for i in range(10)])
ax.set_ylim([0,1])
ax.set_ylabel('model confidence (plddt)')
ax.set_xlabel('diffusion steps (t)')
return fig
def display_pdb(path_to_pdb):
'''
#function to display pdb in py3dmol
'''
pdb = open(path_to_pdb, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
return f"""<iframe style="width: 100%; height:700px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
def get_motif_preview(pdb_id, contigs):
try:
input_pdb = fetch_pdb(pdb_id=pdb_id.lower() if pdb_id else None)
if input_pdb is None:
return gr.HTML("PDB ID๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”"), None
parse = parse_pdb(input_pdb)
output_name = input_pdb
pdb = open(output_name, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
if contigs in ['',0]:
contigs = ['0']
else:
contigs = [contigs]
print('DEBUG: ',contigs)
pdb_map = get_mappings(ContigMap(parse,contigs))
print('DEBUG: ',pdb_map)
print('DEBUG: ',pdb_map['con_ref_idx0'])
roi = [x[1]-1 for x in pdb_map['con_ref_pdb_idx']]
colormap = {0:'#D3D3D3', 1:'#F74CFF'}
colors = {i+1: colormap[1] if i in roi else colormap[0] for i in range(parse['xyz'].shape[0])}
view.setStyle({"cartoon": {"colorscheme": {"prop": "resi", "map": colors}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", output_name
except Exception as e:
return gr.HTML(f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
def fetch_pdb(pdb_id=None):
if pdb_id is None or pdb_id == "":
return None
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_id}.pdb")
return f"{pdb_id}.pdb"
# MSA AND PSSM GUIDANCE
def save_pssm(file_upload):
filename = file_upload.name
orig_name = file_upload.orig_name
if filename.split('.')[-1] in ['fasta', 'a3m']:
return msa_to_pssm(file_upload)
return filename
def msa_to_pssm(msa_file):
# Define the lookup table for converting amino acids to indices
aa_to_index = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10,
'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20, '-': 21}
# Open the FASTA file and read the sequences
records = list(SeqIO.parse(msa_file.name, "fasta"))
assert len(records) >= 1, "MSA must contain more than one protein sequecne."
first_seq = str(records[0].seq)
aligned_seqs = [first_seq]
# print(aligned_seqs)
# Perform sequence alignment using the Needleman-Wunsch algorithm
aligner = Align.PairwiseAligner()
aligner.open_gap_score = -0.7
aligner.extend_gap_score = -0.3
for record in records[1:]:
alignment = aligner.align(first_seq, str(record.seq))[0]
alignment = alignment.format().split("\n")
al1 = alignment[0]
al2 = alignment[2]
al1_fin = ""
al2_fin = ""
percent_gap = al2.count('-')/ len(al2)
if percent_gap > 0.4:
continue
for i in range(len(al1)):
if al1[i] != '-':
al1_fin += al1[i]
al2_fin += al2[i]
aligned_seqs.append(str(al2_fin))
# Get the length of the aligned sequences
aligned_seq_length = len(first_seq)
# Initialize the position scoring matrix
matrix = np.zeros((22, aligned_seq_length))
# Iterate through the aligned sequences and count the amino acids at each position
for seq in aligned_seqs:
#print(seq)
for i in range(aligned_seq_length):
if i == len(seq):
break
amino_acid = seq[i]
if amino_acid.upper() not in aa_to_index.keys():
continue
else:
aa_index = aa_to_index[amino_acid.upper()]
matrix[aa_index, i] += 1
# Normalize the counts to get the frequency of each amino acid at each position
matrix /= len(aligned_seqs)
print(len(aligned_seqs))
matrix[20:,]=0
outdir = ".".join(msa_file.name.split('.')[:-1]) + ".csv"
np.savetxt(outdir, matrix[:21,:].T, delimiter=",")
return outdir
def get_pssm(fasta_msa, input_pssm):
try:
if input_pssm is not None:
outdir = input_pssm.name
elif fasta_msa is not None:
outdir = save_pssm(fasta_msa)
else:
return gr.Plot(label="ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”"), None
pssm = np.loadtxt(outdir, delimiter=",", dtype=float)
fig, ax = plt.subplots(figsize=(15,6))
plt.imshow(torch.permute(torch.tensor(pssm),(1,0)))
return fig, outdir
except Exception as e:
return gr.Plot(label=f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ ํ•จ์ˆ˜ ์ถ”๊ฐ€
def calculate_hero_stats(helix_bias, strand_bias, loop_bias, hydrophobic_score):
stats = {
'strength': strand_bias * 20, # ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'flexibility': helix_bias * 20, # ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'speed': loop_bias * 5, # ๋ฃจํ”„ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'defense': abs(hydrophobic_score) if hydrophobic_score else 0
}
return stats
def toggle_seq_input(choice):
if choice == "์ž๋™ ์„ค๊ณ„":
return gr.update(visible=True), gr.update(visible=False)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return gr.update(visible=False), gr.update(visible=True)
def toggle_secondary_structure(choice):
if choice == "์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •":
return (
gr.update(visible=True), # helix_bias
gr.update(visible=True), # strand_bias
gr.update(visible=True), # loop_bias
gr.update(visible=False) # secondary_structure
)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return (
gr.update(visible=False), # helix_bias
gr.update(visible=False), # strand_bias
gr.update(visible=False), # loop_bias
gr.update(visible=True) # secondary_structure
)
def create_radar_chart(stats):
# ๋ ˆ์ด๋” ์ฐจํŠธ ์ƒ์„ฑ ๋กœ์ง
categories = list(stats.keys())
values = list(stats.values())
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=False
)
return fig
def generate_hero_description(name, stats, abilities):
# ํžˆ์–ด๋กœ ์„ค๋ช… ์ƒ์„ฑ ๋กœ์ง
description = f"""
ํžˆ์–ด๋กœ ์ด๋ฆ„: {name}
์ฃผ์š” ๋Šฅ๋ ฅ:
- ๊ทผ๋ ฅ: {'โ˜…' * int(stats['strength'] * 5)}
- ์œ ์—ฐ์„ฑ: {'โ˜…' * int(stats['flexibility'] * 5)}
- ์Šคํ”ผ๋“œ: {'โ˜…' * int(stats['speed'] * 5)}
- ๋ฐฉ์–ด๋ ฅ: {'โ˜…' * int(stats['defense'] * 5)}
ํŠน์ˆ˜ ๋Šฅ๋ ฅ: {', '.join(abilities)}
"""
return description
def combined_generation(name, strength, flexibility, speed, defense, size, abilities,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
try:
# protein_diffusion_model ์‹คํ–‰
generator = protein_diffusion_model(
sequence=None,
seq_len=size, # ํžˆ์–ด๋กœ ํฌ๊ธฐ๋ฅผ seq_len์œผ๋กœ ์‚ฌ์šฉ
helix_bias=flexibility, # ํžˆ์–ด๋กœ ์œ ์—ฐ์„ฑ์„ helix_bias๋กœ ์‚ฌ์šฉ
strand_bias=strength, # ํžˆ์–ด๋กœ ๊ฐ•๋„๋ฅผ strand_bias๋กœ ์‚ฌ์šฉ
loop_bias=speed, # ํžˆ์–ด๋กœ ์Šคํ”ผ๋“œ๋ฅผ loop_bias๋กœ ์‚ฌ์šฉ
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(-defense), # ํžˆ์–ด๋กœ ๋ฐฉ์–ด๋ ฅ์„ hydrophobic score๋กœ ์‚ฌ์šฉ
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
# ๋งˆ์ง€๋ง‰ ๊ฒฐ๊ณผ ๊ฐ€์ ธ์˜ค๊ธฐ
final_result = None
for result in generator:
final_result = result
if final_result is None:
raise Exception("์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
output_seq, output_pdb, structure_view, plddt_plot = final_result
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ
stats = calculate_hero_stats(flexibility, strength, speed, defense)
# ๋ชจ๋“  ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return (
create_radar_chart(stats), # ๋Šฅ๋ ฅ์น˜ ์ฐจํŠธ
generate_hero_description(name, stats, abilities), # ํžˆ์–ด๋กœ ์„ค๋ช…
output_seq, # ๋‹จ๋ฐฑ์งˆ ์„œ์—ด
output_pdb, # PDB ํŒŒ์ผ
structure_view, # 3D ๊ตฌ์กฐ
plddt_plot # ์‹ ๋ขฐ๋„ ์ฐจํŠธ
)
except Exception as e:
print(f"Error in combined_generation: {str(e)}")
return (
None,
f"์—๋Ÿฌ: {str(e)}",
None,
None,
gr.HTML("์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค"),
None
)
def extract_parameters_from_chat(chat_response):
"""์ฑ—๋ด‡ ์‘๋‹ต์—์„œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ถ”์ถœ"""
try:
params = {
'sequence_length': 100,
'helix_bias': 0.02,
'strand_bias': 0.02,
'loop_bias': 0.1,
'hydrophobic_target_score': 0
}
# ์‘๋‹ต ํ…์ŠคํŠธ์—์„œ ๊ฐ’ ์ถ”์ถœ
if "๊ธธ์ด:" in chat_response:
length_match = re.search(r'๊ธธ์ด: (\d+)', chat_response)
if length_match:
params['sequence_length'] = int(length_match.group(1))
if "์•ŒํŒŒ ํ—ฌ๋ฆญ์Šค ๋น„์œจ:" in chat_response:
helix_match = re.search(r'์•ŒํŒŒ ํ—ฌ๋ฆญ์Šค ๋น„์œจ: ([\d.]+)', chat_response)
if helix_match:
params['helix_bias'] = float(helix_match.group(1)) / 100
if "๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ:" in chat_response:
strand_match = re.search(r'๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ: ([\d.]+)', chat_response)
if strand_match:
params['strand_bias'] = float(strand_match.group(1)) / 100
if "๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ:" in chat_response:
loop_match = re.search(r'๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ: ([\d.]+)', chat_response)
if loop_match:
params['loop_bias'] = float(loop_match.group(1)) / 100
if "์†Œ์ˆ˜์„ฑ ์ ์ˆ˜:" in chat_response:
hydro_match = re.search(r'์†Œ์ˆ˜์„ฑ ์ ์ˆ˜: ([-\d.]+)', chat_response)
if hydro_match:
params['hydrophobic_target_score'] = float(hydro_match.group(1))
return params
except Exception as e:
print(f"ํŒŒ๋ผ๋ฏธํ„ฐ ์ถ”์ถœ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return None
def update_protein_display(chat_response):
if "์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ ๋ถ„์„" in chat_response:
params = extract_parameters_from_chat(chat_response)
if params:
result = generate_protein(params)
stats = calculate_hero_stats(
helix_bias=params['helix_bias'],
strand_bias=params['strand_bias'],
loop_bias=params['loop_bias'],
hydrophobic_score=params['hydrophobic_target_score']
)
return {
hero_stats: create_radar_chart(stats),
hero_description: chat_response,
output_seq: result[0],
output_pdb: result[1],
output_viewer: display_pdb(result[1]),
plddt_plot: result[3]
}
return None
def process_chat_and_generate(message, history):
try:
# 1. ํ”„๋กฌํ”„ํŠธ ๋ถ„์„ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •
analysis = analyze_prompt(message)
similar_structures = search_protein_data(analysis, ds)
params = extract_parameters(analysis, similar_structures)
# 2. ๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ
generator = protein_diffusion_model(
sequence=None,
seq_len=params['sequence_length'],
helix_bias=params['helix_bias'],
strand_bias=params['strand_bias'],
loop_bias=params['loop_bias'],
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25", # 25๋‹จ๊ณ„ ์„ค์ •
noise="normal",
hydrophobic_target_score=str(params['hydrophobic_target_score']),
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
# 3. ๋ชจ๋“  ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ ์ˆ˜์ง‘
final_result = None
for result in generator:
final_result = result
# ์ค‘๊ฐ„ ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ (์„ ํƒ์ )
yield (
history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ์ค‘... {len(result[3].get_lines()) if result[3] else 0}๋‹จ๊ณ„ ์™„๋ฃŒ"}
],
create_radar_chart(calculate_hero_stats(
params['helix_bias'],
params['strand_bias'],
params['loop_bias'],
params['hydrophobic_target_score']
)),
f"๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ์ง„ํ–‰ ์ค‘... {len(result[3].get_lines()) if result[3] else 0}๋‹จ๊ณ„ ์™„๋ฃŒ",
result[0], # output_seq
result[1], # output_pdb
result[2], # structure_view
result[3] # plddt_plot
)
if final_result is None:
raise Exception("์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
output_seq, output_pdb, structure_view, plddt_plot = final_result
# 4. ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ
stats = calculate_hero_stats(
helix_bias=params['helix_bias'],
strand_bias=params['strand_bias'],
loop_bias=params['loop_bias'],
hydrophobic_score=params['hydrophobic_target_score']
)
# 5. ์ตœ์ข… ์„ค๋ช… ์ƒ์„ฑ
explanation = f"""
์š”์ฒญํ•˜์‹  ๊ธฐ๋Šฅ์— ๋งž๋Š” ๋‹จ๋ฐฑ์งˆ์„ ์ƒ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค:
๋ถ„์„๋œ ์š”๊ตฌ์‚ฌํ•ญ:
{analysis}
์„ค๊ณ„๋œ ๊ตฌ์กฐ์  ํŠน์ง•:
- ๊ธธ์ด: {params['sequence_length']} ์•„๋ฏธ๋…ธ์‚ฐ
- ์•ŒํŒŒ ํ—ฌ๋ฆญ์Šค ๋น„์œจ: {params['helix_bias']*100:.1f}%
- ๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ: {params['strand_bias']*100:.1f}%
- ๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ: {params['loop_bias']*100:.1f}%
- ์†Œ์ˆ˜์„ฑ ์ ์ˆ˜: {params['hydrophobic_target_score']}
์ƒ์„ฑ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด {len(plddt_plot.get_lines()) if plddt_plot else 0}๋‹จ๊ณ„์˜ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค.
"""
# 6. ์ตœ์ข… ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return (
history + [
{"role": "user", "content": message},
{"role": "assistant", "content": explanation}
],
create_radar_chart(stats),
explanation,
output_seq,
output_pdb,
structure_view,
plddt_plot
)
except Exception as e:
print(f"Error in process_chat_and_generate: {str(e)}")
return (
history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"}
],
None, None, None, None, None, None
)
def extract_keywords(analysis):
"""๋ถ„์„ ํ…์ŠคํŠธ์—์„œ ํ‚ค์›Œ๋“œ ์ถ”์ถœ"""
try:
# ๊ธฐ๋ณธ ํ‚ค์›Œ๋“œ ์ถ”์ถœ
keywords = []
# ์ฃผ์š” ๊ธฐ๋Šฅ ํ‚ค์›Œ๋“œ
if "์น˜๋ฃŒ" in analysis: keywords.extend(["therapeutic", "binding"])
if "๊ฒฐํ•ฉ" in analysis: keywords.extend(["binding", "interaction"])
if "์ด‰๋งค" in analysis: keywords.extend(["enzyme", "catalytic"])
# ํ™˜๊ฒฝ ํ‚ค์›Œ๋“œ
if "๋ง‰" in analysis: keywords.extend(["membrane", "transmembrane"])
if "์ˆ˜์šฉ์„ฑ" in analysis: keywords.extend(["soluble", "hydrophilic"])
if "์†Œ์ˆ˜์„ฑ" in analysis: keywords.extend(["hydrophobic"])
# ๊ตฌ์กฐ ํ‚ค์›Œ๋“œ
if "์•ŒํŒŒ" in analysis or "๋‚˜์„ " in analysis: keywords.append("helix")
if "๋ฒ ํƒ€" in analysis or "์‹œํŠธ" in analysis: keywords.append("sheet")
if "๋ฃจํ”„" in analysis: keywords.append("loop")
return list(set(keywords)) # ์ค‘๋ณต ์ œ๊ฑฐ
except Exception as e:
print(f"ํ‚ค์›Œ๋“œ ์ถ”์ถœ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return []
def calculate_similarity(keywords, entry):
"""ํ‚ค์›Œ๋“œ์™€ ๋ฐ์ดํ„ฐ์…‹ ํ•ญ๋ชฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ"""
try:
score = 0
# ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ํ™•์ธ ๋ฐ ์•ˆ์ „ํ•œ ์ ‘๊ทผ
sequence = entry.get('sequence', '').lower() if isinstance(entry, dict) else str(entry).lower()
# ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ๋””๋ฒ„๊น…
print("Entry structure:", type(entry))
print("Entry content:", entry)
for keyword in keywords:
# ์•ˆ์ „ํ•œ ์ ‘๊ทผ์„ ์œ„ํ•œ ์ˆ˜์ •
description = entry.get('description', '') if isinstance(entry, dict) else ''
if keyword in description.lower():
score += 2
if keyword in sequence:
score += 1
if isinstance(entry, dict) and 'secondary_structure' in entry:
sec_structure = entry['secondary_structure']
if keyword in ['helix'] and 'H' in sec_structure:
score += 1
if keyword in ['sheet'] and 'E' in sec_structure:
score += 1
if keyword in ['loop'] and 'L' in sec_structure:
score += 1
return score
except Exception as e:
print(f"์œ ์‚ฌ๋„ ๊ณ„์‚ฐ ์ค‘ ์ƒ์„ธ ์˜ค๋ฅ˜: {str(e)}")
print("Entry:", entry)
return 0
def download_checkpoint_files():
"""ํ•„์š”ํ•œ ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ"""
try:
import requests
# ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ URL (์‹ค์ œ URL๋กœ ๊ต์ฒด ํ•„์š”)
dssp_url = "YOUR_DSSP_CHECKPOINT_URL"
og_url = "YOUR_OG_CHECKPOINT_URL"
# DSSP ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
if not os.path.exists(dssp_checkpoint):
print("Downloading DSSP checkpoint...")
response = requests.get(dssp_url)
with open(dssp_checkpoint, 'wb') as f:
f.write(response.content)
# OG ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
if not os.path.exists(og_checkpoint):
print("Downloading OG checkpoint...")
response = requests.get(og_url)
with open(og_checkpoint, 'wb') as f:
f.write(response.content)
print("Checkpoint files downloaded successfully")
except Exception as e:
print(f"Error downloading checkpoint files: {str(e)}")
raise
# ์‹œ์ž‘ ์‹œ ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ํ™•์ธ ๋ฐ ๋‹ค์šด๋กœ๋“œ
try:
download_checkpoint_files()
except Exception as e:
print(f"Warning: Could not download checkpoint files: {str(e)}")
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row():
with gr.Column(scale=1):
# ์ฑ—๋ด‡ ์ธํ„ฐํŽ˜์ด์Šค
gr.Markdown("# ๐Ÿค– AI ๋‹จ๋ฐฑ์งˆ ์„ค๊ณ„ ๋„์šฐ๋ฏธ")
# ์—ฌ๊ธฐ๋ฅผ ์ˆ˜์ •
chatbot = gr.Chatbot(
height=600,
type='messages' # ๋ฉ”์‹œ์ง€ ํ˜•์‹ ์ง€์ •
)
with gr.Row():
msg = gr.Textbox(
label="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
placeholder="์˜ˆ: COVID-19๋ฅผ ์น˜๋ฃŒํ•  ์ˆ˜ ์žˆ๋Š” ๋‹จ๋ฐฑ์งˆ์„ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”",
lines=2,
scale=4
)
submit_btn = gr.Button("์ „์†ก", variant="primary", scale=1)
clear = gr.Button("๋Œ€ํ™” ๋‚ด์šฉ ์ง€์šฐ๊ธฐ")
with gr.Accordion("์ฑ„ํŒ… ์„ค์ •", open=False):
system_message = gr.Textbox(
value="๋‹น์‹ ์€ ๋‹จ๋ฐฑ์งˆ ์„ค๊ณ„๋ฅผ ๋„์™€์ฃผ๋Š” ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค.",
label="์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€"
)
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-P"
)
# ํƒญ ์ธํ„ฐํŽ˜์ด์Šค
with gr.Tabs():
with gr.TabItem("๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ๋””์ž์ธ"):
gr.Markdown("""
### โœจ ๋‹น์‹ ๋งŒ์˜ ํŠน๋ณ„ํ•œ ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ค์–ด๋ณด์„ธ์š”!
๊ฐ ๋Šฅ๋ ฅ์น˜๋ฅผ ์กฐ์ ˆํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA๊ฐ€ ์ž๋™์œผ๋กœ ์„ค๊ณ„๋ฉ๋‹ˆ๋‹ค.
""")
# ํžˆ์–ด๋กœ ๊ธฐ๋ณธ ์ •๋ณด
hero_name = gr.Textbox(
label="ํžˆ์–ด๋กœ ์ด๋ฆ„",
placeholder="๋‹น์‹ ์˜ ํžˆ์–ด๋กœ ์ด๋ฆ„์„ ์ง€์–ด์ฃผ์„ธ์š”!",
info="ํžˆ์–ด๋กœ์˜ ์ •์ฒด์„ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ด๋ฆ„์„ ์ž…๋ ฅํ•˜์„ธ์š”"
)
# ๋Šฅ๋ ฅ์น˜ ์„ค์ •
gr.Markdown("### ๐Ÿ’ช ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ์„ค์ •")
with gr.Row():
strength = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿ’ช ์ดˆ๊ฐ•๋ ฅ(๊ทผ๋ ฅ)",
value=0.02,
info="๋‹จ๋‹จํ•œ ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ๋กœ ๊ฐ•๋ ฅํ•œ ํž˜์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค"
)
flexibility = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿคธโ€โ™‚๏ธ ์œ ์—ฐ์„ฑ",
value=0.02,
info="๋‚˜์„ ํ˜• ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ๋กœ ์œ ์—ฐํ•œ ์›€์ง์ž„์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค"
)
with gr.Row():
speed = gr.Slider(
minimum=0.0, maximum=0.20,
label="โšก ์Šคํ”ผ๋“œ",
value=0.1,
info="๋ฃจํ”„ ๊ตฌ์กฐ๋กœ ๋น ๋ฅธ ์›€์ง์ž„์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค"
)
defense = gr.Slider(
minimum=-10, maximum=10,
label="๐Ÿ›ก๏ธ ๋ฐฉ์–ด๋ ฅ",
value=0,
info="์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”"
)
# ํžˆ์–ด๋กœ ํฌ๊ธฐ ์„ค์ •
hero_size = gr.Slider(
minimum=50, maximum=200,
label="๐Ÿ“ ํžˆ์–ด๋กœ ํฌ๊ธฐ",
value=100,
info="ํžˆ์–ด๋กœ์˜ ์ „์ฒด์ ์ธ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค"
)
# ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ค์ •
with gr.Accordion("๐ŸŒŸ ํŠน์ˆ˜ ๋Šฅ๋ ฅ", open=False):
gr.Markdown("""
ํŠน์ˆ˜ ๋Šฅ๋ ฅ์„ ์„ ํƒํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA์— ํŠน๋ณ„ํ•œ ๊ตฌ์กฐ๊ฐ€ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
- ์ž๊ฐ€ ํšŒ๋ณต: ๋‹จ๋ฐฑ์งˆ ๊ตฌ์กฐ ๋ณต๊ตฌ ๋Šฅ๋ ฅ ๊ฐ•ํ™”
- ์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ: ํŠน์ˆ˜ํ•œ ๊ตฌ์กฐ์  ๋Œ์ถœ๋ถ€ ํ˜•์„ฑ
- ๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ: ์•ˆ์ •์ ์ธ ๋ณดํ˜ธ์ธต ๊ตฌ์กฐ ์ƒ์„ฑ
""")
special_ability = gr.CheckboxGroup(
choices=["์ž๊ฐ€ ํšŒ๋ณต", "์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ", "๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ"],
label="ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ ํƒ"
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
create_btn = gr.Button("๐Ÿงฌ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงฌ ํžˆ์–ด๋กœ DNA ์„ค๊ณ„"):
gr.Markdown("""
### ๐Ÿงช ํžˆ์–ด๋กœ DNA ๊ณ ๊ธ‰ ์„ค์ •
ํžˆ์–ด๋กœ์˜ ์œ ์ „์ž ๊ตฌ์กฐ๋ฅผ ๋” ์„ธ๋ฐ€ํ•˜๊ฒŒ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
seq_opt = gr.Radio(
["์ž๋™ ์„ค๊ณ„", "์ง์ ‘ ์ž…๋ ฅ"],
label="DNA ์„ค๊ณ„ ๋ฐฉ์‹",
value="์ž๋™ ์„ค๊ณ„"
)
sequence = gr.Textbox(
label="DNA ์‹œํ€€์Šค",
lines=1,
placeholder='์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์•„๋ฏธ๋…ธ์‚ฐ: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y (X๋Š” ๋ฌด์ž‘์œ„)',
visible=False
)
seq_len = gr.Slider(
minimum=5.0, maximum=250.0,
label="DNA ๊ธธ์ด",
value=100,
visible=True
)
with gr.Accordion(label='๐Ÿฆด ๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ •', open=True):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ๊ธฐ๋ณธ ๊ณจ๊ฒฉ ๊ตฌ์กฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
- ๋‚˜์„ ํ˜• ๊ตฌ์กฐ: ์œ ์—ฐํ•˜๊ณ  ํƒ„๋ ฅ์žˆ๋Š” ์›€์ง์ž„
- ๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ: ๋‹จ๋‹จํ•˜๊ณ  ๊ฐ•๋ ฅํ•œ ํž˜
- ๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ: ๋น ๋ฅด๊ณ  ๋ฏผ์ฒฉํ•œ ์›€์ง์ž„
""")
sec_str_opt = gr.Radio(
["์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •", "์ง์ ‘ ์ž…๋ ฅ"],
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ • ๋ฐฉ์‹",
value="์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •"
)
secondary_structure = gr.Textbox(
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ",
lines=1,
placeholder='H:๋‚˜์„ ํ˜•, S:๋ณ‘ํ’ํ˜•, L:๊ณ ๋ฆฌํ˜•, X:์ž๋™์„ค์ •',
visible=False
)
with gr.Column():
helix_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋‚˜์„ ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
strand_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
loop_bias = gr.Slider(
minimum=0.0, maximum=0.20,
label="๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
with gr.Accordion(label='๐Ÿงฌ DNA ๊ตฌ์„ฑ ์„ค์ •', open=False):
gr.Markdown("""
ํŠน์ • ์•„๋ฏธ๋…ธ์‚ฐ์˜ ๋น„์œจ์„ ์กฐ์ ˆํ•˜์—ฌ ํžˆ์–ด๋กœ์˜ ํŠน์„ฑ์„ ๊ฐ•ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์˜ˆ์‹œ: W0.2,E0.1 (ํŠธ๋ฆฝํ† ํŒ 20%, ๊ธ€๋ฃจํƒ์‚ฐ 10%)
""")
with gr.Row():
aa_bias = gr.Textbox(
label="์•„๋ฏธ๋…ธ์‚ฐ ๋น„์œจ",
lines=1,
placeholder='์˜ˆ์‹œ: W0.2,E0.1'
)
aa_bias_potential = gr.Textbox(
label="๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-5.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
with gr.Accordion(label='๐ŸŒ ํ™˜๊ฒฝ ์ ์‘๋ ฅ ์„ค์ •', open=False):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ํ™˜๊ฒฝ ์ ์‘๋ ฅ์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”
""")
with gr.Row():
hydrophobic_target_score = gr.Textbox(
label="ํ™˜๊ฒฝ ์ ์‘ ์ ์ˆ˜",
lines=1,
placeholder='์˜ˆ์‹œ: -5 (์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”)'
)
hydrophobic_potential = gr.Textbox(
label="์ ์‘๋ ฅ ๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-2.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
with gr.Accordion(label='โš™๏ธ ๊ณ ๊ธ‰ ์„ค์ •', open=False):
gr.Markdown("""
DNA ์ƒ์„ฑ ๊ณผ์ •์˜ ์„ธ๋ถ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
""")
with gr.Row():
num_steps = gr.Textbox(
label="์ƒ์„ฑ ๋‹จ๊ณ„",
lines=1,
placeholder='25 ์ดํ•˜ ๊ถŒ์žฅ'
)
noise = gr.Dropdown(
['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'],
label='๋…ธ์ด์ฆˆ ํƒ€์ž…',
value='normal'
)
design_btn = gr.Button("๐Ÿงฌ DNA ์„ค๊ณ„ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงช ํžˆ์–ด๋กœ ์œ ์ „์ž ๊ฐ•ํ™”"):
gr.Markdown("""
### โšก ๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ํ™œ์šฉ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ์˜ DNA ์ผ๋ถ€๋ฅผ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ์—๊ฒŒ ์ด์‹ํ•ฉ๋‹ˆ๋‹ค.
""")
gr.Markdown("๊ณต๊ฐœ๋œ ํžˆ์–ด๋กœ DNA ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์ฝ”๋“œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
pdb_id_code = gr.Textbox(
label="ํžˆ์–ด๋กœ DNA ์ฝ”๋“œ",
lines=1,
placeholder='๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ์ฝ”๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š” (์˜ˆ: 1DPX)'
)
gr.Markdown("์ด์‹ํ•˜๊ณ  ์‹ถ์€ DNA ์˜์—ญ์„ ์„ ํƒํ•˜๊ณ  ์ƒˆ๋กœ์šด DNA๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
contigs = gr.Textbox(
label="์ด์‹ํ•  DNA ์˜์—ญ",
lines=1,
placeholder='์˜ˆ์‹œ: 15,A3-10,20-30'
)
with gr.Row():
seq_mask = gr.Textbox(
label='๋Šฅ๋ ฅ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๋Šฅ๋ ฅ์„ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
str_mask = gr.Textbox(
label='๊ตฌ์กฐ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๊ตฌ์กฐ๋ฅผ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
preview_viewer = gr.HTML()
rewrite_pdb = gr.File(label='ํžˆ์–ด๋กœ DNA ํŒŒ์ผ')
preview_btn = gr.Button("๐Ÿ” ๋ฏธ๋ฆฌ๋ณด๊ธฐ", variant="secondary")
enhance_btn = gr.Button("โšก ๊ฐ•ํ™”๋œ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿ‘‘ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ"):
gr.Markdown("""
### ๐Ÿฐ ์œ„๋Œ€ํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ์œ ์‚ฐ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ํŠน์„ฑ์„ ๊ณ„์Šนํ•˜์—ฌ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column():
gr.Markdown("ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ DNA ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
fasta_msa = gr.File(label='๊ฐ€๋ฌธ DNA ๋ฐ์ดํ„ฐ')
with gr.Column():
gr.Markdown("์ด๋ฏธ ๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋‹ค๋ฉด ์—…๋กœ๋“œํ•˜์„ธ์š”")
input_pssm = gr.File(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ')
pssm = gr.File(label='๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ')
pssm_view = gr.Plot(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„ ๊ฒฐ๊ณผ')
pssm_gen_btn = gr.Button("โœจ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„", variant="secondary")
inherit_btn = gr.Button("๐Ÿ‘‘ ๊ฐ€๋ฌธ์˜ ํž˜ ๊ณ„์Šน!", variant="primary", scale=2)
# ์˜ค๋ฅธ์ชฝ ์—ด: ๊ฒฐ๊ณผ ํ‘œ์‹œ
with gr.Column(scale=1):
gr.Markdown("## ๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ํ”„๋กœํ•„")
hero_stats = gr.Plot(label="๋Šฅ๋ ฅ์น˜ ๋ถ„์„")
hero_description = gr.Textbox(label="ํžˆ์–ด๋กœ ํŠน์„ฑ", lines=3)
gr.Markdown("## ๐Ÿงฌ ํžˆ์–ด๋กœ DNA ๋ถ„์„ ๊ฒฐ๊ณผ")
gr.Markdown("#### โšก DNA ์•ˆ์ •์„ฑ ์ ์ˆ˜")
plddt_plot = gr.Plot(label='์•ˆ์ •์„ฑ ๋ถ„์„')
gr.Markdown("#### ๐Ÿ“ DNA ์‹œํ€€์Šค")
output_seq = gr.Textbox(label="DNA ์„œ์—ด")
gr.Markdown("#### ๐Ÿ’พ DNA ๋ฐ์ดํ„ฐ")
output_pdb = gr.File(label="DNA ํŒŒ์ผ")
gr.Markdown("#### ๐Ÿ”ฌ DNA ๊ตฌ์กฐ")
output_viewer = gr.HTML()
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
# ์ฑ—๋ด‡ ์ด๋ฒคํŠธ
msg.submit(process_chat, [msg, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
seq_opt.change(
fn=toggle_seq_input,
inputs=[seq_opt],
outputs=[seq_len, sequence],
queue=False
)
sec_str_opt.change(
fn=toggle_secondary_structure,
inputs=[sec_str_opt],
outputs=[helix_bias, strand_bias, loop_bias, secondary_structure],
queue=False
)
preview_btn.click(
get_motif_preview,
inputs=[pdb_id_code, contigs],
outputs=[preview_viewer, rewrite_pdb]
)
pssm_gen_btn.click(
get_pssm,
inputs=[fasta_msa, input_pssm],
outputs=[pssm_view, pssm]
)
# ์ฑ—๋ด‡ ๊ธฐ๋ฐ˜ ๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ
def update_protein_display(chat_response):
if "์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ ๋ถ„์„" in chat_response:
params = extract_parameters_from_chat(chat_response)
result = generate_protein(params)
return {
hero_stats: create_radar_chart(calculate_hero_stats(params)),
hero_description: chat_response,
output_seq: result[0],
output_pdb: result[1],
output_viewer: display_pdb(result[1]),
plddt_plot: result[3]
}
return None
# ๊ฐ ์ƒ์„ฑ ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
for btn in [create_btn, design_btn, enhance_btn, inherit_btn]:
btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
msg.submit(
fn=process_chat_and_generate,
inputs=[msg, chatbot],
outputs=[
chatbot,
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
submit_btn.click(
fn=process_chat_and_generate,
inputs=[msg, chatbot],
outputs=[
chatbot,
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
# ์ฑ„ํŒ… ๋‚ด์šฉ ์ง€์šฐ๊ธฐ
clear.click(
lambda: (None, None, None, None, None, None, None),
None,
[chatbot, hero_stats, hero_description, output_seq, output_pdb, output_viewer, plddt_plot],
queue=False
)
# ์ฑ—๋ด‡ ์‘๋‹ต์— ๋”ฐ๋ฅธ ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ
msg.submit(
update_protein_display,
inputs=[chatbot],
outputs=[hero_stats, hero_description, output_seq, output_pdb, output_viewer, plddt_plot]
)
submit_btn.click(respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p],
[chatbot])
msg.submit(respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p],
[chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์ˆ˜์ •
msg.submit(
fn=process_chat_and_generate,
inputs=[msg, chatbot],
outputs=[
chatbot,
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
],
show_progress=True
).then(
lambda: None, # ์ง„ํ–‰ ์ƒํƒœ ์ดˆ๊ธฐํ™”
None,
None
)
# ์‹คํ–‰
demo.queue()
demo.launch(debug=True)