Henry65 commited on
Commit
a4a0af8
·
1 Parent(s): 2955c59

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +153 -26
RepoPipeline.py CHANGED
@@ -10,6 +10,11 @@ from tqdm.auto import tqdm
10
 
11
 
12
  def extract_code_and_docs(text: str):
 
 
 
 
 
13
  code_set = set()
14
  docs_set = set()
15
  root = ast.parse(text)
@@ -28,7 +33,33 @@ def extract_code_and_docs(text: str):
28
  return code_set, docs_set
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_metadata(repo_name, headers=None):
 
 
 
 
 
 
32
  api_url = f"https://api.github.com/repos/{repo_name}"
33
  tqdm.write(f"[+] Getting metadata for {repo_name}")
34
  try:
@@ -41,9 +72,15 @@ def get_metadata(repo_name, headers=None):
41
 
42
 
43
  def extract_information(repos, headers=None):
 
 
 
 
 
 
44
  extracted_infos = []
45
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
46
- # Get metadata
47
  metadata = get_metadata(repo_name, headers=headers)
48
  repo_info = {
49
  "name": repo_name,
@@ -60,7 +97,7 @@ def extract_information(repos, headers=None):
60
  if metadata.get("license"):
61
  repo_info["license"] = metadata["license"]["spdx_id"]
62
 
63
- # Download repo tarball bytes
64
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
65
  tqdm.write(f"[+] Downloading {repo_name}")
66
  try:
@@ -70,24 +107,50 @@ def extract_information(repos, headers=None):
70
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
71
  continue
72
 
73
- # Extract python files and parse them
74
  tqdm.write(f"[+] Extracting {repo_name} info")
75
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
76
  for member in tar:
77
- if (member.name.endswith(".py") and member.isfile()) is False:
78
- continue
79
- try:
80
- file_content = tar.extractfile(member).read().decode("utf-8")
81
- code_set, docs_set = extract_code_and_docs(file_content)
82
-
83
- repo_info["codes"].update(code_set)
84
- repo_info["docs"].update(docs_set)
85
- except UnicodeDecodeError as e:
86
- tqdm.write(
87
- f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
88
- )
89
- except SyntaxError as e:
90
- tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  extracted_infos.append(repo_info)
93
 
@@ -95,11 +158,20 @@ def extract_information(repos, headers=None):
95
 
96
 
97
  class RepoPipeline(Pipeline):
 
 
 
98
 
99
  def __init__(self, github_token=None, *args, **kwargs):
 
 
 
 
 
 
100
  super().__init__(*args, **kwargs)
101
 
102
- # Github token
103
  self.github_token = github_token
104
  if self.github_token:
105
  print("[+] GitHub token set!")
@@ -111,36 +183,56 @@ class RepoPipeline(Pipeline):
111
  )
112
 
113
  def _sanitize_parameters(self, **pipeline_parameters):
 
 
 
 
 
 
114
  preprocess_parameters = {}
115
  if "github_token" in pipeline_parameters:
116
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
117
 
 
118
  forward_parameters = {}
119
  if "max_length" in pipeline_parameters:
120
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
121
 
 
122
  postprocess_parameters = {}
123
  return preprocess_parameters, forward_parameters, postprocess_parameters
124
 
125
  def preprocess(self, input_: Any, github_token=None) -> List:
126
- # Making input to list format
 
 
 
 
 
 
127
  if isinstance(input_, str):
128
  input_ = [input_]
129
 
130
- # Building token
131
  headers = {"Accept": "application/vnd.github+json"}
132
  token = github_token or self.github_token
133
  if token:
134
  headers["Authorization"] = f"Bearer {token}"
135
 
136
- # Getting repositories' information: input_ means series of repositories
137
  extracted_infos = extract_information(input_, headers=headers)
138
-
139
  return extracted_infos
140
 
141
  def encode(self, text, max_length):
 
 
 
 
 
 
142
  assert max_length < 1024
143
 
 
144
  tokenizer = self.tokenizer
145
  tokens = (
146
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
@@ -149,20 +241,36 @@ class RepoPipeline(Pipeline):
149
  )
150
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
151
  source_ids = torch.tensor([tokens_id]).to(self.device)
152
-
153
  token_embeddings = self.model(source_ids)[0]
 
 
154
  sentence_embeddings = token_embeddings.mean(dim=1)
155
 
156
  return sentence_embeddings
157
 
158
  def generate_embeddings(self, text_sets, max_length):
 
 
 
 
 
 
159
  assert max_length < 1024
 
 
160
  return torch.zeros((1, 768), device=self.device) \
161
  if text_sets is None or len(text_sets) == 0 \
162
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
163
 
