frankjosh commited on
Commit
e688e26
Β·
verified Β·
1 Parent(s): 856c3dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -143
app.py CHANGED
@@ -7,10 +7,11 @@ import numpy as np
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from transformers import AutoTokenizer, AutoModel
9
  import torch
10
- from torch.utils.data import DataLoader, Dataset
11
  from datasets import load_dataset
12
  from datetime import datetime
13
  from typing import List, Dict, Any
 
14
  from functools import partial
15
 
16
  # Configure GPU if available
@@ -19,187 +20,325 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  # Initialize session state
20
  if 'history' not in st.session_state:
21
  st.session_state.history = []
22
-
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
- # Define subset size and batch size for optimization
27
- SUBSET_SIZE = 500 # Subset for faster precomputation
28
- BATCH_SIZE = 8 # Smaller batch size to reduce memory overhead
29
 
30
- @st.cache_resource
31
- def load_model_and_tokenizer_with_progress():
32
- """
33
- Load the pre-trained model and tokenizer using Hugging Face Transformers
34
- with a progress bar for better user experience.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  """
36
- progress_bar = st.progress(0)
37
- status_text = st.empty()
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  try:
40
- progress_bar.progress(10)
41
- status_text.text("Loading tokenizer...")
 
 
 
 
 
 
 
 
 
42
  model_name = "Salesforce/codet5-small"
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
44
 
45
- progress_bar.progress(50)
46
- status_text.text("Loading model...")
47
- model = AutoModel.from_pretrained(model_name).to(device)
48
- model.eval()
49
 
50
- progress_bar.progress(100)
51
- status_text.text("Model loaded successfully!")
52
-
53
- finally:
54
- progress_bar.empty()
55
- status_text.empty()
56
 
57
- return tokenizer, model
 
 
58
 
59
- @st.cache_resource
60
- def load_data():
61
- """
62
- Load and sample the dataset from Hugging Face.
63
- Ensures the 'text' column is created for embedding precomputation.
64
- """
65
- dataset = load_dataset("frankjosh/filtered_dataset")
66
- data = pd.DataFrame(dataset['train'])
67
 
68
- # Take a random subset of data
69
- data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True)
 
 
 
 
 
 
 
 
 
70
 
71
- # Create a 'text' column by combining relevant fields
72
- data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
73
- return data
 
74
 
75
- @st.cache_resource
76
- def precompute_embeddings(data: pd.DataFrame, _tokenizer, _model, batch_size=BATCH_SIZE):
77
- """
78
- Precompute embeddings for repository metadata to optimize query performance.
79
- The tokenizer and model are excluded from caching as they are unhashable.
80
- """
81
- class TextDataset(Dataset):
82
- def __init__(self, texts: List[str], tokenizer, max_length=512):
83
- self.texts = texts
84
- self.tokenizer = tokenizer
85
- self.max_length = max_length
86
-
87
- def __len__(self):
88
- return len(self.texts)
89
-
90
- def __getitem__(self, idx):
91
- return self.tokenizer(
92
- self.texts[idx],
93
- padding='max_length',
94
- truncation=True,
95
- max_length=self.max_length,
96
- return_tensors="pt"
97
- )
98
-
99
- def collate_fn(batch, pad_token_id):
100
- max_length = max(inputs['input_ids'].shape[1] for inputs in batch)
101
- input_ids, attention_mask = [], []
102
- for inputs in batch:
103
- input_ids.append(torch.nn.functional.pad(
104
- inputs['input_ids'].squeeze(),
105
- (0, max_length - inputs['input_ids'].shape[1]),
106
- value=pad_token_id
107
- ))
108
- attention_mask.append(torch.nn.functional.pad(
109
- inputs['attention_mask'].squeeze(),
110
- (0, max_length - inputs['attention_mask'].shape[1]),
111
- value=0
112
- ))
113
- return {
114
- 'input_ids': torch.stack(input_ids),
115
- 'attention_mask': torch.stack(attention_mask)
116
- }
117
-
118
- def generate_embeddings_batch(model, batch, device):
119
- with torch.no_grad():
120
- batch = {k: v.to(device) for k, v in batch.items()}
121
- outputs = model.encoder(**batch)
122
- return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
123
-
124
- dataset = TextDataset(data['text'].tolist(), _tokenizer)
125
  dataloader = DataLoader(
126
- dataset, batch_size=batch_size, shuffle=False,
127
- collate_fn=partial(collate_fn, pad_token_id=_tokenizer.pad_token_id)
 
 
 
 
128
  )
