merligus commited on
Commit
c3bbde5
·
1 Parent(s): 4373a09

device problems

Browse files
Files changed (2) hide show
  1. Chroma.py +5 -1
  2. LangChain.py +6 -1
Chroma.py CHANGED
@@ -4,6 +4,7 @@ from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_chroma import Chroma
5
  import os
6
  import shutil
 
7
 
8
 
9
  def create_db(
@@ -14,10 +15,13 @@ def create_db(
14
  MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
15
  CHROMA_PATH="./chromadb/",
16
  ):
 
 
 
17
  # setup embeddings
18
  embeddings = HuggingFaceEmbeddings(
19
  model_name=MODEL_NAME,
20
- model_kwargs={"device": "cuda", "trust_remote_code": True},
21
  encode_kwargs={"normalize_embeddings": True},
22
  )
23
 
 
4
  from langchain_chroma import Chroma
5
  import os
6
  import shutil
7
+ import torch
8
 
9
 
10
  def create_db(
 
15
  MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
16
  CHROMA_PATH="./chromadb/",
17
  ):
18
+ # Check if CUDA is available
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
  # setup embeddings
22
  embeddings = HuggingFaceEmbeddings(
23
  model_name=MODEL_NAME,
24
+ model_kwargs={"device": device, "trust_remote_code": True},
25
  encode_kwargs={"normalize_embeddings": True},
26
  )
27
 
LangChain.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # chat
2
  from QWEN import ChatQWEN
3
  from langchain_core.prompts import ChatPromptTemplate
@@ -8,11 +10,14 @@ from langchain_chroma import Chroma
8
 
9
 
10
  def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
 
 
 
11
  # setup embeddings
12
  embeddings = HuggingFaceEmbeddings(
13
  model_name=MODEL_NAME,
14
  model_kwargs={
15
- "device": "cuda",
16
  "trust_remote_code": True,
17
  },
18
  encode_kwargs={"normalize_embeddings": True},
 
1
+ import torch
2
+
3
  # chat
4
  from QWEN import ChatQWEN
5
  from langchain_core.prompts import ChatPromptTemplate
 
10
 
11
 
12
  def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
13
+ # Check if CUDA is available
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
  # setup embeddings
17
  embeddings = HuggingFaceEmbeddings(
18
  model_name=MODEL_NAME,
19
  model_kwargs={
20
+ "device": device,
21
  "trust_remote_code": True,
22
  },
23
  encode_kwargs={"normalize_embeddings": True},