164
  def _forward(self, extracted_infos: List, max_length=512) -> List:
 
 
 
 
 
 
165
  model_outputs = []
 
166
  num_repos = len(extracted_infos)
167
  with tqdm(total=num_repos) as progress_bar:
168
  # For each repository
@@ -194,18 +302,37 @@ class RepoPipeline(Pipeline):
194
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
195
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
196
 
197
- # Requirement embeddings
198
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
199
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
200
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
201
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  progress_bar.update(1)
204
  model_outputs.append(info)
205
 
206
  return model_outputs
207
 
208
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
 
 
 
 
 
 
209
  return model_outputs
210
-
211
-
 
10
 
11
 
12
  def extract_code_and_docs(text: str):
13
+ """
14
+ The method for extracting codes and docs in text.
15
+ :param text: python file.
16
+ :return: codes and docs set.
17
+ """
18
  code_set = set()
19
  docs_set = set()
20
  root = ast.parse(text)
 
33
  return code_set, docs_set
34
 
35
 
36
+ def extract_requirements(lines):
37
+ """
38
+ The method for extracting requirements.
39
+ :param lines: requirements.
40
+ :return: requirement libraries.
41
+ """
42
+ requirements_set = set()
43
+ for line in lines:
44
+ try:
45
+ if line != "\n":
46
+ if " == " in line:
47
+ splitLine = line.split(" == ")
48
+ else:
49
+ splitLine = line.split("==")
50
+ requirements_set.update(splitLine[0])
51
+ except:
52
+ pass
53
+ return requirements_set
54
+
55
+
56
  def get_metadata(repo_name, headers=None):
57
+ """
58
+ The method for getting metadata of repository from github_api.
59
+ :param repo_name: repository name.
60
+ :param headers: request headers.
61
+ :return: response json.
62
+ """
63
  api_url = f"https://api.github.com/repos/{repo_name}"
64
  tqdm.write(f"[+] Getting metadata for {repo_name}")
65
  try:
 
72
 
73
 
74
  def extract_information(repos, headers=None):
75
+ """
76
+ The method for extracting repositories information.
77
+ :param repos: repositories.
78
+ :param headers: request header.
79
+ :return: a list for representing the information of each repository.
80
+ """
81
  extracted_infos = []
82
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
83
+ # 1. Extracting metadata.
84
  metadata = get_metadata(repo_name, headers=headers)
85
  repo_info = {
86
  "name": repo_name,
 
97
  if metadata.get("license"):
98
  repo_info["license"] = metadata["license"]["spdx_id"]
99
 
100
+ # Download repo tarball bytes ---- Download repository.
101
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
102
  tqdm.write(f"[+] Downloading {repo_name}")
103
  try:
 
107
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
108
  continue
109
 
110
+ # Extract repository files and parse them
111
  tqdm.write(f"[+] Extracting {repo_name} info")
112
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
113
  for member in tar:
114
+ # 2. Extracting codes and docs.
115
+ if (member.name.endswith(".py") and member.isfile()) is True:
116
+ try:
117
+ file_content = tar.extractfile(member).read().decode("utf-8")
118
+ # extract_code_and_docs
119
+ code_set, docs_set = extract_code_and_docs(file_content)
120
+ repo_info["codes"].update(code_set)
121
+ repo_info["docs"].update(docs_set)
122
+ except UnicodeDecodeError as e:
123
+ tqdm.write(
124
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
125
+ )
126
+ except SyntaxError as e:
127
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
128
+ # 3. Extracting readme.
129
+ elif ((member.name == "README.md" or member.name == "README.rst") and member.isfile()) is True:
130
+ try:
131
+ file_content = tar.extractfile(member).read().decode("utf-8")
132
+ # extract readme
133
+ readme_set = set(file_content)
134
+ repo_info["readmes"].update(readme_set)
135
+ except UnicodeDecodeError as e:
136
+ tqdm.write(
137
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
138
+ )
139
+ except SyntaxError as e:
140
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
141
+ # 4. Extracting requirements.
142
+ elif (member.name == "requirements.txt" and member.isfile()) is True:
143
+ try:
144
+ lines = tar.extractfile(member).readlines().decode("utf-8")
145
+ # extract readme
146
+ requirements_set = extract_requirements(lines)
147
+ repo_info["requirements"].update(requirements_set)
148
+ except UnicodeDecodeError as e:
149
+ tqdm.write(
150
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
151
+ )
152
+ except SyntaxError as e:
153
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
154
 
155
  extracted_infos.append(repo_info)
156
 
 
158
 
159
 
160
  class RepoPipeline(Pipeline):
161
+ """
162
+ A custom pipeline for generating series of embeddings of a repository.
163
+ """
164
 
165
  def __init__(self, github_token=None, *args, **kwargs):
166
+ """
167
+ The initial method for pipeline.
168
+ :param github_token: github_token
169
+ :param args: args
170
+ :param kwargs: kwargs
171
+ """
172
  super().__init__(*args, **kwargs)
173
 
174
+ # Getting github token
175
  self.github_token = github_token
176
  if self.github_token:
177
  print("[+] GitHub token set!")
 
183
  )
