From 6404167ab76070cddae10552b80ea49d1d10eabd Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 23:00:26 +0800 Subject: [PATCH] support val set in streaming mode Former-commit-id: d86ea314a197fd821770d895e988c48d46679047 --- src/llmtuner/api/app.py | 2 +- src/llmtuner/dsets/preprocess.py | 7 ++-- src/llmtuner/dsets/utils.py | 28 ++++++++++++---- src/llmtuner/extras/template.py | 2 +- src/llmtuner/hparams/data_args.py | 4 +-- src/llmtuner/tuner/core/parser.py | 52 ++++++++++++++---------------- src/llmtuner/tuner/ppo/workflow.py | 7 ++-- src/llmtuner/tuner/pt/workflow.py | 2 +- src/llmtuner/tuner/rm/workflow.py | 2 +- src/llmtuner/tuner/sft/workflow.py | 2 +- 10 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 4fc5fc43..47b7661f 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -49,8 +49,8 @@ def create_app(chat_model: ChatModel) -> FastAPI: async def create_chat_completion(request: ChatCompletionRequest): if request.messages[-1].role != Role.USER: raise HTTPException(status_code=400, detail="Invalid request") - query = request.messages[-1].content + query = request.messages[-1].content prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: prefix = prev_messages.pop(0).content diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index d2150dbc..534d77b5 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -104,9 +104,9 @@ def preprocess_dataset( if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(accept_ids) > data_args.max_target_length: - accept_ids = accept_ids[:data_args.max_target_length - 1] + accept_ids = accept_ids[:data_args.max_target_length] if len(reject_ids) > data_args.max_target_length: - reject_ids = reject_ids[:data_args.max_target_length - 1] + reject_ids = reject_ids[:data_args.max_target_length] accept_ids = source_ids + accept_ids reject_ids = source_ids + reject_ids @@ -166,8 +166,5 @@ def preprocess_dataset( **kwargs ) - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size) - print_function(next(iter(dataset))) return dataset diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/dsets/utils.py index 31c48222..e1093a95 100644 --- a/src/llmtuner/dsets/utils.py +++ b/src/llmtuner/dsets/utils.py @@ -1,15 +1,29 @@ -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Union if TYPE_CHECKING: - from datasets import Dataset + from datasets import Dataset, IterableDataset + from transformers import TrainingArguments + from llmtuner.hparams import DataArguments -def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: - if do_train: - if dev_ratio > 1e-6: # Split the dataset - dataset = dataset.train_test_split(test_size=dev_ratio) - return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], + data_args: "DataArguments", + training_args: "TrainingArguments" +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + if data_args.streaming: + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": train_set, "eval_dataset": val_set} + else: + dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} else: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) return {"train_dataset": dataset} else: # do_eval or do_predict return {"eval_dataset": dataset} diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 91595751..413333ac 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -97,7 +97,7 @@ class Template: sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: + if turn_idx == 0 and prefix: prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids else: prefix_ids = sep_ids diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 60945b60..de470ae2 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -90,9 +90,9 @@ class DataArguments: default=None, metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} ) - dev_ratio: Optional[float] = field( + val_size: Optional[float] = field( default=0, - metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} ) def init_for_training(self): # support mixing multiple datasets diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index d872afcc..692f9b13 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -67,33 +67,33 @@ def get_train_args( # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) data_args.init_for_training() - assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." + if general_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") - assert not (training_args.do_train and training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True while training." + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") - assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ - "Please enable `predict_with_generate` to save model predictions." + if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") - assert not (training_args.max_steps == -1 and data_args.streaming), \ - "Please specify `max_steps` in streaming mode." + if general_args.stage == "ppo" and data_args.streaming: + raise ValueError("Streaming mode does not suppport PPO training currently.") - assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ - "Streaming mode does not support evaluation currently." + if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming: + raise ValueError("Streaming mode should have an integer val size.") - assert not (general_args.stage == "ppo" and data_args.streaming), \ - "Streaming mode does not suppport PPO training currently." + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -113,10 +113,6 @@ def get_train_args( logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") data_args.max_samples = None - if data_args.dev_ratio > 1e-6 and data_args.streaming: - logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") - data_args.dev_ratio = 0 - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if model_args.quantization_bit is not None: @@ -145,14 +141,14 @@ def get_infer_args( ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint.") return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 0ca8cbd4..aa372671 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -41,14 +41,15 @@ def run_ppo( max_grad_norm=training_args.max_grad_norm ) - optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) total_train_batch_size = \ training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) lr_scheduler = get_scheduler( training_args.lr_scheduler_type, optimizer=optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps ) # Initialize our Trainer diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 2a9f8279..b4ea148b 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -38,7 +38,7 @@ def run_pt( tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Training diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index 19527ce8..b19a13e6 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -39,7 +39,7 @@ def run_rm( data_collator=data_collator, callbacks=callbacks, compute_metrics=compute_accuracy, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Training diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 69d200f3..a5cd2cd3 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -46,7 +46,7 @@ def run_sft( data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Keyword arguments for `model.generate`