SOAPAssistV00 / gpt_index /token_counter /mock_chain_wrapper.py
AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
"""Mock chain wrapper."""
from typing import Any, Dict, Optional
from langchain.llms.base import BaseLLM
from gpt_index.constants import NUM_OUTPUTS
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.prompt_type import PromptType
from gpt_index.token_counter.utils import (
mock_extract_keywords_response,
mock_extract_kg_triplets_response,
)
from gpt_index.utils import globals_helper
# TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py
def _mock_summary_predict(max_tokens: int, prompt_args: Dict) -> str:
"""Mock summary predict."""
# tokens in response shouldn't be larger than tokens in `context_str`
num_text_tokens = len(globals_helper.tokenizer(prompt_args["context_str"]))
token_limit = min(num_text_tokens, max_tokens)
return " ".join(["summary"] * token_limit)
def _mock_insert_predict() -> str:
"""Mock insert predict."""
return "ANSWER: 1"
def _mock_query_select() -> str:
"""Mock query select."""
return "ANSWER: 1"
def _mock_query_select_multiple(num_chunks: int) -> str:
"""Mock query select."""
nums_str = ", ".join([str(i) for i in range(num_chunks)])
return f"ANSWER: {nums_str}"
def _mock_answer(max_tokens: int, prompt_args: Dict) -> str:
"""Mock answer."""
# tokens in response shouldn't be larger than tokens in `text`
num_ctx_tokens = len(globals_helper.tokenizer(prompt_args["context_str"]))
token_limit = min(num_ctx_tokens, max_tokens)
return " ".join(["answer"] * token_limit)
def _mock_refine(max_tokens: int, prompt: Prompt, prompt_args: Dict) -> str:
"""Mock refine."""
# tokens in response shouldn't be larger than tokens in
# `existing_answer` + `context_msg`
# NOTE: if existing_answer is not in prompt_args, we need to get it from the prompt
if "existing_answer" not in prompt_args:
existing_answer = prompt.partial_dict["existing_answer"]
else:
existing_answer = prompt_args["existing_answer"]
num_ctx_tokens = len(globals_helper.tokenizer(prompt_args["context_msg"]))
num_exist_tokens = len(globals_helper.tokenizer(existing_answer))
token_limit = min(num_ctx_tokens + num_exist_tokens, max_tokens)
return " ".join(["answer"] * token_limit)
def _mock_keyword_extract(prompt_args: Dict) -> str:
"""Mock keyword extract."""
return mock_extract_keywords_response(prompt_args["text"])
def _mock_query_keyword_extract(prompt_args: Dict) -> str:
"""Mock query keyword extract."""
return mock_extract_keywords_response(prompt_args["question"])
def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) -> str:
"""Mock knowledge graph triplet extract."""
return mock_extract_kg_triplets_response(
prompt_args["text"], max_triplets=max_triplets
)
class MockLLMPredictor(LLMPredictor):
"""Mock LLM Predictor."""
def __init__(
self, max_tokens: int = NUM_OUTPUTS, llm: Optional[BaseLLM] = None
) -> None:
"""Initialize params."""
super().__init__(llm)
# NOTE: don't call super, we don't want to instantiate LLM
self.max_tokens = max_tokens
self._total_tokens_used = 0
self.flag = True
self._last_token_usage = None
def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
"""Mock predict."""
prompt_str = prompt.prompt_type
if prompt_str == PromptType.SUMMARY:
return _mock_summary_predict(self.max_tokens, prompt_args)
elif prompt_str == PromptType.TREE_INSERT:
return _mock_insert_predict()
elif prompt_str == PromptType.TREE_SELECT:
return _mock_query_select()
elif prompt_str == PromptType.TREE_SELECT_MULTIPLE:
return _mock_query_select_multiple(prompt_args["num_chunks"])
elif prompt_str == PromptType.REFINE:
return _mock_refine(self.max_tokens, prompt, prompt_args)
elif prompt_str == PromptType.QUESTION_ANSWER:
return _mock_answer(self.max_tokens, prompt_args)
elif prompt_str == PromptType.KEYWORD_EXTRACT:
return _mock_keyword_extract(prompt_args)
elif prompt_str == PromptType.QUERY_KEYWORD_EXTRACT:
return _mock_query_keyword_extract(prompt_args)
elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT:
return _mock_knowledge_graph_triplet_extract(
prompt_args, prompt.partial_dict.get("max_knowledge_triplets", 2)
)
else:
raise ValueError("Invalid prompt type.")