Update RepoPipeline.py
Browse files- RepoPipeline.py +8 -8
RepoPipeline.py
CHANGED
@@ -179,26 +179,26 @@ class RepoPipeline(Pipeline):
|
|
179 |
# Code embeddings
|
180 |
tqdm.write(f"[*] Generating code embeddings for {repo_name}")
|
181 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
182 |
-
info["code_embeddings"] = code_embeddings.numpy()
|
183 |
-
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).numpy()
|
184 |
|
185 |
# Doc embeddings
|
186 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
187 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
188 |
-
info["doc_embeddings"] = doc_embeddings.numpy()
|
189 |
-
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).numpy()
|
190 |
|
191 |
# Requirement embeddings
|
192 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
193 |
requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
|
194 |
-
info["requirement_embeddings"] = requirement_embeddings.numpy()
|
195 |
-
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).numpy()
|
196 |
|
197 |
# Requirement embeddings
|
198 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
199 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
200 |
-
info["readme_embeddings"] = readme_embeddings.numpy()
|
201 |
-
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).numpy()
|
202 |
|
203 |
progress_bar.update(1)
|
204 |
model_outputs.append(info)
|
|
|
179 |
# Code embeddings
|
180 |
tqdm.write(f"[*] Generating code embeddings for {repo_name}")
|
181 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
182 |
+
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
183 |
+
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).cpu().numpy()
|
184 |
|
185 |
# Doc embeddings
|
186 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
187 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
188 |
+
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
189 |
+
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).cpu().numpy()
|
190 |
|
191 |
# Requirement embeddings
|
192 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
193 |
requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
|
194 |
+
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
195 |
+
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
|
196 |
|
197 |
# Requirement embeddings
|
198 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
199 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
200 |
+
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
201 |
+
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
|
202 |
|
203 |
progress_bar.update(1)
|
204 |
model_outputs.append(info)
|