File size: 12,003 Bytes
569cdb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
import base64
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import requests
from erniebot_agent.utils.exception import BaizhongError
from erniebot_agent.utils.logging import logger
from tqdm import tqdm
from .document import Document
class BaizhongSearch:
def __init__(
self,
base_url: str,
project_name: Optional[str] = None,
remark: Optional[str] = None,
project_id: Optional[int] = None,
max_seq_length: int = 512,
) -> None:
self.base_url = base_url
self.max_seq_length = max_seq_length
if project_id is not None:
logger.info(f"Loading existing project with `project_id={project_id}`")
self.project_id = project_id
elif project_name is not None:
logger.info("Creating new project and schema")
self.index = self.create_project(project_name, remark)
logger.info("Project creation succeeded")
self.project_id = self.index["result"]["projectId"]
self.create_schema()
logger.info("Schema creation succeeded")
else:
raise BaizhongError("You must provide either a `project_name` or a `project_id`.")
def create_project(self, project_name: str, remark: Optional[str] = None):
"""
Create a project using the Baizhong API.
Returns:
dict: A dictionary containing information about the created project.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {
"projectName": project_name,
"remark": remark,
}
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project/add", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return result
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def create_schema(self):
"""
Create a schema for a project using the Baizhong API.
Returns:
dict: A dictionary containing information about the created schema.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {
"projectId": self.project_id,
"schemaJson": {
"paraSize": self.max_seq_length,
"dataSegmentationMod": "neisou",
"storeType": "ElasticSearch",
"properties": {
"title": {"type": "text", "shortindex": True},
"content_se": {"type": "text", "longindex": True},
},
},
}
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/create", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return res.json()
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def update_schema(
self,
):
"""
Update the schema for a project using the Baizhong API.
Returns:
dict: A dictionary containing information about the updated schema.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {
"projectId": self.project_id,
"schemaJson": {
"paraSize": self.max_seq_length,
"dataSegmentationMod": "neisou",
"storeType": "ElasticSearch",
"properties": {
"title": {"type": "text", "shortindex": True},
"content_se": {"type": "text", "longindex": True},
},
},
}
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/update", json=json_data)
status_code = res.status_code
if status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return result
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def search(self, query: str, top_k: int = 10, filters: Optional[Dict[str, Any]] = None):
"""
Perform a search using the Baizhong common search API.
Args:
query (str): The search query.
top_k (int, optional): The number of top results to retrieve (default is 10).
filters (Optional[Dict[str, Any]], optional): Additional filters to apply to the search query
(default is None).
Returns:
List[Dict[str, Any]]: A list of dictionaries containing search results.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {
"query": query,
"projectId": self.project_id,
"size": top_k,
}
if filters is not None:
filterConditions = {"filterConditions": {"bool": {"filter": {"match": filters}}}}
json_data.update(filterConditions)
res = requests.post(f"{self.base_url}/baizhong/common-search/v2/search", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
list_data = []
for item in result["hits"]:
content = item["_source"]["doc"]
content = base64.b64decode(content).decode("utf-8")
json_data = json.loads(content)
list_data.append(json_data)
return list_data
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def add_documents(self, documents: List[Document], batch_size: int = 1, thread_count: int = 1):
"""
Add a batch of documents to the Baizhong system using multi-threading.
Args:
documents (List[Document]): A list of Document objects to be added.
batch_size (int, optional): The size of each batch of documents (defaults to 1).
thread_count (int, optional): The number of threads to use for concurrent document addition
(defaults to 1).
Returns:
List[Union[None, Exception]]: A list of results from the document addition process.
Note:
This function uses multi-threading to improve the efficiency of adding a large number of
documents.
"""
if type(documents[0]) == Document:
list_dicts = [item.to_dict() for item in documents]
all_data = []
for i in tqdm(range(0, len(list_dicts), batch_size)):
batch_data = list_dicts[i : i + batch_size]
all_data.append(batch_data)
with ThreadPoolExecutor(max_workers=thread_count) as executor:
res = executor.map(self._add_documents, all_data)
return list(res)
def get_document_by_id(self, doc_id):
"""
Retrieve a document from the Baizhong system by its ID.
Args:
doc_id: The ID of the document to retrieve.
Returns:
dict: A dictionary containing information about the retrieved document.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": [doc_id]}
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/get", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return result
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def delete_documents(
self,
ids: Optional[List[str]] = None,
):
"""
Delete documents from the Baizhong system.
Args:
ids (Optional[List[str]], optional): A list of document IDs to delete. If not provided,
all documents will be deleted.
Returns:
dict: A dictionary containing information about the deletion process.
Raises:
NotImplementedError: If the deletion of all documents is attempted, this exception is raised
as it is not yet implemented.
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data: Dict[str, Any] = {"projectId": self.project_id, "followIndexFlag": True}
if ids is not None:
json_data["dataBody"] = ids
else:
# TODO: delete all documents
raise NotImplementedError
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/delete", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return result
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
def _add_documents(self, documents: List[Dict[str, Any]]):
"""
Internal method to add a batch of documents to the Baizhong system.
Args:
documents (List[Dict[str, Any]]): A list of dictionaries representing documents to be added.
Returns:
dict: A dictionary containing information about the document addition process.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": documents}
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/add", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return result
else:
# TODO(wugaosheng): retry 3 times
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
@classmethod
def delete_project(cls, project_id: int):
"""
Class method to delete a project using the Baizhong API.
Args:
project_id (int): The ID of the project to be deleted.
Returns:
dict: A dictionary containing information about the deletion process.
Raises:
BaizhongError: If the API request fails, this exception is raised with details about the error.
"""
json_data = {"projectId": project_id}
res = requests.post(f"{cls.base_url}/baizhong/web-api/v2/project/delete", json=json_data)
if res.status_code == 200:
result = res.json()
if result["errCode"] != 0:
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
return res.json()
else:
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|