mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
add BOS token in pre-training
Former-commit-id: d668f8b501c367276ef4be372f2eb1753a1b7e86
This commit is contained in:
parent
3419396945
commit
dd1e7ed3cf
@ -430,15 +430,16 @@ def preprocess_data(
|
|||||||
yield dialog
|
yield dialog
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples):
|
def preprocess_pretrain_dataset(examples):
|
||||||
# build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS])
|
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
|
||||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
concatenated_ids = list(chain(*text_ids))
|
concatenated_ids = list(chain(*text_ids))
|
||||||
total_length = len(concatenated_ids)
|
total_length = len(concatenated_ids)
|
||||||
|
block_size = data_args.max_source_length - 1
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of max_source_length
|
# split by chunks of max_source_length
|
||||||
result = [concatenated_ids[i: i + data_args.max_source_length] for i in
|
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||||
range(0, total_length, data_args.max_source_length)]
|
for i in range(0, total_length, block_size)]
|
||||||
return {
|
return {
|
||||||
"input_ids": result,
|
"input_ids": result,
|
||||||
"labels": result.copy()
|
"labels": result.copy()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user