Spaces:
Runtime error
Runtime error
Upload RAG_class.py
Browse files- RAG_class.py +71 -0
RAG_class.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#import recurive textsplitter
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
import chromadb
|
5 |
+
import uuid
|
6 |
+
import os
|
7 |
+
|
8 |
+
class RAG_1177:
|
9 |
+
def __init__(self):
|
10 |
+
self.db_name = "RAG_1177"
|
11 |
+
|
12 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500,chunk_overlap=500,length_function=len)
|
13 |
+
self.model = SentenceTransformer('KBLab/sentence-bert-swedish-cased')
|
14 |
+
self.client = chromadb.PersistentClient(path="RAG_1177_db")
|
15 |
+
self.db = self.client.get_or_create_collection(self.db_name)
|
16 |
+
|
17 |
+
self.url_list_path = "all_urls_list.txt"
|
18 |
+
self.text_folder = "scraped_texts/"
|
19 |
+
|
20 |
+
def chunk_text_file(self, file_name):
|
21 |
+
file_name = self.text_folder + file_name
|
22 |
+
with open(file_name, 'r', encoding='utf-8') as f:
|
23 |
+
text = f.read()
|
24 |
+
chunks = self.text_splitter.create_documents([text])
|
25 |
+
#append chunks as elements in a list
|
26 |
+
chunks = [chunk.page_content for chunk in chunks]
|
27 |
+
return chunks
|
28 |
+
|
29 |
+
def get_file_names(self, folder_path):
|
30 |
+
doc_list = os.listdir(folder_path)
|
31 |
+
doc_list = sorted(doc_list, key=lambda x: int(x.split('-')[-1].split('.')[0]))
|
32 |
+
return doc_list
|
33 |
+
|
34 |
+
def get_embeddings(self, text):
|
35 |
+
embeddings = self.model.encode(text)
|
36 |
+
return (embeddings.tolist())
|
37 |
+
|
38 |
+
def get_url(self, url_index):
|
39 |
+
with open(self.url_list_path, 'r') as f:
|
40 |
+
urls = f.readlines()
|
41 |
+
return urls[url_index].strip()
|
42 |
+
|
43 |
+
def get_ids(self, num_ids):
|
44 |
+
ids = [str(uuid.uuid4()) for _ in range(num_ids)]
|
45 |
+
return ids
|
46 |
+
|
47 |
+
def get_url_dict(self, url, integer):
|
48 |
+
url_list = [{"url": url} for _ in range(integer)]
|
49 |
+
return url_list
|
50 |
+
|
51 |
+
def delete_collection(self):
|
52 |
+
self.client.delete_collection(self.db_name)
|
53 |
+
return
|
54 |
+
|
55 |
+
def retrieve(self, query, num_results):
|
56 |
+
query_emb = self.get_embeddings(query)
|
57 |
+
result = self.db.query(query_embeddings=query_emb, n_results=num_results, include=['documents', 'metadatas'])
|
58 |
+
result_urls = result['metadatas'][0]
|
59 |
+
result_docs = result['documents'][0]
|
60 |
+
|
61 |
+
url_list = set([item['url'] for item in result_urls])
|
62 |
+
result_urls = "Läs mer på:\n"
|
63 |
+
for i, url in enumerate(url_list, start=1):
|
64 |
+
result_urls += f"{i}: {url}\n"
|
65 |
+
|
66 |
+
return result_docs, result_urls
|
67 |
+
|
68 |
+
def insert(self,docs, emb, urls, ids):
|
69 |
+
self.db.add(documents=docs, embeddings=emb, metadatas=urls, ids=ids)
|
70 |
+
return
|
71 |
+
|