ka1kuk commited on
Commit
348a0a6
·
verified ·
1 Parent(s): 25c22bd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +48 -60
main.py CHANGED
@@ -1,68 +1,56 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- from rdkit import Chem
3
- from rdkit.Chem import Descriptors
4
- from rdkit.Chem import rdMolDescriptors
5
  import requests
6
- from fastapi.middleware.cors import CORSMiddleware
 
 
7
 
8
  app = FastAPI()
9
 
10
- app.add_middleware(
11
- CORSMiddleware,
12
- allow_credentials=True,
13
- allow_methods=["*"],
14
- allow_headers=["*"],
15
- )
16
-
17
- def name_to_smiles(name: str) -> str:
18
- url = f"https://cactus.nci.nih.gov/chemical/structure/{name}/smiles"
19
- response = requests.get(url)
20
- if response.status_code == 200:
21
- return response.text
22
- return None
23
-
24
- def get_molecule_info(mol):
25
- mol_weight = Descriptors.MolWt(mol)
26
- num_atoms = mol.GetNumAtoms()
27
- num_bonds = mol.GetNumBonds()
28
- mol_formula = rdMolDescriptors.CalcMolFormula(mol)
29
- tpsa = Descriptors.TPSA(mol)
30
- mol_logp = Descriptors.MolLogP(mol)
31
- num_rotatable_bonds = Descriptors.NumRotatableBonds(mol)
32
-
33
- return {
34
- 'molecular_weight': mol_weight,
35
- 'number_of_atoms': num_atoms,
36
- 'number_of_bonds': num_bonds,
37
- 'molecular_formula': mol_formula,
38
- 'tpsa': tpsa,
39
- 'logP': mol_logp,
40
- 'number_of_rotatable_bonds': num_rotatable_bonds
41
  }
42
-
43
- @app.get("/")
44
- def home():
45
- return "Hello! This is ChemDB."
46
-
47
-
48
- @app.get("/molecule_info/{name}")
49
- async def read_molecule_info(name: str, request: Request):
50
- if not name:
51
- raise HTTPException(status_code=400, detail="No molecule name provided")
52
-
53
- smiles = name_to_smiles(name)
54
- if not smiles:
55
- raise HTTPException(status_code=400, detail="Could not fetch SMILES string for provided name")
56
-
57
- mol = Chem.MolFromSmiles(smiles)
58
-
59
- if not mol:
60
- raise HTTPException(status_code=400, detail="Molecule not recognized")
61
-
62
- info = get_molecule_info(mol)
63
-
64
- return info
65
 
66
  if __name__ == "__main__":
67
  import uvicorn
68
- uvicorn.run(app, host="0.0.0.0", port=3000)
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
 
 
3
  import requests
4
+ from urllib.parse import quote
5
+ import asyncio
6
+ from io import BytesIO
7
 
8
  app = FastAPI()
9
 
10
+ def create_job(prompt, model, sampler, seed, neg):
11
+ if model is None:
12
+ model = 'Realistic_Vision_V5.0.safetensors [614d1063]'
13
+ if sampler is None:
14
+ sampler = 'DPM++ 2M Karras'
15
+ if seed is None:
16
+ seed = '-1'
17
+ if neg is None:
18
+ neg = "(long list of negative prompts removed for brevity)"
19
+ url = 'https://api.prodia.com/generate'
20
+ params = {
21
+ 'new': 'true',
22
+ 'prompt': quote(prompt),
23
+ 'model': model,
24
+ 'negative_prompt': neg,
25
+ 'steps': '100',
26
+ 'cfg': '9.5',
27
+ 'seed': seed,
28
+ 'sampler': sampler,
29
+ 'upscale': 'True',
30
+ 'aspect_ratio': 'square'
 
 
 
 
 
 
 
 
 
 
31
  }
32
+ response = requests.get(url, params=params)
33
+ response.raise_for_status()
34
+ return response.json()['job']
35
+
36
+ @app.get("/generate_image")
37
+ async def generate_image(prompt: str, model: str = None, sampler: str = None, seed: str = None, neg: str = None):
38
+ job_id = create_job(prompt, model, sampler, seed, neg)
39
+ url = f'https://api.prodia.com/job/{job_id}'
40
+ headers = {'accept': '*/*'}
41
+
42
+ while True:
43
+ response = requests.get(url=url, headers=headers)
44
+ response.raise_for_status()
45
+ job_response = response.json()
46
+ if job_response['status'] == 'succeeded':
47
+ image_url = f'https://images.prodia.xyz/{job_id}.png'
48
+ image_response = requests.get(image_url)
49
+ image_response.raise_for_status()
50
+ image = BytesIO(image_response.content)
51
+ return StreamingResponse(image, media_type='image/png')
52
+ await asyncio.sleep(2) # Add a delay to prevent excessive requests
 
 
53
 
54
  if __name__ == "__main__":
55
  import uvicorn
56
+ uvicorn.run(app, host="0.0.0.0", port=8000, debug=True)