184
 
185
  def _sanitize_parameters(self, **pipeline_parameters):
186
+ """
187
+ The method for splitting parameters.
188
+ :param pipeline_parameters: parameters
189
+ :return: different parameters of different periods.
190
+ """
191
+ # The parameters of "preprocess" period.
192
  preprocess_parameters = {}
193
  if "github_token" in pipeline_parameters:
194
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
195
 
196
+ # The parameters of "forward" period.
197
  forward_parameters = {}
198
  if "max_length" in pipeline_parameters:
199
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
200
 
201
+ # The parameters of "postprocess" period.
202
  postprocess_parameters = {}
203
  return preprocess_parameters, forward_parameters, postprocess_parameters
204
 
205
  def preprocess(self, input_: Any, github_token=None) -> List:
206
+ """
207
+ The method for "preprocess" period.
208
+ :param input_: the input.
209
+ :param github_token: github_token.
210
+ :return: a list about repository information.
211
+ """
212
+ # Making input to list format.
213
  if isinstance(input_, str):
214
  input_ = [input_]
215
 
216
+ # Building headers.
217
  headers = {"Accept": "application/vnd.github+json"}
218
  token = github_token or self.github_token
219
  if token:
220
  headers["Authorization"] = f"Bearer {token}"
221
 
222
+ # Getting repositories' information: input_ means series of repositories (can be only one repository).
223
  extracted_infos = extract_information(input_, headers=headers)
 
224
  return extracted_infos
225
 
226
  def encode(self, text, max_length):
227
+ """
228
+ The method for encoding the text to embedding by using UniXcoder.
229
+ :param text: text.
230
+ :param max_length: the max length.
231
+ :return: the embedding of text.
232
+ """
233
  assert max_length < 1024
234
 
235
+ # Getting the tokenizer.
236
  tokenizer = self.tokenizer
237
  tokens = (
238
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
 
241
  )
242
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
243
  source_ids = torch.tensor([tokens_id]).to(self.device)
 
244
  token_embeddings = self.model(source_ids)[0]
245
+
246
+ # Getting the text embedding.
247
  sentence_embeddings = token_embeddings.mean(dim=1)
248
 
249
  return sentence_embeddings
250
 
251
  def generate_embeddings(self, text_sets, max_length):
252
+ """
253
+ The method for generating embeddings of a text set.
254
+ :param text_sets: text set.
255
+ :param max_length: max length.
256
+ :return: the embeddings of text set.
257
+ """
258
  assert max_length < 1024
259
+
260
+ # Concat the embeddings of each sentence/text in vertical dimension.
261
  return torch.zeros((1, 768), device=self.device) \
262
  if text_sets is None or len(text_sets) == 0 \
263
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
264
 
265
  def _forward(self, extracted_infos: List, max_length=512) -> List:
266
+ """
267
+ The method for "forward" period.
268
+ :param extracted_infos: the information of repositories.
269
+ :param max_length: max length.
270
+ :return: the output of this pipeline.
271
+ """
272
  model_outputs = []
273
+ # The number of repository.
274
  num_repos = len(extracted_infos)
275
  with tqdm(total=num_repos) as progress_bar:
276
  # For each repository
 
302
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
303
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
304
 
305
+ # Readme embeddings
306
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
307
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
308
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
309
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
310
 
311
+ # Repo-level mean embedding
312
+ info["mean_repo_embedding"] = torch.cat([
313
+ info["mean_code_embedding"],
314
+ info["mean_doc_embedding"],
315
+ info["mean_requirement_embedding"],
316
+ info["mean_readme_embedding"]
317
+ ], dim=1)
318
+
319
+ # TODO Remove test
320
+ info["code_embeddings_shape"] = info["code_embeddings"].shape
321
+ info["doc_embeddings_shape"] = info["doc_embeddings"].shape
322
+ info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
323
+ info["readme_embeddings_shape"] = info["readme_embeddings"].shape
324
+ info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
325
+
326
  progress_bar.update(1)
327
  model_outputs.append(info)
328
 
329
  return model_outputs
330
 
331
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
332
+ """
333
+ The method for "postprocess" period.
334
+ :param model_outputs: the output of this pipeline.
335
+ :param postprocess_parameters: the parameters of "postprocess" period.
336
+ :return: model output.
337
+ """
338
  return model_outputs