device problems
Browse files- Chroma.py +5 -1
- LangChain.py +6 -1
Chroma.py
CHANGED
@@ -4,6 +4,7 @@ from langchain_huggingface import HuggingFaceEmbeddings
|
|
4 |
from langchain_chroma import Chroma
|
5 |
import os
|
6 |
import shutil
|
|
|
7 |
|
8 |
|
9 |
def create_db(
|
@@ -14,10 +15,13 @@ def create_db(
|
|
14 |
MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
|
15 |
CHROMA_PATH="./chromadb/",
|
16 |
):
|
|
|
|
|
|
|
17 |
# setup embeddings
|
18 |
embeddings = HuggingFaceEmbeddings(
|
19 |
model_name=MODEL_NAME,
|
20 |
-
model_kwargs={"device":
|
21 |
encode_kwargs={"normalize_embeddings": True},
|
22 |
)
|
23 |
|
|
|
4 |
from langchain_chroma import Chroma
|
5 |
import os
|
6 |
import shutil
|
7 |
+
import torch
|
8 |
|
9 |
|
10 |
def create_db(
|
|
|
15 |
MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
|
16 |
CHROMA_PATH="./chromadb/",
|
17 |
):
|
18 |
+
# Check if CUDA is available
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
# setup embeddings
|
22 |
embeddings = HuggingFaceEmbeddings(
|
23 |
model_name=MODEL_NAME,
|
24 |
+
model_kwargs={"device": device, "trust_remote_code": True},
|
25 |
encode_kwargs={"normalize_embeddings": True},
|
26 |
)
|
27 |
|
LangChain.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
# chat
|
2 |
from QWEN import ChatQWEN
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
@@ -8,11 +10,14 @@ from langchain_chroma import Chroma
|
|
8 |
|
9 |
|
10 |
def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
|
|
|
|
|
|
|
11 |
# setup embeddings
|
12 |
embeddings = HuggingFaceEmbeddings(
|
13 |
model_name=MODEL_NAME,
|
14 |
model_kwargs={
|
15 |
-
"device":
|
16 |
"trust_remote_code": True,
|
17 |
},
|
18 |
encode_kwargs={"normalize_embeddings": True},
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
# chat
|
4 |
from QWEN import ChatQWEN
|
5 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
10 |
|
11 |
|
12 |
def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
|
13 |
+
# Check if CUDA is available
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
# setup embeddings
|
17 |
embeddings = HuggingFaceEmbeddings(
|
18 |
model_name=MODEL_NAME,
|
19 |
model_kwargs={
|
20 |
+
"device": device,
|
21 |
"trust_remote_code": True,
|
22 |
},
|
23 |
encode_kwargs={"normalize_embeddings": True},
|