Henry65 commited on
Commit
134c3f4
·
1 Parent(s): 3690859

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +8 -8
RepoPipeline.py CHANGED
@@ -179,26 +179,26 @@ class RepoPipeline(Pipeline):
179
  # Code embeddings
180
  tqdm.write(f"[*] Generating code embeddings for {repo_name}")
181
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
182
- info["code_embeddings"] = code_embeddings.numpy()
183
- info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).numpy()
184
 
185
  # Doc embeddings
186
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
187
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
188
- info["doc_embeddings"] = doc_embeddings.numpy()
189
- info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).numpy()
190
 
191
  # Requirement embeddings
192
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
193
  requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
194
- info["requirement_embeddings"] = requirement_embeddings.numpy()
195
- info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).numpy()
196
 
197
  # Requirement embeddings
198
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
199
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
200
- info["readme_embeddings"] = readme_embeddings.numpy()
201
- info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).numpy()
202
 
203
  progress_bar.update(1)
204
  model_outputs.append(info)
 
179
  # Code embeddings
180
  tqdm.write(f"[*] Generating code embeddings for {repo_name}")
181
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
182
+ info["code_embeddings"] = code_embeddings.cpu().numpy()
183
+ info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).cpu().numpy()
184
 
185
  # Doc embeddings
186
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
187
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
188
+ info["doc_embeddings"] = doc_embeddings.cpu().numpy()
189
+ info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).cpu().numpy()
190
 
191
  # Requirement embeddings
192
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
193
  requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
194
+ info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
195
+ info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
196
 
197
  # Requirement embeddings
198
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
199
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
200
+ info["readme_embeddings"] = readme_embeddings.cpu().numpy()
201
+ info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
202
 
203
  progress_bar.update(1)
204
  model_outputs.append(info)