frankjosh commited on
Commit
ad91929
·
verified ·
1 Parent(s): 178a171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -17,13 +17,13 @@ def load_model():
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  # Check if GPU is available
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- model = AutoModel.from_pretrained(model_name).to("cuda")
21
- return tokenizer, model
22
 
23
- def generate_embedding(text, tokenizer, model):
24
  """Generate embeddings for a given text."""
25
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
26
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
27
  with torch.no_grad():
28
  outputs = model.encoder(**inputs)
29
  return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
@@ -53,14 +53,14 @@ def main():
53
  st.write("Find Python repositories to learn production-level coding practices.")
54
 
55
  # Load resources
56
- tokenizer, model = load_model()
57
  data = load_data()
58
 
59
  # Input user query
60
  user_query = st.text_input("Describe your project or learning goal:",
61
  "I am working on a project to recommend music using pandas and numpy.")
62
  if user_query:
63
- query_embedding = generate_embedding(user_query, tokenizer, model)
64
 
65
  # Compute similarity
66
  data['similarity'] = data['embedding'].apply(
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  # Check if GPU is available
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model = AutoModel.from_pretrained(model_name).to(device)
21
+ return tokenizer, model, device
22
 
23
+ def generate_embedding(text, tokenizer, model, device):
24
  """Generate embeddings for a given text."""
25
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
26
+ inputs = {k: v.to(device) for k, v in inputs.items()}
27
  with torch.no_grad():
28
  outputs = model.encoder(**inputs)
29
  return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
 
53
  st.write("Find Python repositories to learn production-level coding practices.")
54
 
55
  # Load resources
56
+ tokenizer, model, device = load_model()
57
  data = load_data()
58
 
59
  # Input user query
60
  user_query = st.text_input("Describe your project or learning goal:",
61
  "I am working on a project to recommend music using pandas and numpy.")
62
  if user_query:
63
+ query_embedding = generate_embedding(user_query, tokenizer, model, device)
64
 
65
  # Compute similarity
66
  data['similarity'] = data['embedding'].apply(