129
-
130
  embeddings = []
131
- progress_bar = st.progress(0) # Progress bar for embedding computation
 
 
 
 
 
 
 
132
  for i, batch in enumerate(dataloader):
133
- batch_embeddings = generate_embeddings_batch(_model, batch, device)
 
134
  embeddings.extend(batch_embeddings)
135
- progress_bar.progress((i + 1) / len(dataloader))
136
-
 
 
 
 
 
 
 
 
137
  progress_bar.empty()
 
 
 
138
  data['embedding'] = embeddings
139
  return data
140
 
141
  @torch.no_grad()
142
  def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray:
143
- """
144
- Generate embedding for a user query using the pre-trained model.
145
- """
146
  inputs = tokenizer(
147
- query, return_tensors="pt", padding=True,
148
- truncation=True, max_length=512
 
 
 
149
  ).to(device)
 
150
  outputs = model.encoder(**inputs)
151
- return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
 
152
 
153
- def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n=5) -> pd.DataFrame:
154
- """
155
- Compute cosine similarity and return the top N most similar repositories.
156
- """
157
- # Reshape query_embedding to 2D
158
- query_embedding = query_embedding.reshape(1, -1)
159
-
160
- # Convert data['embedding'] to a 2D array
161
- embeddings = np.vstack(data['embedding'].values)
162
-
163
- # Compute cosine similarity
164
- similarities = cosine_similarity(query_embedding, embeddings)[0]
165
-
166
- # Add similarity scores to the DataFrame
167
  data['similarity'] = similarities
168
-
169
  return data.nlargest(top_n, 'similarity')
170
 
171
- def display_recommendations(recommendations: pd.DataFrame):
172
- """
173
- Display the recommended repositories in the Streamlit app interface.
174
- """
175
- st.markdown("### 🎯 Top Recommendations")
176
- for idx, row in recommendations.iterrows():
177
- st.markdown(f"### {idx + 1}. {row['repo']}")
178
- st.metric("Match Score", f"{row['similarity']:.2%}")
179
- st.markdown(f"[View Repository]({row['url']})")
180
 
181
- # Main workflow
182
- st.title("Repository Recommender System πŸš€")
183
- st.caption("Find repositories based on your project description.")
184
 
185
- # Load resources with progress bar
186
- tokenizer, model = load_model_and_tokenizer_with_progress()
187
 
188
- # Load data and precompute embeddings
189
- data = load_data()
190
- data = precompute_embeddings(data, tokenizer, model)
191
 
192
- # User input
193
  user_query = st.text_area(
194
- "Describe your project:", height=150,
195
- placeholder="Example: A machine learning project for customer churn prediction..."
 
196
  )
197
 
