Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import sys | |
| sys.path.append("..") | |
| import json | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union | |
| import yaml | |
| from pydantic import Field | |
| from ....base import BotBase | |
| from .formatter import FStringFormatter, JinjiaFormatter | |
| from .parser import BaseOutputParser, DictParser, ListParser, StrParser | |
| DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { | |
| "f-string": FStringFormatter(), | |
| "jinja2": JinjiaFormatter(), | |
| } | |
| _OUTPUT_PARSER = { | |
| "StrParser": StrParser, | |
| "ListParser": ListParser, | |
| "DictParser": DictParser, | |
| } | |
| def check_valid_template( | |
| template: str, template_format: str, input_variables: List[str] | |
| ) -> None: | |
| """Check that template string is valid.""" | |
| if template_format not in DEFAULT_FORMATTER_MAPPING: | |
| valid_formats = list(DEFAULT_FORMATTER_MAPPING) | |
| raise ValueError( | |
| f"Invalid template format. Got `{template_format}`;" | |
| f" should be one of {valid_formats}" | |
| ) | |
| try: | |
| formatter = DEFAULT_FORMATTER_MAPPING[template_format] | |
| formatter.validate(template, input_variables) | |
| except KeyError as e: | |
| raise ValueError( | |
| "Invalid prompt schema; check for mismatched or missing input parameters. " | |
| + str(e) | |
| ) | |
| def _get_jinja2_variables_from_template(template: str) -> Set[str]: | |
| try: | |
| from jinja2 import Environment, meta | |
| except ImportError: | |
| raise ImportError( | |
| "jinja2 not installed, which is needed to use the jinja2_formatter. " | |
| "Please install it with `pip install jinja2`." | |
| ) | |
| env = Environment() | |
| ast = env.parse(template) | |
| variables = meta.find_undeclared_variables(ast) | |
| return variables | |
| class BasePromptTemplate(BotBase, ABC): | |
| """Base class for all prompt templates, returning a prompt.""" | |
| input_variables: List[str] | |
| """A list of the names of the variables the prompt template expects.""" | |
| output_parser: Optional[BaseOutputParser] = None | |
| """How to parse the output of calling an LLM on this formatted prompt.""" | |
| partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( | |
| default_factory=dict | |
| ) | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = "forbid" | |
| arbitrary_types_allowed = True | |
| def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: | |
| """Return a partial of the prompt template.""" | |
| prompt_dict = self.__dict__.copy() | |
| prompt_dict["input_variables"] = list( | |
| set(self.input_variables).difference(kwargs) | |
| ) | |
| prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} | |
| return type(self)(**prompt_dict) | |
| def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: | |
| # Get partial params: | |
| partial_kwargs = { | |
| k: v if isinstance(v, str) else v() | |
| for k, v in self.partial_variables.items() | |
| } | |
| return {**partial_kwargs, **kwargs} | |
| def format(self, **kwargs: Any) -> str: | |
| """Format the prompt with the inputs. | |
| Args: | |
| kwargs: Any arguments to be passed to the prompt template. | |
| Returns: | |
| A formatted string. | |
| Example: | |
| .. code-block:: python | |
| prompt.format(variable1="foo") | |
| """ | |
| def save(self, file_path: Union[Path, str]) -> None: | |
| """Save the prompt. | |
| Args: | |
| file_path: Path to directory to save prompt to. | |
| Example: | |
| .. code-block:: python | |
| prompt.save(file_path="path/prompt.yaml") | |
| """ | |
| if self.partial_variables: | |
| raise ValueError("Cannot save prompt with partial variables.") | |
| # Convert file to Path object. | |
| if isinstance(file_path, str): | |
| save_path = Path(file_path) | |
| else: | |
| save_path = file_path | |
| directory_path = save_path.parent | |
| directory_path.mkdir(parents=True, exist_ok=True) | |
| # Fetch dictionary to save | |
| prompt_dict = self.dict() | |
| if save_path.suffix == ".json": | |
| with open(file_path, "w") as f: | |
| json.dump(prompt_dict, f, indent=4) | |
| elif save_path.suffix == ".yaml": | |
| with open(file_path, "w") as f: | |
| yaml.dump(prompt_dict, f, default_flow_style=False) | |
| else: | |
| raise ValueError(f"{save_path} must be json or yaml") | |
| def from_template(cls, template: str, **kwargs: Any) -> BasePromptTemplate: | |
| """Create a prompt from a template.""" | |
| def from_config(cls, config: Dict) -> BasePromptTemplate: | |
| """Create a prompt from config.""" | |