support LongLoRA

Former-commit-id: 0832ed37e7947d699f17375648a52f80752c2b6b
This commit is contained in:
hiyouga
2023-09-27 21:55:50 +08:00
parent 889a24ccfa
commit d6f5a3cae9
8 changed files with 313 additions and 329 deletions

View File

@@ -22,6 +22,9 @@ def preprocess_dataset(
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if template.efficient_eos and data_args.sft_packing:
raise ValueError("Current template is incompatible with packing.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
@@ -96,6 +99,28 @@ def preprocess_dataset(
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# we do not mask the inputs in packed training.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
input_ids += source_ids + target_ids
labels += source_ids + target_ids # TODO: try masking source_ids here
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * len(block_size))
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
@@ -166,19 +191,19 @@ def preprocess_dataset(
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
@@ -191,7 +216,7 @@ def preprocess_dataset(
)
dataset = dataset.map(
preprocess_function,
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs