Spaces:
Runtime error
Runtime error
stakelovelace
commited on
Commit
·
3b6b2b0
1
Parent(s):
ab60a3a
app.py
CHANGED
|
@@ -4,6 +4,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,
|
|
| 4 |
import csv
|
| 5 |
import yaml
|
| 6 |
from datasets import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import os
|
| 8 |
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
|
| 9 |
|
|
@@ -19,7 +24,8 @@ def load_data_and_config(data_path):
|
|
| 19 |
def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_url):
|
| 20 |
"""Generates an API query using a fine-tuned model."""
|
| 21 |
input_ids = tokenizer.encode(prompt + f" Write an API query to {api_name} to get {desired_output}", return_tensors="pt")
|
| 22 |
-
|
|
|
|
| 23 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 24 |
return f"{base_url}/{query}"
|
| 25 |
|
|
@@ -58,19 +64,19 @@ def train_model(model, tokenizer, data):
|
|
| 58 |
# Optionally clear cache if using GPU or MPS
|
| 59 |
if torch.cuda.is_available():
|
| 60 |
torch.cuda.empty_cache()
|
| 61 |
-
elif torch.
|
| 62 |
torch.mps.empty_cache()
|
| 63 |
|
| 64 |
# Perform any remaining steps such as logging, saving, etc.
|
| 65 |
trainer.save_model()
|
| 66 |
|
| 67 |
-
|
| 68 |
-
# Load data
|
| 69 |
data = load_data_and_config("train2.csv")
|
| 70 |
|
| 71 |
# Load tokenizer and model
|
| 72 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 73 |
-
model = AutoModelForCausalLM.from_pretrained("
|
| 74 |
|
| 75 |
# Train the model on your dataset
|
| 76 |
train_model(model, tokenizer, data)
|
|
@@ -81,5 +87,10 @@ if __name__ == "__main__":
|
|
| 81 |
|
| 82 |
# Example usage
|
| 83 |
prompt = "I need to retrieve the latest block on chain using a python script"
|
| 84 |
-
api_query = generate_api_query(model, tokenizer, prompt, "latest block on chain",
|
| 85 |
print(f"Generated code: {api_query}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import csv
|
| 5 |
import yaml
|
| 6 |
from datasets import Dataset
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
# Check TensorFlow GPU availability
|
| 10 |
+
print("GPUs Available: ", tf.config.list_physical_devices('GPU'))
|
| 11 |
+
|
| 12 |
import os
|
| 13 |
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
|
| 14 |
|
|
|
|
| 24 |
def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_url):
|
| 25 |
"""Generates an API query using a fine-tuned model."""
|
| 26 |
input_ids = tokenizer.encode(prompt + f" Write an API query to {api_name} to get {desired_output}", return_tensors="pt")
|
| 27 |
+
input_ids = input_ids.to(model.device) # Ensure input_ids are on the same device as the model
|
| 28 |
+
output = model.generate(input_ids, max_length=256, temperature=0.7, do_sample=True) # Enable sampling with temperature control
|
| 29 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 30 |
return f"{base_url}/{query}"
|
| 31 |
|
|
|
|
| 64 |
# Optionally clear cache if using GPU or MPS
|
| 65 |
if torch.cuda.is_available():
|
| 66 |
torch.cuda.empty_cache()
|
| 67 |
+
elif torch.backends.mps.is_built():
|
| 68 |
torch.mps.empty_cache()
|
| 69 |
|
| 70 |
# Perform any remaining steps such as logging, saving, etc.
|
| 71 |
trainer.save_model()
|
| 72 |
|
| 73 |
+
def main(api_name, base_url):
|
| 74 |
+
# Load data
|
| 75 |
data = load_data_and_config("train2.csv")
|
| 76 |
|
| 77 |
# Load tokenizer and model
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-small")
|
| 79 |
+
model = AutoModelForCausalLM.from_pretrained("thenlper/gte-small")
|
| 80 |
|
| 81 |
# Train the model on your dataset
|
| 82 |
train_model(model, tokenizer, data)
|
|
|
|
| 87 |
|
| 88 |
# Example usage
|
| 89 |
prompt = "I need to retrieve the latest block on chain using a python script"
|
| 90 |
+
api_query = generate_api_query(model, tokenizer, prompt, "latest block on chain", api_name, base_url)
|
| 91 |
print(f"Generated code: {api_query}")
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
api_name = "Koios"
|
| 95 |
+
base_url = "https://api.koios.rest"
|
| 96 |
+
main(api_name, base_url)
|