198
- if st.button("πŸ” Search Repositories"):
199
- if user_query.strip():
200
- with st.spinner("Finding relevant repositories..."):
201
- query_embedding = generate_query_embedding(model, tokenizer, user_query)
202
- recommendations = find_similar_repos(query_embedding, data)
203
- display_recommendations(recommendations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  else:
205
- st.error("Please provide a project description.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from transformers import AutoTokenizer, AutoModel
9
  import torch
10
+ from tqdm import tqdm
11
  from datasets import load_dataset
12
  from datetime import datetime
13
  from typing import List, Dict, Any
14
+ from torch.utils.data import DataLoader, Dataset
15
  from functools import partial
16
 
17
  # Configure GPU if available
 
20
  # Initialize session state
21
  if 'history' not in st.session_state:
22
  st.session_state.history = []
 
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
+ # Define subset size
27
+ SUBSET_SIZE = 1000 # Starting with 500 items for quick testing
 
28
 
29
+ class TextDataset(Dataset):
30
+ def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
31
+ self.texts = texts
32
+ self.tokenizer = tokenizer
33
+ self.max_length = max_length
34
+
35
+ def __len__(self):
36
+ return len(self.texts)
37
+
38
+ def __getitem__(self, idx):
39
+ return self.tokenizer(
40
+ self.texts[idx],
41
+ padding='max_length',
42
+ truncation=True,
43
+ max_length=self.max_length,
44
+ return_tensors="pt"
45
+ )
46
+
47
+ def generate_case_study(row: Dict[str, Any]) -> str:
48
+ """Generate a detailed case study for a repository using available metadata"""
49
+ # Extract relevant information from the row
50
+ summary = row.get('summary', '').strip()
51
+ docstring = row.get('docstring', '').strip()
52
+ repo_name = row.get('repo', '').strip()
53
+
54
+ # Generate a more detailed overview using available information
55
+ overview = summary if summary else "This repository provides a software implementation"
56
+ if docstring:
57
+ # Extract the first paragraph of the docstring for additional context
58
+ first_para = docstring.split('\n\n')[0].strip()
59
+ overview = f"{overview}. {first_para}"
60
+
61
+ # Analyze the repository path to infer technology stack
62
+ path_components = row.get('path', '').lower().split('/')
63
+ tech_stack = []
64
+
65
+ # Common technology indicators in paths
66
+ if any('python' in comp for comp in path_components):
67
+ tech_stack.append("Python")
68
+ if any('tensorflow' in comp or 'tf' in comp for comp in path_components):
69
+ tech_stack.append("TensorFlow")
70
+ if any('pytorch' in comp for comp in path_components):
71
+ tech_stack.append("PyTorch")
72
+ if any('react' in comp for comp in path_components):
73
+ tech_stack.append("React")
74
+
75
+ tech_stack_str = ", ".join(tech_stack) if tech_stack else "various technologies"
76
+
77
+ case_study = f"""
78
+ ### Overview
79
+ {overview}
80
+
81
+ ### Technical Implementation
82
+ This project is built using {tech_stack_str}. The implementation focuses on providing a robust and maintainable solution for {summary.lower() if summary else 'the specified requirements'}.
83
+
84
+ ### Key Features
85
+ - Primary functionality: {summary if summary else 'Implementation of core project requirements'}
86
+ - Complete documentation and code examples
87
+ - Well-structured implementation following best practices
88
+ - Modular design for easy integration and customization
89
+
90
+ ### Use Cases
91
+ This repository is particularly valuable for:
92
+ - Developers implementing similar functionality in their projects
93
+ - Teams looking for reference implementations and best practices
94
+ - Projects requiring similar technical capabilities
95
+ - Learning and educational purposes in related technical domains
96
+
97
+ ### Integration Considerations
98
+ The repository can be integrated into existing projects, with consideration for:
99
+ - Compatibility with existing technology stacks
100
+ - Required dependencies and prerequisites
101
+ - Potential customization needs
102
+ - Performance and scalability requirements
103
  """
104
+ return case_study
105
+
106
+ def display_recommendations(recommendations: pd.DataFrame):
107
+ """Display recommendations in a list format with all details"""
108
+ st.markdown("### 🎯 Top Recommendations")
109
 
110
+ # Create a list of recommendations
111
+ for idx, row in recommendations.iterrows():
112
+ with st.container():
113
+ # Header with repository name and match score
114
+ col1, col2 = st.columns([3, 1])
115
+ with col1:
116
+ st.markdown(f"### {idx + 1}. {row['repo']}")
117
+ with col2:
118
+ st.metric("Match Score", f"{row['similarity']:.2%}")
119
+
120
+ # Repository details
121
+ st.markdown(f"**URL:** [View Repository]({row['url']})")
122
+ st.markdown(f"**Path:** `{row['path']}`")
123
+
124
+ # Feedback buttons
125
+ col1, col2, col3 = st.columns([1, 1, 4])
126
+ with col1:
127
+ if st.button("πŸ‘", key=f"like_{idx}"):
128
+ st.session_state.feedback[row['repo']] = st.session_state.feedback.get(row['repo'], {'likes': 0, 'dislikes': 0})
129
+ st.session_state.feedback[row['repo']]['likes'] += 1
130
+ st.success("Thanks for your feedback!")
131
+ with col2:
132
+ if st.button("πŸ‘Ž", key=f"dislike_{idx}"):
133
+ st.session_state.feedback[row['repo']] = st.session_state.feedback.get(row['repo'], {'likes': 0, 'dislikes': 0})
134
+ st.session_state.feedback[row['repo']]['dislikes'] += 1
135
+ st.success("Thanks for your feedback!")
136
+
137
+ # Documentation and case study in tabs
138
+ tab1, tab2 = st.tabs(["πŸ“š Documentation", "πŸ“‘ Case Study"])
139
+ with tab1:
140
+ if row['docstring']:
141
+ st.markdown(row['docstring'])
142
+ else:
143
+ st.info("No documentation available")
144
+
145
+ with tab2:
146
+ st.markdown(generate_case_study(row))
147
+
148
+ st.markdown("---")
149
+
150
+ @st.cache_resource
151
+ def load_data_and_model():
152
+ """Load the dataset and model with optimized memory usage"""
153
  try:
154
+ # Load dataset
155
+ dataset = load_dataset("frankjosh/filtered_dataset")
156
+ data = pd.DataFrame(dataset['train'])
157
+
158
+ # Take a random subset
159
+ data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True)
160
+
161
+ # Combine text fields
162
+ data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
163
+
164
+ # Load model and tokenizer
165
  model_name = "Salesforce/codet5-small"
166
  tokenizer = AutoTokenizer.from_pretrained(model_name)
167
+ model = AutoModel.from_pretrained(model_name)
168
 
169
+ if torch.cuda.is_available():
170
+ model = model.to(device)
 
 
171
 
172
+ model.eval()
173
+ return data, tokenizer, model
 
 
 
 
174
 
175
+ except Exception as e:
176
+ st.error(f"Error in initialization: {str(e)}")
177
+ st.stop()
178
 
179
+ def collate_fn(batch, pad_token_id):
180
+ max_length = max(inputs['input_ids'].shape[1] for inputs in batch)
181
+ input_ids = []
182
+ attention_mask = []
 
 
 
 
183
 
184
+ for inputs in batch:
185
+ input_ids.append(torch.nn.functional.pad(
186
+ inputs['input_ids'].squeeze(),
187
+ (0, max_length - inputs['input_ids'].shape[1]),
188
+ value=pad_token_id
189
+ ))
190
+ attention_mask.append(torch.nn.functional.pad(
191
+ inputs['attention_mask'].squeeze(),
192
+ (0, max_length - inputs['attention_mask'].shape[1]),
193
+ value=0
194
+ ))
195
 
196
+ return {
197
+ 'input_ids': torch.stack(input_ids),
198
+ 'attention_mask': torch.stack(attention_mask)
199
+ }
200
 
201
+ def generate_embeddings_batch(model, batch, device):
202
+ """Generate embeddings for a batch of inputs"""
203
+ with torch.no_grad():
204
+ batch = {k: v.to(device) for k, v in batch.items()}
205
+ outputs = model.encoder(**batch)
206
+ embeddings = outputs.last_hidden_state.mean(dim=1)
207
+ return embeddings.cpu().numpy()
208
+
209
+ def precompute_embeddings(data: pd.DataFrame, model, tokenizer, batch_size: int = 16):
210
+ """Precompute embeddings with batching and progress tracking"""
211
+ dataset = TextDataset(data['text'].tolist(), tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  dataloader = DataLoader(
213
+ dataset,
214
+ batch_size=batch_size,
215
+ shuffle=False,
216
+ collate_fn=partial(collate_fn, pad_token_id=tokenizer.pad_token_id),
217
+ num_workers=2,
218
+ pin_memory=True
219
  )
220
+
221
  embeddings = []
222
+ total_batches = len(dataloader)
223
+
224
+ # Create a progress bar
225
+ progress_bar = st.progress(0)
226
+ status_text = st.empty()
227
+
228
+ start_time = datetime.now()
229
+
230
  for i, batch in enumerate(dataloader):
231
+ # Generate embeddings for batch
232
+ batch_embeddings = generate_embeddings_batch(model, batch, device)
233
  embeddings.extend(batch_embeddings)
234
+
235
+ # Update progress
236
+ progress = (i + 1) / total_batches
237
+ progress_bar.progress(progress)
238
+
239
+ # Calculate and display ETA
240
+ elapsed_time = (datetime.now() - start_time).total_seconds()
241
+ eta = (elapsed_time / (i + 1)) * (total_batches - (i + 1))
242
+ status_text.text(f"Processing batch {i+1}/{total_batches}. ETA: {int(eta)} seconds")
243
+
244
  progress_bar.empty()
245
+ status_text.empty()
246
+
247
+ # Add embeddings to dataframe
248
  data['embedding'] = embeddings
249
  return data
250
 
251
  @torch.no_grad()
252
  def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray:
253
+ """Generate embedding for a single query"""
 
 
254
  inputs = tokenizer(
255
+ query,
256
+ return_tensors="pt",
257
+ padding=True,
258
+ truncation=True,
259
+ max_length=512
260
  ).to(device)
261
+
262
  outputs = model.encoder(**inputs)
263
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
264
+ return embedding.squeeze()
265
 
266
+ def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n: int = 5) -> pd.DataFrame:
267
+ """Find similar repositories using vectorized operations"""
268
+ similarities = cosine_similarity([query_embedding], np.stack(data['embedding'].values))[0]
 
 
 
 
 
 
 
 
 
 
 
269
  data['similarity'] = similarities
 
270
  return data.nlargest(top_n, 'similarity')
271
 
272
+ # Load resources
273
+ data, tokenizer, model = load_data_and_model()
 
 
 
 
 
 
 
274
 
275
+ # Add info about subset size
276
+ st.info(f"Running with a subset of {SUBSET_SIZE} repositories for testing purposes.")
 
277
 
278
+ # Precompute embeddings for the subset
279
+ data = precompute_embeddings(data, model, tokenizer)
280
 
281
+ # Main App Interface
282
+ st.title("Repository Recommender System πŸš€")
283
+ st.caption("Testing Version - Running on subset of data")
284
 
285
+ # Main interface
286
  user_query = st.text_area(
287
+ "Describe your project:",
288
+ height=150,
289
+ placeholder="Example: I need a machine learning project for customer churn prediction..."
290
  )
291
 
292
+ # Search button and filters
293
+ col1, col2 = st.columns([2, 1])
294
+ with col1:
295
+ search_button = st.button("πŸ” Search Repositories", type="primary")
296
+ with col2:
297
+ top_n = st.selectbox("Number of results:", [3, 5, 10], index=1)
298
+
299
+ if search_button and user_query.strip():
300
+ with st.spinner("Finding relevant repositories..."):
301
+ # Generate query embedding and get recommendations
302
+ query_embedding = generate_query_embedding(model, tokenizer, user_query)
303
+ recommendations = find_similar_repos(query_embedding, data, top_n)
304
+
305
+ # Save to history
306
+ st.session_state.history.append({
307
+ 'query': user_query,
308
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
309
+ 'results': recommendations['repo'].tolist()
310
+ })
311
+
312
+ # Display recommendations using the new function
313
+ display_recommendations(recommendations)
314
+
315
+ # Sidebar for History and Stats
316
+ with st.sidebar:
317
+ st.header("πŸ“Š Search History")
318
+ if st.session_state.history:
319
+ for idx, item in enumerate(reversed(st.session_state.history[-5:])):
320
+ st.markdown(f"**Search {len(st.session_state.history)-idx}**")
321
+ st.markdown(f"Query: _{item['query'][:30]}..._")
322
+ st.caption(f"Time: {item['timestamp']}")
323
+ st.caption(f"Results: {len(item['results'])} repositories")
324
+ if st.button("Rerun this search", key=f"rerun_{idx}"):
325
+ st.session_state.rerun_query = item['query']
326
+ st.markdown("---")
327
  else:
328
+ st.write("No search history yet")
329
+
330
+ st.header("πŸ“ˆ Usage Statistics")
331
+ st.write(f"Total Searches: {len(st.session_state.history)}")
332
+ if st.session_state.feedback:
333
+ feedback_df = pd.DataFrame(st.session_state.feedback).T
334
+ feedback_df['Total'] = feedback_df['likes'] + feedback_df['dislikes']
335
+ st.bar_chart(feedback_df[['likes', 'dislikes']])
336
+
337
+ # Footer
338
+ st.markdown("---")
339
+ st.markdown(
340
+ """
341
+ Made with πŸ€– using CodeT5 and Streamlit |
342
+
343
+ """
344
+ )