frankjosh commited on
Commit
347e3cb
Β·
verified Β·
1 Parent(s): ea9711e

Upload repository_recommender.py

Browse files

This model recommends python repositories to data scientists based on their project ideas.

Files changed (1) hide show
  1. repository_recommender.py +437 -0
repository_recommender.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """repository_recommender.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1qv09N8Vtcw5vr5NqCSfZonFeh1SQmVW5
8
+ """
9
+
10
+ pip install pyarrow pandas numpy streamlit gdown torch transformers
11
+
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ import streamlit as st
16
+ import pandas as pd
17
+ import numpy as np
18
+ from sklearn.metrics.pairwise import cosine_similarity
19
+ from transformers import AutoTokenizer, AutoModel
20
+ import torch
21
+ import gdown
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+ import json
25
+
26
+ # Initialize session state for history and feedback
27
+ if 'search_history' not in st.session_state:
28
+ st.session_state.search_history = []
29
+ if 'feedback_data' not in st.session_state:
30
+ st.session_state.feedback_data = {}
31
+
32
+ # Model Loading Optimization
33
+ class ModelManager:
34
+ def __init__(self):
35
+ self.model = None
36
+ self.tokenizer = None
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+
39
+ @st.cache_resource
40
+ def load_model_and_tokenizer(self):
41
+ """Optimized model loading with device placement"""
42
+ model_name = "Salesforce/codet5-small"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = AutoModel.from_pretrained(model_name).to(self.device)
45
+ model.eval() # Set model to evaluation mode
46
+ return tokenizer, model
47
+
48
+ def get_model_and_tokenizer(self):
49
+ if self.model is None or self.tokenizer is None:
50
+ self.tokenizer, self.model = self.load_model_and_tokenizer()
51
+ return self.tokenizer, self.model
52
+
53
+ @torch.no_grad() # Disable gradient computation
54
+ def generate_embedding(self, text, max_length=512):
55
+ """Optimized embedding generation"""
56
+ tokenizer, model = self.get_model_and_tokenizer()
57
+ inputs = tokenizer(
58
+ text,
59
+ return_tensors="pt",
60
+ padding=True,
61
+ truncation=True,
62
+ max_length=max_length
63
+ ).to(self.device)
64
+
65
+ outputs = model.encoder(**inputs)
66
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
67
+ return embedding
68
+
69
+ # Data Management
70
+ class DataManager:
71
+ @st.cache_resource
72
+ def load_dataset():
73
+ """Load and prepare dataset"""
74
+ Path("data").mkdir(exist_ok=True)
75
+ dataset_path = "/content/drive/MyDrive/practice_ml/filtered_dataset.csv"
76
+
77
+ if not Path(dataset_path).exists():
78
+ with st.spinner('Downloading dataset... This might take a few minutes...'):
79
+ url = "/content/drive/MyDrive/practice_ml"
80
+ gdown.download(url, dataset_path, quiet=False)
81
+
82
+ data = pd.read_csv(dataset_path)
83
+ data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
84
+ return data
85
+
86
+ @st.cache_data
87
+ def compute_embeddings(_data, _model_manager):
88
+ """Compute embeddings in batches"""
89
+ embeddings = []
90
+ batch_size = 32
91
+
92
+ with st.progress(0) as progress_bar:
93
+ for i in range(0, len(_data), batch_size):
94
+ batch = _data['text'].iloc[i:i+batch_size]
95
+ batch_embeddings = [_model_manager.generate_embedding(text) for text in batch]
96
+ embeddings.extend(batch_embeddings)
97
+ progress_bar.progress(min((i + batch_size) / len(_data), 1.0))
98
+
99
+ return embeddings
100
+
101
+ # History and Feedback Management
102
+ def add_to_history(query, recommendations):
103
+ """Add search to history"""
104
+ history_entry = {
105
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
106
+ 'query': query,
107
+ 'recommendations': recommendations[['repo', 'path', 'url', 'similarity']].to_dict('records')
108
+ }
109
+ st.session_state.search_history.insert(0, history_entry)
110
+
111
+ # Keep only last 10 searches
112
+ if len(st.session_state.search_history) > 10:
113
+ st.session_state.search_history.pop()
114
+
115
+ def save_feedback(repo_id, feedback_type):
116
+ """Save user feedback"""
117
+ if repo_id not in st.session_state.feedback_data:
118
+ st.session_state.feedback_data[repo_id] = {'likes': 0, 'dislikes': 0}
119
+
120
+ if feedback_type == 'like':
121
+ st.session_state.feedback_data[repo_id]['likes'] += 1
122
+ else:
123
+ st.session_state.feedback_data[repo_id]['dislikes'] += 1
124
+
125
+ def get_recommendations(query, data, model_manager, top_n=5):
126
+ """Get repository recommendations"""
127
+ query_embedding = model_manager.generate_embedding(query)
128
+ similarities = data['embedding'].apply(
129
+ lambda x: cosine_similarity([query_embedding], [x])[0][0]
130
+ )
131
+ recommendations = data.assign(similarity=similarities)\
132
+ .sort_values(by='similarity', ascending=False)\
133
+ .head(top_n)
134
+ return recommendations
135
+
136
+ # Streamlit UI
137
+ def main():
138
+ st.title("Repository Recommender System πŸš€")
139
+
140
+ # Sidebar with history
141
+ with st.sidebar:
142
+ st.header("Search History πŸ“œ")
143
+ if st.session_state.search_history:
144
+ for entry in st.session_state.search_history:
145
+ with st.expander(f"πŸ” {entry['timestamp']}", expanded=False):
146
+ st.write(f"Query: {entry['query']}")
147
+ for rec in entry['recommendations'][:3]: # Show top 3
148
+ st.write(f"- {rec['repo']} ({rec['similarity']:.2%})")
149
+ else:
150
+ st.info("No search history yet")
151
+
152
+ # Main interface
153
+ st.markdown("""
154
+ **Welcome to the Enhanced Repo_Recommender!**
155
+
156
+ Enter your project description to get personalized repository recommendations.
157
+ New features:
158
+ - πŸ“œ Search history (check sidebar)
159
+ - πŸ‘ Repository feedback
160
+ - ⚑ Optimized performance
161
+ """)
162
+
163
+ # Initialize managers
164
+ model_manager = ModelManager()
165
+ data = DataManager.load_dataset()
166
+
167
+ # Compute embeddings if not already done
168
+ if 'embedding' not in data.columns:
169
+ data['embedding'] = DataManager.compute_embeddings(data, model_manager)
170
+
171
+ # User input
172
+ user_query = st.text_area(
173
+ "Describe your project:",
174
+ height=150,
175
+ placeholder="Example: I need a machine learning project for customer churn prediction..."
176
+ )
177
+
178
+ # Get recommendations
179
+ if st.button("Get Recommendations", type="primary"):
180
+ if user_query.strip():
181
+ with st.spinner("Finding relevant repositories..."):
182
+ recommendations = get_recommendations(user_query, data, model_manager)
183
+ add_to_history(user_query, recommendations)
184
+
185
+ # Display recommendations
186
+ st.markdown("### 🎯 Top Recommendations")
187
+ for idx, row in recommendations.iterrows():
188
+ with st.expander(f"Repository {idx + 1}: {row['repo']}", expanded=True):
189
+ cols = st.columns([2, 1])
190
+ with cols[0]:
191
+ st.markdown(f"**Path:** `{row['path']}`")
192
+ st.markdown(f"**Summary:** {row['summary']}")
193
+ st.markdown(f"**URL:** [View Repository]({row['url']})")
194
+ with cols[1]:
195
+ st.metric("Similarity", f"{row['similarity']:.2%}")
196
+
197
+ # Feedback buttons
198
+ feedback_cols = st.columns(2)
199
+ repo_id = f"{row['repo']}_{row['path']}"
200
+
201
+ with feedback_cols[0]:
202
+ if st.button("πŸ‘", key=f"like_{repo_id}"):
203
+ save_feedback(repo_id, 'like')
204
+ st.success("Thanks for your feedback!")
205
+
206
+ with feedback_cols[1]:
207
+ if st.button("πŸ‘Ž", key=f"dislike_{repo_id}"):
208
+ save_feedback(repo_id, 'dislike')
209
+ st.success("Thanks for your feedback!")
210
+
211
+ # Show feedback stats
212
+ if repo_id in st.session_state.feedback_data:
213
+ stats = st.session_state.feedback_data[repo_id]
214
+ st.write(f"Likes: {stats['likes']} | Dislikes: {stats['dislikes']}")
215
+
216
+ if row['docstring']:
217
+ with st.expander("View Documentation"):
218
+ st.markdown(row['docstring'])
219
+ else:
220
+ st.warning("Please enter a project description.")
221
+
222
+ # Footer
223
+ st.markdown("---")
224
+ st.markdown("Made with πŸ€– using CodeT5 and Streamlit")
225
+
226
+ if __name__ == "__main__":
227
+ main()
228
+
229
+ import warnings
230
+ warnings.filterwarnings('ignore')
231
+
232
+ import streamlit as st
233
+ import pandas as pd
234
+ import numpy as np
235
+ from sklearn.metrics.pairwise import cosine_similarity
236
+ from transformers import AutoTokenizer, AutoModel
237
+ import torch
238
+ import gdown
239
+ from pathlib import Path
240
+ from datetime import datetime
241
+
242
+ # Initialize session state
243
+ if 'search_history' not in st.session_state:
244
+ st.session_state.search_history = []
245
+ if 'feedback_data' not in st.session_state:
246
+ st.session_state.feedback_data = {}
247
+
248
+ # Model Loading Optimization
249
+ @st.cache_resource
250
+ def load_model_and_tokenizer():
251
+ """Optimized model loading with device placement"""
252
+ model_name = "Salesforce/codet5-small"
253
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
254
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
255
+ model = AutoModel.from_pretrained(model_name).to(device)
256
+ model.eval() # Set model to evaluation mode
257
+ return tokenizer, model, device
258
+
259
+ @st.cache_resource
260
+ def load_dataset():
261
+ """Load and prepare dataset"""
262
+ Path("data").mkdir(exist_ok=True)
263
+ dataset_path = "/content/drive/MyDrive/practice_ml/filtered_dataset.csv"
264
+
265
+ if not Path(dataset_path).exists():
266
+ with st.spinner('Downloading dataset... This might take a few minutes...'):
267
+ url = "/content/drive/MyDrive/practice_ml"
268
+ gdown.download(url, dataset_path, quiet=False)
269
+
270
+ data = pd.read_csv(dataset_path)
271
+ data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
272
+ return data
273
+
274
+ @st.cache_data
275
+ def generate_embedding(_tokenizer, _model, _device, text, max_length=512):
276
+ """Generate embedding for a single text"""
277
+ with torch.no_grad():
278
+ inputs = _tokenizer(
279
+ text,
280
+ return_tensors="pt",
281
+ padding=True,
282
+ truncation=True,
283
+ max_length=max_length
284
+ ).to(_device)
285
+
286
+ outputs = _model.encoder(**inputs)
287
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
288
+ return embedding
289
+
290
+ @st.cache_data
291
+ def compute_embeddings(_data, _tokenizer, _model, _device):
292
+ """Compute embeddings in batches"""
293
+ embeddings = []
294
+ batch_size = 32
295
+ texts = _data['text'].tolist()
296
+
297
+ with st.progress(0) as progress_bar:
298
+ progress_container = st.empty()
299
+ for i in range(0, len(texts), batch_size):
300
+ batch = texts[i:i+batch_size]
301
+ batch_embeddings = [
302
+ generate_embedding(_tokenizer, _model, _device, text)
303
+ for text in batch
304
+ ]
305
+ embeddings.extend(batch_embeddings)
306
+ progress_container.progress(min((i + batch_size) / len(texts), 1.0))
307
+
308
+ return embeddings
309
+
310
+ def add_to_history(query, recommendations):
311
+ """Add search to history"""
312
+ history_entry = {
313
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
314
+ 'query': query,
315
+ 'recommendations': recommendations[['repo', 'path', 'url', 'similarity']].to_dict('records')
316
+ }
317
+ st.session_state.search_history.insert(0, history_entry)
318
+ if len(st.session_state.search_history) > 10:
319
+ st.session_state.search_history.pop()
320
+
321
+ def save_feedback(repo_id, feedback_type):
322
+ """Save user feedback"""
323
+ if repo_id not in st.session_state.feedback_data:
324
+ st.session_state.feedback_data[repo_id] = {'likes': 0, 'dislikes': 0}
325
+
326
+ if feedback_type == 'like':
327
+ st.session_state.feedback_data[repo_id]['likes'] += 1
328
+ else:
329
+ st.session_state.feedback_data[repo_id]['dislikes'] += 1
330
+
331
+ def get_recommendations(query, data, tokenizer, model, device, top_n=5):
332
+ """Get repository recommendations"""
333
+ query_embedding = generate_embedding(tokenizer, model, device, query)
334
+
335
+ similarities = []
336
+ for emb in data['embedding']:
337
+ sim = cosine_similarity([query_embedding], [emb])[0][0]
338
+ similarities.append(sim)
339
+
340
+ recommendations = data.assign(similarity=similarities)\
341
+ .sort_values(by='similarity', ascending=False)\
342
+ .head(top_n)
343
+ return recommendations
344
+
345
+ def main():
346
+ st.title("Repository Recommender System πŸš€")
347
+
348
+ # Sidebar with history
349
+ with st.sidebar:
350
+ st.header("Search History πŸ“œ")
351
+ if st.session_state.search_history:
352
+ for entry in st.session_state.search_history:
353
+ with st.expander(f"πŸ” {entry['timestamp']}", expanded=False):
354
+ st.write(f"Query: {entry['query']}")
355
+ for rec in entry['recommendations'][:3]:
356
+ st.write(f"- {rec['repo']} ({rec['similarity']:.2%})")
357
+ else:
358
+ st.info("No search history yet")
359
+
360
+ st.markdown("""
361
+ **Welcome to the Enhanced Repo_Recommender!**
362
+
363
+ Enter your project description to get personalized repository recommendations.
364
+ New features:
365
+ - πŸ“œ Search history (check sidebar)
366
+ - πŸ‘ Repository feedback
367
+ - ⚑ Optimized performance
368
+ """)
369
+
370
+ # Load resources
371
+ with st.spinner("Loading model and data..."):
372
+ tokenizer, model, device = load_model_and_tokenizer()
373
+ data = load_dataset()
374
+
375
+ # Compute embeddings if not already done
376
+ if 'embedding' not in data.columns:
377
+ data['embedding'] = compute_embeddings(data, tokenizer, model, device)
378
+
379
+ # User input
380
+ user_query = st.text_area(
381
+ "Describe your project:",
382
+ height=150,
383
+ placeholder="Example: I need a machine learning project for customer churn prediction..."
384
+ )
385
+
386
+ # Get recommendations
387
+ if st.button("Get Recommendations", type="primary"):
388
+ if user_query.strip():
389
+ with st.spinner("Finding relevant repositories..."):
390
+ recommendations = get_recommendations(
391
+ user_query, data, tokenizer, model, device
392
+ )
393
+ add_to_history(user_query, recommendations)
394
+
395
+ # Display recommendations
396
+ st.markdown("### 🎯 Top Recommendations")
397
+ for idx, row in recommendations.iterrows():
398
+ with st.expander(f"Repository {idx + 1}: {row['repo']}", expanded=True):
399
+ cols = st.columns([2, 1])
400
+ with cols[0]:
401
+ st.markdown(f"**Path:** `{row['path']}`")
402
+ st.markdown(f"**Summary:** {row['summary']}")
403
+ st.markdown(f"**URL:** [View Repository]({row['url']})")
404
+ with cols[1]:
405
+ st.metric("Similarity", f"{row['similarity']:.2%}")
406
+
407
+ # Feedback buttons
408
+ feedback_cols = st.columns(2)
409
+ repo_id = f"{row['repo']}_{row['path']}"
410
+
411
+ with feedback_cols[0]:
412
+ if st.button("πŸ‘", key=f"like_{repo_id}"):
413
+ save_feedback(repo_id, 'like')
414
+ st.success("Thanks for your feedback!")
415
+
416
+ with feedback_cols[1]:
417
+ if st.button("πŸ‘Ž", key=f"dislike_{repo_id}"):
418
+ save_feedback(repo_id, 'dislike')
419
+ st.success("Thanks for your feedback!")
420
+
421
+ # Show feedback stats
422
+ if repo_id in st.session_state.feedback_data:
423
+ stats = st.session_state.feedback_data[repo_id]
424
+ st.write(f"Likes: {stats['likes']} | Dislikes: {stats['dislikes']}")
425
+
426
+ if row['docstring']:
427
+ with st.expander("View Documentation"):
428
+ st.markdown(row['docstring'])
429
+ else:
430
+ st.warning("Please enter a project description.")
431
+
432
+ # Footer
433
+ st.markdown("---")
434
+ st.markdown("Made with πŸ€– using CodeT5 and Streamlit")
435
+
436
+ if __name__ == "__main__":
437
+ main()