mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
allow non-packing pretraining
This commit is contained in:
@@ -21,8 +21,11 @@ logger = get_logger(__name__)
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
if not data_args.packing:
|
||||
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
@@ -245,7 +248,7 @@ def get_preprocess_and_print_func(
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.sft_packing:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user