From c5ec4eaef548505fc0911c9e6986ea3d23d65e01 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 27 Jun 2023 23:54:24 +0800 Subject: [PATCH] tiny fix Former-commit-id: 450910c1db969533c5268022cb064cbc2c9cb7e6 --- README.md | 1 + src/export_model.py | 1 + src/utils/data_collator.py | 8 +++++++- src/utils/seq2seq.py | 10 ++++++---- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index be4d0fed..3f97aa79 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,7 @@ accelerate launch src/train_XX.py # arguments (same as above) compute_environment: LOCAL_MACHINE deepspeed_config: gradient_accumulation_steps: 4 + gradient_clipping: 0.5 offload_optimizer_device: none offload_param_device: none zero3_init_flag: false diff --git a/src/export_model.py b/src/export_model.py index 71985180..e36d5c82 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -13,6 +13,7 @@ def main(): model.save_pretrained(training_args.output_dir, max_shard_size="10GB") tokenizer.save_pretrained(training_args.output_dir) print("model and tokenizer have been saved at:", training_args.output_dir) + print("Remember to copy the *.py files from the original directory.") if __name__ == "__main__": diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py index dbfc34f0..4db1a305 100644 --- a/src/utils/data_collator.py +++ b/src/utils/data_collator.py @@ -26,8 +26,10 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding): """ batch_size, seq_length = input_ids.size() attention_mask = torch.ones((batch_size, seq_length), device=device) + for i, seq in enumerate(input_ids): attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding + attention_mask = attention_mask.bool() return attention_mask @@ -49,7 +51,11 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding): labels = [torch.tensor(feature["labels"]).flip(0) for feature in features] input_ids = input_ids + labels # pad them to the same length - input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id + ).flip(-1) batch = {} diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 6c772a25..267310bc 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -35,8 +35,9 @@ class ComputeMetrics: score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} for pred, label in zip(preds, labels): - pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts - label = label[:len(label) - np.sum(label == IGNORE_INDEX)] + pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX) + pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts + label = label[:len(label) - label_pad_len] hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) @@ -79,8 +80,9 @@ class Seq2SeqPeftTrainer(PeftTrainer): with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] for pred, label in zip(predict_results.predictions, predict_results.label_ids): - pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts - label = label[:len(label) - np.sum(label == IGNORE_INDEX)] + pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX) + pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts + label = label[:len(label) - label_pad_len] pred = self.tokenizer.decode(pred, skip_special_tokens=True) label = self.tokenizer.decode(label, skip_special_tokens=True)