aiqcamp commited on
Commit
99e133c
·
verified ·
1 Parent(s): 8ff8fc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import os,sys
2
- from openai import OpenAI
3
- import gradio as gr
4
- import json # json 모듈 추가
5
 
6
  # install required packages
7
- os.system('pip install -q plotly')
8
- os.system('pip install -q matplotlib')
9
  os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
10
  os.environ["DGLBACKEND"] = "pytorch"
11
  print('Modules installed')
@@ -32,11 +29,20 @@ if not os.path.exists('./tmp/args.json'):
32
  with open('./tmp/args.json', 'w') as f:
33
  json.dump(default_args, f)
34
 
35
- # args 로드
36
- with open('./tmp/args.json', 'r') as f:
37
- args = json.load(f)
 
 
 
 
 
 
 
38
 
39
- # 필수 라이브러리 임포트
 
 
40
  from datasets import load_dataset
41
  import plotly.graph_objects as go
42
  import numpy as np
@@ -52,10 +58,21 @@ from model.util import writepdb
52
  from utils.inpainting_util import *
53
  import os
54
 
55
- # 현재 스크립트의 디렉토리를 기준으로 체크포인트 파일 경로 설정
56
- current_dir = os.path.dirname(os.path.abspath(__file__))
57
- dssp_checkpoint = os.path.join(current_dir, 'SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
58
- og_checkpoint = os.path.join(current_dir, 'SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # Hugging Face 토큰 설정
61
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
@@ -351,12 +368,15 @@ def generate_explanation(result, params):
351
  return explanation
352
 
353
  # 체크포인트 파일 경로를 절대 경로로 수정
354
-
355
-
356
  def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
357
  secondary_structure, aa_bias, aa_bias_potential,
358
  num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
359
  contigs, pssm, seq_mask, str_mask, rewrite_pdb):
 
 
 
 
 
360
 
361
  # 체크포인트 파일 존재 확인
362
  if not os.path.exists(dssp_checkpoint):
 
1
  import os,sys
 
 
 
2
 
3
  # install required packages
4
+ os.system('pip install plotly') # plotly 설치
5
+ os.system('pip install matplotlib') # matplotlib 설치
6
  os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
7
  os.environ["DGLBACKEND"] = "pytorch"
8
  print('Modules installed')
 
29
  with open('./tmp/args.json', 'w') as f:
30
  json.dump(default_args, f)
31
 
32
+ # 체크포인트 파일 다운로드
33
+ if not os.path.exists('./SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'):
34
+ print('Downloading model weights 1')
35
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
36
+ print('Successfully Downloaded')
37
+
38
+ if not os.path.exists('./SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'):
39
+ print('Downloading model weights 2')
40
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
41
+ print('Successfully Downloaded')
42
 
43
+ from openai import OpenAI
44
+ import gradio as gr
45
+ import json # json 모듈 추가
46
  from datasets import load_dataset
47
  import plotly.graph_objects as go
48
  import numpy as np
 
58
  from utils.inpainting_util import *
59
  import os
60
 
61
+ # args 로드
62
+ with open('./tmp/args.json', 'r') as f:
63
+ args = json.load(f)
64
+
65
+ plt.rcParams.update({'font.size': 13})
66
+
67
+ # manually set checkpoint to load
68
+ args['checkpoint'] = None
69
+ args['dump_trb'] = False
70
+ args['dump_args'] = True
71
+ args['save_best_plddt'] = True
72
+ args['T'] = 25
73
+ args['strand_bias'] = 0.0
74
+ args['loop_bias'] = 0.0
75
+ args['helix_bias'] = 0.0
76
 
77
  # Hugging Face 토큰 설정
78
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
 
368
  return explanation
369
 
370
  # 체크포인트 파일 경로를 절대 경로로 수정
 
 
371
  def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
372
  secondary_structure, aa_bias, aa_bias_potential,
373
  num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
374
  contigs, pssm, seq_mask, str_mask, rewrite_pdb):
375
+
376
+
377
+ dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
378
+ og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
379
+
380
 
381
  # 체크포인트 파일 존재 확인
382
  if not os.path.exists(dssp_checkpoint):