enable cutoff len

Former-commit-id: e9513d300c338dfcae98eee7d057bfd00da2da0e
This commit is contained in:
hiyouga 2024-01-18 12:25:42 +08:00
parent d8affd3967
commit e4a424cb6a
6 changed files with 46 additions and 14 deletions

View File

@ -0,0 +1 @@
0fe460c57c1a1dae70cdd72b185129c2aaae4e24

View File

@ -12,7 +12,7 @@ if TYPE_CHECKING:
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []} outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
if dataset_attr.history: if dataset_attr.history:
@ -33,13 +33,13 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append("") outputs["tools"].append("")
return outputs return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []} outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER, dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT, dataset_attr.assistant_tag: Role.ASSISTANT,
@ -69,7 +69,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs return outputs
@ -82,7 +82,7 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] prompt: [{"role": "user", "content": "..."}]
response: [{"role": "assistant", "content": "..."}] response: [{"role": "assistant", "content": "..."}]
system: "..." system: "..."
tool: "..." tools: "..."
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)

View File

@ -93,6 +93,9 @@ class ToolFormatter:
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try: try:
tools = json.loads(content) tools = json.loads(content)
if not len(tools):
return [""]
if self.type == "default": if self.type == "default":
return [self._default(tools)] return [self._default(tools)]
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@ -29,7 +29,7 @@ class DatasetAttr:
history: Optional[str] = None history: Optional[str] = None
messages: Optional[str] = "conversations" messages: Optional[str] = "conversations"
tool: Optional[str] = None tools: Optional[str] = None
role_tag: Optional[str] = "from" role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value" content_tag: Optional[str] = "value"
@ -86,7 +86,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"] column_names = ["prompt", "query", "response", "history"]
else: else:
column_names = ["messages", "tool"] column_names = ["messages", "tools"]
column_names += ["system"] column_names += ["system"]
for column_name in column_names: for column_name in column_names:

View File

@ -58,7 +58,7 @@ def preprocess_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], [] input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)): )):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i] tokenizer, messages, examples["system"][i], examples["tools"][i]
)): )):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
@ -141,7 +141,7 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = template.encode_oneturn( input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
) )
if template.efficient_eos: if template.efficient_eos:
@ -170,10 +170,10 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn( prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
) )
_, rejected_ids = template.encode_oneturn( _, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
) )
if template.efficient_eos: if template.efficient_eos:

View File

@ -95,7 +95,21 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] # TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
return encoded_pairs
def _convert_elements_to_ids( def _convert_elements_to_ids(
self, self,
@ -161,7 +175,21 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] # TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
return encoded_pairs
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}