|
import type {} from '@redux-devtools/extension'; |
|
import { humanId } from 'human-id'; |
|
import differenceWith from 'lodash/differenceWith'; |
|
import intersectionWith from 'lodash/intersectionWith'; |
|
import lodashSet from 'lodash/set'; |
|
import { |
|
Connection, |
|
Edge, |
|
EdgeChange, |
|
Node, |
|
NodeChange, |
|
OnConnect, |
|
OnEdgesChange, |
|
OnNodesChange, |
|
OnSelectionChangeFunc, |
|
OnSelectionChangeParams, |
|
addEdge, |
|
applyEdgeChanges, |
|
applyNodeChanges, |
|
} from 'reactflow'; |
|
import { create } from 'zustand'; |
|
import { devtools } from 'zustand/middleware'; |
|
import { immer } from 'zustand/middleware/immer'; |
|
import { Operator, SwitchElseTo } from './constant'; |
|
import { NodeData } from './interface'; |
|
import { |
|
duplicateNodeForm, |
|
generateNodeNamesWithIncreasingIndex, |
|
getNodeDragHandle, |
|
getOperatorIndex, |
|
isEdgeEqual, |
|
} from './utils'; |
|
|
|
export type RFState = { |
|
nodes: Node<NodeData>[]; |
|
edges: Edge[]; |
|
selectedNodeIds: string[]; |
|
selectedEdgeIds: string[]; |
|
clickedNodeId: string; |
|
onNodesChange: OnNodesChange; |
|
onEdgesChange: OnEdgesChange; |
|
onConnect: OnConnect; |
|
setNodes: (nodes: Node[]) => void; |
|
setEdges: (edges: Edge[]) => void; |
|
setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void; |
|
updateNodeForm: ( |
|
nodeId: string, |
|
values: any, |
|
path?: (string | number)[], |
|
) => Node[]; |
|
onSelectionChange: OnSelectionChangeFunc; |
|
addNode: (nodes: Node) => void; |
|
getNode: (id?: string | null) => Node<NodeData> | undefined; |
|
addEdge: (connection: Connection) => void; |
|
getEdge: (id: string) => Edge | undefined; |
|
updateFormDataOnConnect: (connection: Connection) => void; |
|
updateSwitchFormData: ( |
|
source: string, |
|
sourceHandle?: string | null, |
|
target?: string | null, |
|
) => void; |
|
deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; |
|
duplicateNode: (id: string, name: string) => void; |
|
deleteEdge: () => void; |
|
deleteEdgeById: (id: string) => void; |
|
deleteNodeById: (id: string) => void; |
|
deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void; |
|
findNodeByName: (operatorName: Operator) => Node | undefined; |
|
updateMutableNodeFormItem: (id: string, field: string, value: any) => void; |
|
getOperatorTypeFromId: (id?: string | null) => string | undefined; |
|
updateNodeName: (id: string, name: string) => void; |
|
generateNodeName: (name: string) => string; |
|
setClickedNodeId: (id?: string) => void; |
|
}; |
|
|
|
|
|
const useGraphStore = create<RFState>()( |
|
devtools( |
|
immer((set, get) => ({ |
|
nodes: [] as Node[], |
|
edges: [] as Edge[], |
|
selectedNodeIds: [] as string[], |
|
selectedEdgeIds: [] as string[], |
|
clickedNodeId: '', |
|
onNodesChange: (changes: NodeChange[]) => { |
|
set({ |
|
nodes: applyNodeChanges(changes, get().nodes), |
|
}); |
|
}, |
|
onEdgesChange: (changes: EdgeChange[]) => { |
|
set({ |
|
edges: applyEdgeChanges(changes, get().edges), |
|
}); |
|
}, |
|
onConnect: (connection: Connection) => { |
|
const { |
|
deletePreviousEdgeOfClassificationNode, |
|
updateFormDataOnConnect, |
|
} = get(); |
|
set({ |
|
edges: addEdge(connection, get().edges), |
|
}); |
|
deletePreviousEdgeOfClassificationNode(connection); |
|
updateFormDataOnConnect(connection); |
|
}, |
|
onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => { |
|
set({ |
|
selectedEdgeIds: edges.map((x) => x.id), |
|
selectedNodeIds: nodes.map((x) => x.id), |
|
}); |
|
}, |
|
setNodes: (nodes: Node[]) => { |
|
set({ nodes }); |
|
}, |
|
setEdges: (edges: Edge[]) => { |
|
set({ edges }); |
|
}, |
|
setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => { |
|
const { edges, setEdges } = get(); |
|
|
|
const previousDownstreamEdges = edges.filter( |
|
(x) => x.source === nodeId, |
|
); |
|
const isDifferent = |
|
previousDownstreamEdges.length !== currentDownstreamEdges.length || |
|
!previousDownstreamEdges.every((x) => |
|
currentDownstreamEdges.some( |
|
(y) => |
|
y.source === x.source && |
|
y.target === x.target && |
|
y.sourceHandle === x.sourceHandle, |
|
), |
|
) || |
|
!currentDownstreamEdges.every((x) => |
|
previousDownstreamEdges.some( |
|
(y) => |
|
y.source === x.source && |
|
y.target === x.target && |
|
y.sourceHandle === x.sourceHandle, |
|
), |
|
); |
|
|
|
const intersectionDownstreamEdges = intersectionWith( |
|
previousDownstreamEdges, |
|
currentDownstreamEdges, |
|
isEdgeEqual, |
|
); |
|
if (isDifferent) { |
|
|
|
const irrelevantEdges = edges.filter((x) => x.source !== nodeId); |
|
|
|
const selfAddedDownstreamEdges = differenceWith( |
|
currentDownstreamEdges, |
|
intersectionDownstreamEdges, |
|
isEdgeEqual, |
|
); |
|
setEdges([ |
|
...irrelevantEdges, |
|
...intersectionDownstreamEdges, |
|
...selfAddedDownstreamEdges, |
|
]); |
|
} |
|
}, |
|
addNode: (node: Node) => { |
|
set({ nodes: get().nodes.concat(node) }); |
|
}, |
|
getNode: (id?: string | null) => { |
|
return get().nodes.find((x) => x.id === id); |
|
}, |
|
getOperatorTypeFromId: (id?: string | null) => { |
|
return get().getNode(id)?.data?.label; |
|
}, |
|
addEdge: (connection: Connection) => { |
|
set({ |
|
edges: addEdge(connection, get().edges), |
|
}); |
|
get().deletePreviousEdgeOfClassificationNode(connection); |
|
|
|
get().updateFormDataOnConnect(connection); |
|
}, |
|
getEdge: (id: string) => { |
|
return get().edges.find((x) => x.id === id); |
|
}, |
|
updateFormDataOnConnect: (connection: Connection) => { |
|
const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } = |
|
get(); |
|
const { source, target, sourceHandle } = connection; |
|
const operatorType = getOperatorTypeFromId(source); |
|
if (source) { |
|
switch (operatorType) { |
|
case Operator.Relevant: |
|
updateNodeForm(source, { [sourceHandle as string]: target }); |
|
break; |
|
case Operator.Categorize: |
|
if (sourceHandle) |
|
updateNodeForm(source, target, [ |
|
'category_description', |
|
sourceHandle, |
|
'to', |
|
]); |
|
break; |
|
case Operator.Switch: { |
|
updateSwitchFormData(source, sourceHandle, target); |
|
break; |
|
} |
|
default: |
|
break; |
|
} |
|
} |
|
}, |
|
deletePreviousEdgeOfClassificationNode: (connection: Connection) => { |
|
|
|
const { edges, getOperatorTypeFromId, deleteEdgeById } = get(); |
|
|
|
const anchoredNodes = [ |
|
Operator.Categorize, |
|
Operator.Relevant, |
|
Operator.Switch, |
|
]; |
|
if ( |
|
anchoredNodes.some( |
|
(x) => x === getOperatorTypeFromId(connection.source), |
|
) |
|
) { |
|
const previousEdge = edges.find( |
|
(x) => |
|
x.source === connection.source && |
|
x.sourceHandle === connection.sourceHandle && |
|
x.target !== connection.target, |
|
); |
|
if (previousEdge) { |
|
deleteEdgeById(previousEdge.id); |
|
} |
|
} |
|
}, |
|
duplicateNode: (id: string, name: string) => { |
|
const { getNode, addNode, generateNodeName } = get(); |
|
const node = getNode(id); |
|
const position = { |
|
x: (node?.position?.x || 0) + 50, |
|
y: (node?.position?.y || 0) + 50, |
|
}; |
|
|
|
addNode({ |
|
...(node || {}), |
|
data: { |
|
...duplicateNodeForm(node?.data), |
|
name: generateNodeName(name), |
|
}, |
|
selected: false, |
|
dragging: false, |
|
id: `${node?.data?.label}:${humanId()}`, |
|
position, |
|
dragHandle: getNodeDragHandle(node?.data?.label), |
|
}); |
|
}, |
|
deleteEdge: () => { |
|
const { edges, selectedEdgeIds } = get(); |
|
set({ |
|
edges: edges.filter((edge) => |
|
selectedEdgeIds.every((x) => x !== edge.id), |
|
), |
|
}); |
|
}, |
|
deleteEdgeById: (id: string) => { |
|
const { |
|
edges, |
|
updateNodeForm, |
|
getOperatorTypeFromId, |
|
updateSwitchFormData, |
|
} = get(); |
|
const currentEdge = edges.find((x) => x.id === id); |
|
|
|
if (currentEdge) { |
|
const { source, sourceHandle } = currentEdge; |
|
const operatorType = getOperatorTypeFromId(source); |
|
|
|
switch (operatorType) { |
|
case Operator.Relevant: |
|
updateNodeForm(source, { |
|
[sourceHandle as string]: undefined, |
|
}); |
|
break; |
|
case Operator.Categorize: |
|
if (sourceHandle) |
|
updateNodeForm(source, undefined, [ |
|
'category_description', |
|
sourceHandle, |
|
'to', |
|
]); |
|
break; |
|
case Operator.Switch: { |
|
updateSwitchFormData(source, sourceHandle, undefined); |
|
break; |
|
} |
|
default: |
|
break; |
|
} |
|
} |
|
set({ |
|
edges: edges.filter((edge) => edge.id !== id), |
|
}); |
|
}, |
|
deleteEdgeBySourceAndSourceHandle: ({ |
|
source, |
|
sourceHandle, |
|
}: Partial<Connection>) => { |
|
const { edges } = get(); |
|
const nextEdges = edges.filter( |
|
(edge) => |
|
edge.source !== source || edge.sourceHandle !== sourceHandle, |
|
); |
|
set({ |
|
edges: nextEdges, |
|
}); |
|
}, |
|
deleteNodeById: (id: string) => { |
|
const { nodes, edges } = get(); |
|
set({ |
|
nodes: nodes.filter((node) => node.id !== id), |
|
edges: edges |
|
.filter((edge) => edge.source !== id) |
|
.filter((edge) => edge.target !== id), |
|
}); |
|
}, |
|
findNodeByName: (name: Operator) => { |
|
return get().nodes.find((x) => x.data.label === name); |
|
}, |
|
updateNodeForm: ( |
|
nodeId: string, |
|
values: any, |
|
path: (string | number)[] = [], |
|
) => { |
|
const nextNodes = get().nodes.map((node) => { |
|
if (node.id === nodeId) { |
|
let nextForm: Record<string, unknown> = { ...node.data.form }; |
|
if (path.length === 0) { |
|
nextForm = Object.assign(nextForm, values); |
|
} else { |
|
lodashSet(nextForm, path, values); |
|
} |
|
return { |
|
...node, |
|
data: { |
|
...node.data, |
|
form: nextForm, |
|
}, |
|
} as any; |
|
} |
|
|
|
return node; |
|
}); |
|
set({ |
|
nodes: nextNodes, |
|
}); |
|
|
|
return nextNodes; |
|
}, |
|
updateSwitchFormData: (source, sourceHandle, target) => { |
|
const { updateNodeForm } = get(); |
|
if (sourceHandle) { |
|
if (sourceHandle === SwitchElseTo) { |
|
updateNodeForm(source, target, [SwitchElseTo]); |
|
} else { |
|
const operatorIndex = getOperatorIndex(sourceHandle); |
|
if (operatorIndex) { |
|
updateNodeForm(source, target, [ |
|
'conditions', |
|
Number(operatorIndex) - 1, |
|
'to', |
|
]); |
|
} |
|
} |
|
} |
|
}, |
|
updateMutableNodeFormItem: (id: string, field: string, value: any) => { |
|
const { nodes } = get(); |
|
const idx = nodes.findIndex((x) => x.id === id); |
|
if (idx) { |
|
lodashSet(nodes, [idx, 'data', 'form', field], value); |
|
} |
|
}, |
|
updateNodeName: (id, name) => { |
|
if (id) { |
|
set({ |
|
nodes: get().nodes.map((node) => { |
|
if (node.id === id) { |
|
node.data.name = name; |
|
} |
|
|
|
return node; |
|
}), |
|
}); |
|
} |
|
}, |
|
setClickedNodeId: (id?: string) => { |
|
set({ clickedNodeId: id }); |
|
}, |
|
generateNodeName: (name: string) => { |
|
const { nodes } = get(); |
|
|
|
return generateNodeNamesWithIncreasingIndex(name, nodes); |
|
}, |
|
})), |
|
{ name: 'graph' }, |
|
), |
|
); |
|
|
|
export default useGraphStore; |
|
|