mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
fix aquila template, repair sft packing mechanism
Former-commit-id: 8c82cfa5dd4bec957426b5bf176d242c77552ab0
This commit is contained in:
parent
6d0d46c7fb
commit
bd8ea09479
@ -22,9 +22,6 @@ def preprocess_dataset(
|
|||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
if template is not None and template.efficient_eos and data_args.sft_packing:
|
|
||||||
raise ValueError("Current template is incompatible with packing.")
|
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
@ -105,9 +102,19 @@ def preprocess_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
|
tokenizer, query, response, history, system
|
||||||
|
)):
|
||||||
|
if turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += source_ids + target_ids # TODO: try masking source_ids here
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_length = len(input_ids)
|
total_length = len(input_ids)
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
|
@ -423,7 +423,7 @@ register_template(
|
|||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
@ -439,7 +439,11 @@ register_template(
|
|||||||
),
|
),
|
||||||
sep=[
|
sep=[
|
||||||
"###"
|
"###"
|
||||||
]
|
],
|
||||||
|
stop_words=[
|
||||||
|
"</s>"
|
||||||
|
],
|
||||||
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user