hassano94 commited on
Commit
c02fb77
·
verified ·
1 Parent(s): 9640899

Upload RAG_class.py

Browse files
Files changed (1) hide show
  1. RAG_class.py +71 -0
RAG_class.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import recurive textsplitter
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ import chromadb
5
+ import uuid
6
+ import os
7
+
8
+ class RAG_1177:
9
+ def __init__(self):
10
+ self.db_name = "RAG_1177"
11
+
12
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500,chunk_overlap=500,length_function=len)
13
+ self.model = SentenceTransformer('KBLab/sentence-bert-swedish-cased')
14
+ self.client = chromadb.PersistentClient(path="RAG_1177_db")
15
+ self.db = self.client.get_or_create_collection(self.db_name)
16
+
17
+ self.url_list_path = "all_urls_list.txt"
18
+ self.text_folder = "scraped_texts/"
19
+
20
+ def chunk_text_file(self, file_name):
21
+ file_name = self.text_folder + file_name
22
+ with open(file_name, 'r', encoding='utf-8') as f:
23
+ text = f.read()
24
+ chunks = self.text_splitter.create_documents([text])
25
+ #append chunks as elements in a list
26
+ chunks = [chunk.page_content for chunk in chunks]
27
+ return chunks
28
+
29
+ def get_file_names(self, folder_path):
30
+ doc_list = os.listdir(folder_path)
31
+ doc_list = sorted(doc_list, key=lambda x: int(x.split('-')[-1].split('.')[0]))
32
+ return doc_list
33
+
34
+ def get_embeddings(self, text):
35
+ embeddings = self.model.encode(text)
36
+ return (embeddings.tolist())
37
+
38
+ def get_url(self, url_index):
39
+ with open(self.url_list_path, 'r') as f:
40
+ urls = f.readlines()
41
+ return urls[url_index].strip()
42
+
43
+ def get_ids(self, num_ids):
44
+ ids = [str(uuid.uuid4()) for _ in range(num_ids)]
45
+ return ids
46
+
47
+ def get_url_dict(self, url, integer):
48
+ url_list = [{"url": url} for _ in range(integer)]
49
+ return url_list
50
+
51
+ def delete_collection(self):
52
+ self.client.delete_collection(self.db_name)
53
+ return
54
+
55
+ def retrieve(self, query, num_results):
56
+ query_emb = self.get_embeddings(query)
57
+ result = self.db.query(query_embeddings=query_emb, n_results=num_results, include=['documents', 'metadatas'])
58
+ result_urls = result['metadatas'][0]
59
+ result_docs = result['documents'][0]
60
+
61
+ url_list = set([item['url'] for item in result_urls])
62
+ result_urls = "Läs mer på:\n"
63
+ for i, url in enumerate(url_list, start=1):
64
+ result_urls += f"{i}: {url}\n"
65
+
66
+ return result_docs, result_urls
67
+
68
+ def insert(self,docs, emb, urls, ids):
69
+ self.db.add(documents=docs, embeddings=emb, metadatas=urls, ids=ids)
70
+ return
71
+