mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
modify some style
Former-commit-id: ece78a6d6af0673795824b2f95c266c042532eb3
This commit is contained in:
parent
26b760842b
commit
5ef2b8bdda
@ -6,6 +6,7 @@ from ..extras.constants import IGNORE_INDEX
|
|||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import Role
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
||||||
@ -34,7 +35,7 @@ def preprocess_pretrain_dataset(
|
|||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
result = {
|
result = {
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
@ -132,10 +133,10 @@ def preprocess_packed_supervised_dataset(
|
|||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
# split by chunks of cutoff_len
|
||||||
for i in range(0, total_length, block_size):
|
for i in range(0, total_length, block_size):
|
||||||
if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]):
|
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
model_inputs["labels"].append(labels[i: i + block_size])
|
model_inputs["labels"].append(labels[i : i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user