whackthejacker commited on
Commit
fa6dfc8
·
verified ·
1 Parent(s): caf7335

Create models/codet5.py

Browse files
Files changed (1) hide show
  1. models/codet5.py +26 -0
models/codet5.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CodeT5ForCodeGeneration, CodeT5Tokenizer
3
+
4
+ class CodeT5:
5
+ def __init__(self):
6
+ self.tokenizer = CodeT5Tokenizer.from_pretrained('Salesforce/codet5-base')
7
+ self.model = CodeT5ForCodeGeneration.from_pretrained('Salesforce/codet5-base')
8
+
9
+ def analyze(self, repo_data, github_api):
10
+ if isinstance(repo_data, str): # Error handling from github_api
11
+ return repo_data
12
+ optimization_results = []
13
+ for file in repo_data:
14
+ if file["type"] == "file" and file["name"].endswith((".py", ".js", ".java", ".c", ".cpp")):
15
+ content = github_api.get_file_content(file["download_url"])
16
+ if isinstance(content, str) and content.startswith("Error"): #Error Handling for file content.
17
+ optimization_results.append(f"{file['name']}: {content}")
18
+ continue
19
+ try:
20
+ inputs = self.tokenizer.encode(content, return_tensors="pt", max_length=512, truncation=True)
21
+ outputs = self.model.generate(inputs, max_length=256)
22
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ optimization_results.append(f"{file['name']}: {decoded_output}")
24
+ except Exception as e:
25
+ optimization_results.append(f"{file['name']}: Error analyzing - {e}")
26
+ return "\n".join(optimization_results)