Spaces:
Build error
Build error
| from langchain_community.document_loaders import DirectoryLoader, JSONLoader, UnstructuredMarkdownLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter, MarkdownHeaderTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from pathlib import Path | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import config as cfg | |
| class LocalRAG: | |
| def __init__(self, | |
| rag_top_k=3, | |
| doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择 | |
| vector_db_path="rag/vector_db/", | |
| embed_model=cfg.DEFAULT_EMBEDDING_MODEL | |
| ): | |
| self.rag_top_k = rag_top_k | |
| self.doc_dir = doc_dir # 本地知识库的文档目录 | |
| self.vector_db_path = vector_db_path # 向量数据库存储路径 | |
| self.embed_model = embed_model | |
| self.build_vector_db() | |
| def build_vector_db(self): | |
| # 加载文档(支持PDF、TXT、DOCX) | |
| if isinstance(self.doc_dir, list): | |
| general_docs = [] | |
| json_docs = [] | |
| md_docs = [] | |
| for doc_dir in self.doc_dir: | |
| # 处理一般文件,txt等 | |
| loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*" | |
| tmp_docs = loader.load() | |
| general_docs.extend(tmp_docs) | |
| # 额外处理json文件 | |
| for json_file in Path(doc_dir).rglob("*.json"): | |
| loader = JSONLoader( | |
| file_path=str(json_file), | |
| jq_schema=".[] | {spk: .spk, text: .text}", | |
| text_content=False) | |
| data = loader.load() | |
| for iidx in range(len(data)): | |
| data[iidx].page_content = bytes(data[iidx].page_content, "utf-8").decode("unicode_escape") | |
| json_docs.extend(data) | |
| # 额外处理md文件 | |
| headers_to_split_on = [ | |
| ("#", "Header 1"), | |
| ("##", "Header 2"), | |
| ("###", "Header 3"), | |
| ] | |
| for md_file in Path(doc_dir).rglob("*.md"): | |
| with open(md_file, 'r') as f: | |
| content = f.read() | |
| # 定义拆分器,拆分markdown内容 | |
| markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
| md_header_splits = markdown_splitter.split_text(content) | |
| md_docs.extend(md_header_splits) | |
| # loader = UnstructuredMarkdownLoader(md_file, mode="elements") | |
| # data = loader.load() | |
| # docs.extend(data) | |
| else: | |
| loader = DirectoryLoader(self.doc_dir, glob="**/*.*") | |
| docs = loader.load() | |
| # 文本分块 | |
| if len(general_docs) > 0: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50 | |
| ) | |
| chunks = text_splitter.split_documents(docs) | |
| else: | |
| chunks = json_docs + md_docs | |
| # 生成向量并构建FAISS数据库 | |
| embeddings = HuggingFaceEmbeddings(model_name=self.embed_model) | |
| self.vector_db = FAISS.from_documents(chunks, embeddings) | |
| self.vector_db.save_local(self.vector_db_path) | |
| def reload_knowledge_base(self, target_doc_dir): | |
| self.doc_dir = target_doc_dir | |
| self.build_vector_db() | |
| # def reset(self): | |
| # self.vector_db = None | |
| class LocalRAG_new: | |
| def __init__(self, | |
| rag_top_k=3, | |
| doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择 | |
| vector_db_path="rag/vector_db/", | |
| embed_model_path="princeton-nlp/sup-simcse-bert-large-uncased", | |
| device=torch.device('cuda:2')): | |
| self.rag_top_k = rag_top_k | |
| self.doc_dir = doc_dir # 本地知识库的文档目录 | |
| self.kb_name = '_'.join([Path(doc_dir[i]).name for i in range(len(doc_dir))]) | |
| self.embed_model_name = Path(embed_model_path).name | |
| self.vector_db_path = vector_db_path # 向量数据库存储路径 | |
| self.embed_model = embed_model_path | |
| self.device = device | |
| # 加载分词器和模型 | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model) | |
| self.embed_model = AutoModel.from_pretrained(self.embed_model).to(device) | |
| self.vector_db = None | |
| self._vector_db = None | |
| self.build_vector_db() | |
| class VectorDB: | |
| def __init__(self, rag): | |
| self._data = rag._vector_db | |
| self.rag = rag | |
| def similarity_search(self, query, k): | |
| # 可能的输入预处理,暂无 | |
| # query = input_optimize(query) | |
| # 计算query的embedding并与库中比较 | |
| with torch.inference_mode(): | |
| query_token = self.rag.tokenizer(query, padding=True, truncation=False, return_tensors="pt").to(self.rag.device) | |
| query_embed = self.rag.embed_model(**query_token)['last_hidden_state'].mean(dim=1) | |
| sim_query = F.cosine_similarity(query_embed.repeat(len(self._data['embeds']), 1), self._data['embeds'], dim=1, eps=1e-8) | |
| max_ids_query = torch.argsort(sim_query, descending=True)[:self.rag.rag_top_k].cpu().detach().numpy() | |
| return list(zip(np.array(self._data['chunks'])[max_ids_query], sim_query[max_ids_query])) | |
| def build_vector_db(self): | |
| # 加载文档(支持PDF、TXT、DOCX) | |
| if isinstance(self.doc_dir, list): | |
| docs = [] | |
| for doc_dir in self.doc_dir: | |
| loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*" | |
| tmp_docs = loader.load() | |
| docs.extend(tmp_docs) | |
| # # 额外处理json文件 | |
| # for json_file in Path(doc_dir).rglob("*.json"): | |
| # loader = JSONLoader( | |
| # file_path=str(json_file), | |
| # jq_schema='.messages[].content', | |
| # text_content=False) | |
| # data = loader.load() | |
| # 额外处理md文件 | |
| headers_to_split_on = [ | |
| ("#", "Header 1"), | |
| ("##", "Header 2"), | |
| ("###", "Header 3"), | |
| ] | |
| for md_file in Path(doc_dir).rglob("*.md"): | |
| with open(md_file, 'r') as f: | |
| content = f.read() | |
| # 定义拆分器,拆分markdown内容 | |
| markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
| md_header_splits = markdown_splitter.split_text(content) | |
| docs.extend(md_header_splits) | |
| # loader = UnstructuredMarkdownLoader(md_file, mode="elements") | |
| # data = loader.load() | |
| # docs.extend(data) | |
| else: | |
| loader = DirectoryLoader(self.doc_dir, glob="**/*.*") | |
| docs = loader.load() | |
| # 文本分块 | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50 | |
| ) | |
| chunks = text_splitter.split_documents(docs) | |
| with torch.inference_mode(): | |
| chunk_and_embed = [] | |
| for chunk in chunks: | |
| chunk_token = self.tokenizer(chunk.page_content, padding=True, truncation=False, return_tensors="pt").to(self.device) | |
| chunk_embed = self.embed_model(**chunk_token)['last_hidden_state'].mean(dim=1) | |
| chunk_and_embed.append((chunk, chunk_embed)) | |
| all_chunks, all_embeds = list(zip(*chunk_and_embed)) | |
| all_chunks, all_embeds = list(all_chunks), list(all_embeds) | |
| all_embeds = torch.cat(all_embeds, dim=0) | |
| self._vector_db = {'chunks': all_chunks, 'embeds': all_embeds} | |
| self.vector_db = self.VectorDB(self) | |
| torch.save(self.vector_db, str(Path(self.vector_db_path) / f'{self.kb_name}_{self.embed_model_name}.pt')) | |
| def reload_knowledge_base(self, target_doc_dir): | |
| self.doc_dir = target_doc_dir | |
| self.build_vector_db() | |
| # def reset(self): | |
| # self.vector_db = None | |
| class CosPlayer: | |
| def __init__(self, description_file): | |
| self.update(description_file) | |
| def update(self, description_file): | |
| self.description_file = description_file | |
| with open(description_file, 'r') as f: | |
| all_lines = f.readlines() | |
| self.core_setting = ''.join(all_lines) | |
| self.characters_dir = Path(description_file).parent | |
| self.prologue_file = self.description_file.replace('/characters/', '/prologues/') | |
| if not Path(self.prologue_file).exists(): | |
| self.prologue_file = None | |
| def get_all_characters(self): | |
| return [str(i) for i in list(self.characters_dir.rglob('*.txt'))] | |
| def get_core_setting(self): | |
| return self.core_setting | |
| def get_prologue(self): | |
| if self.prologue_file: | |
| with open(self.prologue_file, 'r') as f: | |
| all_lines = f.readlines() | |
| return ''.join(all_lines) | |
| else: | |
| return None | |
| if __name__ == "__main__": | |
| rag = LocalRAG() | |
| # # rag.build_vector_db() | |
| # doc_dir = "rag/debug" | |
| # loader = DirectoryLoader(doc_dir, glob="**/*.*") | |
| # docs = loader.load() | |
| # # 文本分块 | |
| # text_splitter = RecursiveCharacterTextSplitter( | |
| # chunk_size=500, | |
| # chunk_overlap=50 | |
| # ) | |
| # chunks = text_splitter.split_documents(docs) | |
| # pass |