Upload dialog_operators.py with huggingface_hub
Browse files- dialog_operators.py +88 -0
dialog_operators.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dialog Serializers.
|
| 2 |
+
|
| 3 |
+
Dialog serializers are the way to take dialog data and turn it into
|
| 4 |
+
text that can be fed to the model.
|
| 5 |
+
|
| 6 |
+
The format of the dialog is:
|
| 7 |
+
|
| 8 |
+
dialog = [
|
| 9 |
+
{"user": "hello", "system": "hi"},
|
| 10 |
+
{"user": "kkk", "system": ""},
|
| 11 |
+
{"user": "kkk", "system": ""},
|
| 12 |
+
]
|
| 13 |
+
"""
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
|
| 16 |
+
from .formats import SystemFormat
|
| 17 |
+
from .operators import InstanceFieldOperator
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SerializeDialog(InstanceFieldOperator):
|
| 21 |
+
"""Serializes dialog data for feeding into a model.
|
| 22 |
+
|
| 23 |
+
This class takes structured dialog data and converts it into a text format
|
| 24 |
+
according to a specified template. It allows for the inclusion or exclusion
|
| 25 |
+
of system responses and can operate on a per-turn basis or aggregate the entire
|
| 26 |
+
dialog.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
field (str): The field in the input data that contains the dialog.
|
| 30 |
+
to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
|
| 31 |
+
last_user_turn_to_field (Optional[str]): Field to store the last user turn.
|
| 32 |
+
last_system_turn_to_field (Optional[str]): Field to store the last system turn.
|
| 33 |
+
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
format: Optional[SystemFormat] = None
|
| 37 |
+
last_response_to_field: Optional[str] = None
|
| 38 |
+
context_field: Optional[str] = None
|
| 39 |
+
context_seperator: str = " "
|
| 40 |
+
|
| 41 |
+
def standartize_format(self, demo_format):
|
| 42 |
+
turn_format = demo_format.replace("{source}", "{user}")
|
| 43 |
+
turn_format = turn_format.replace("{target}", "{system}")
|
| 44 |
+
return turn_format.replace("{target_prefix}", "")
|
| 45 |
+
|
| 46 |
+
def slice_first_turn(self, turn_format):
|
| 47 |
+
return turn_format[turn_format.index("{user}") :]
|
| 48 |
+
|
| 49 |
+
def slice_last_turn(self, turn_format):
|
| 50 |
+
return turn_format[: turn_format.index("{system}") + len("{system}")]
|
| 51 |
+
|
| 52 |
+
def slice_last_reponse(self, turn_format):
|
| 53 |
+
return turn_format[: turn_format.index("{user}") + len("{user}")]
|
| 54 |
+
|
| 55 |
+
def get_turn_format(self, turn_format, step, length):
|
| 56 |
+
if step == 0:
|
| 57 |
+
turn_format = self.slice_first_turn(turn_format)
|
| 58 |
+
if step == length - 1:
|
| 59 |
+
turn_format = self.slice_last_turn(turn_format)
|
| 60 |
+
if self.last_response_to_field is not None:
|
| 61 |
+
turn_format = self.slice_last_reponse(turn_format)
|
| 62 |
+
return turn_format
|
| 63 |
+
|
| 64 |
+
def get_general_turn_format(self, instance):
|
| 65 |
+
general_format = (
|
| 66 |
+
instance["recipe_metadata"]["format"]
|
| 67 |
+
if self.format is None
|
| 68 |
+
else self.format
|
| 69 |
+
)
|
| 70 |
+
return self.standartize_format(general_format.demo_format)
|
| 71 |
+
|
| 72 |
+
def process_instance_value(
|
| 73 |
+
self, structred_dialog: List[Dict[str, str]], instance: Dict[str, Any]
|
| 74 |
+
):
|
| 75 |
+
dialog = (
|
| 76 |
+
""
|
| 77 |
+
if self.context_field is None
|
| 78 |
+
else instance[self.context_field] + self.context_seperator
|
| 79 |
+
)
|
| 80 |
+
general_turn_format = self.get_general_turn_format(instance)
|
| 81 |
+
for i, turn in enumerate(structred_dialog):
|
| 82 |
+
turn_format = self.get_turn_format(
|
| 83 |
+
general_turn_format, i, len(structred_dialog)
|
| 84 |
+
)
|
| 85 |
+
dialog += turn_format.format(**turn)
|
| 86 |
+
if self.last_response_to_field is not None:
|
| 87 |
+
instance[self.last_response_to_field] = turn["system"]
|
| 88 |
+
return dialog
|