Spaces:
Runtime error
Runtime error
"""Notion reader.""" | |
import logging | |
import os | |
from typing import Any, Dict, List, Optional | |
import requests # type: ignore | |
from gpt_index.readers.base import BaseReader | |
from gpt_index.readers.schema.base import Document | |
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" | |
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" | |
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" | |
SEARCH_URL = "https://api.notion.com/v1/search" | |
# TODO: Notion DB reader coming soon! | |
class NotionPageReader(BaseReader): | |
"""Notion Page reader. | |
Reads a set of Notion pages. | |
Args: | |
integration_token (str): Notion integration token. | |
""" | |
def __init__(self, integration_token: Optional[str] = None) -> None: | |
"""Initialize with parameters.""" | |
if integration_token is None: | |
integration_token = os.getenv(INTEGRATION_TOKEN_NAME) | |
if integration_token is None: | |
raise ValueError( | |
"Must specify `integration_token` or set environment " | |
"variable `NOTION_INTEGRATION_TOKEN`." | |
) | |
self.token = integration_token | |
self.headers = { | |
"Authorization": "Bearer " + self.token, | |
"Content-Type": "application/json", | |
"Notion-Version": "2022-06-28", | |
} | |
def _read_block(self, block_id: str, num_tabs: int = 0) -> str: | |
"""Read a block.""" | |
done = False | |
result_lines_arr = [] | |
cur_block_id = block_id | |
while not done: | |
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |
query_dict: Dict[str, Any] = {} | |
res = requests.request( | |
"GET", block_url, headers=self.headers, json=query_dict | |
) | |
data = res.json() | |
for result in data["results"]: | |
result_type = result["type"] | |
result_obj = result[result_type] | |
cur_result_text_arr = [] | |
if "rich_text" in result_obj: | |
for rich_text in result_obj["rich_text"]: | |
# skip if doesn't have text object | |
if "text" in rich_text: | |
text = rich_text["text"]["content"] | |
prefix = "\t" * num_tabs | |
cur_result_text_arr.append(prefix + text) | |
result_block_id = result["id"] | |
has_children = result["has_children"] | |
if has_children: | |
children_text = self._read_block( | |
result_block_id, num_tabs=num_tabs + 1 | |
) | |
cur_result_text_arr.append(children_text) | |
cur_result_text = "\n".join(cur_result_text_arr) | |
result_lines_arr.append(cur_result_text) | |
if data["next_cursor"] is None: | |
done = True | |
break | |
else: | |
cur_block_id = data["next_cursor"] | |
result_lines = "\n".join(result_lines_arr) | |
return result_lines | |
def read_page(self, page_id: str) -> str: | |
"""Read a page.""" | |
return self._read_block(page_id) | |
def query_database( | |
self, database_id: str, query_dict: Dict[str, Any] = {} | |
) -> List[str]: | |
"""Get all the pages from a Notion database.""" | |
res = requests.post( | |
DATABASE_URL_TMPL.format(database_id=database_id), | |
headers=self.headers, | |
json=query_dict, | |
) | |
data = res.json() | |
page_ids = [] | |
for result in data["results"]: | |
page_id = result["id"] | |
page_ids.append(page_id) | |
return page_ids | |
def search(self, query: str) -> List[str]: | |
"""Search Notion page given a text query.""" | |
done = False | |
next_cursor: Optional[str] = None | |
page_ids = [] | |
while not done: | |
query_dict = { | |
"query": query, | |
} | |
if next_cursor is not None: | |
query_dict["start_cursor"] = next_cursor | |
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) | |
data = res.json() | |
for result in data["results"]: | |
page_id = result["id"] | |
page_ids.append(page_id) | |
if data["next_cursor"] is None: | |
done = True | |
break | |
else: | |
next_cursor = data["next_cursor"] | |
return page_ids | |
def load_data( | |
self, page_ids: List[str] = [], database_id: Optional[str] = None | |
) -> List[Document]: | |
"""Load data from the input directory. | |
Args: | |
page_ids (List[str]): List of page ids to load. | |
Returns: | |
List[Document]: List of documents. | |
""" | |
if not page_ids and not database_id: | |
raise ValueError("Must specify either `page_ids` or `database_id`.") | |
docs = [] | |
if database_id is not None: | |
# get all the pages in the database | |
page_ids = self.query_database(database_id) | |
for page_id in page_ids: | |
page_text = self.read_page(page_id) | |
docs.append(Document(page_text, extra_info={"page_id": page_id})) | |
else: | |
for page_id in page_ids: | |
page_text = self.read_page(page_id) | |
docs.append(Document(page_text, extra_info={"page_id": page_id})) | |
return docs | |
if __name__ == "__main__": | |
reader = NotionPageReader() | |
logging.info(reader.search("What I")) | |