dify / api /factories /variable_factory.py
CatPtain's picture
Upload 697 files
20f348c verified
from collections.abc import Mapping, Sequence
from typing import Any, cast
from uuid import uuid4
from configs import dify_config
from core.file import File
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
class InvalidSelectorError(ValueError):
pass
class UnsupportedSegmentTypeError(Exception):
pass
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
NoneSegment: NoneVariable,
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
"""
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
# FIXME: using Any here, fix it later
result: Any
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
return ObjectSegment(value=value)
if isinstance(value, File):
return FileSegment(value=value)
if isinstance(value, list):
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER:
return ArrayNumberSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
case SegmentType.FILE:
return ArrayFileSegment(value=value)
case SegmentType.NONE:
return ArrayAnySegment(value=value)
case _:
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")
def segment_to_variable(
*,
segment: Segment,
selector: Sequence[str],
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
return segment
name = name or selector[-1]
id = id or str(uuid4())
segment_type = type(segment)
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
Variable,
variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=selector,
),
)