support val set in streaming mode

Former-commit-id: d86ea314a197fd821770d895e988c48d46679047
This commit is contained in:
hiyouga 2023-08-09 23:00:26 +08:00
parent d01c1231ed
commit 6404167ab7
10 changed files with 58 additions and 50 deletions

View File

@ -49,8 +49,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER: if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content

View File

@ -104,9 +104,9 @@ def preprocess_dataset(
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length: if len(accept_ids) > data_args.max_target_length:
accept_ids = accept_ids[:data_args.max_target_length - 1] accept_ids = accept_ids[:data_args.max_target_length]
if len(reject_ids) > data_args.max_target_length: if len(reject_ids) > data_args.max_target_length:
reject_ids = reject_ids[:data_args.max_target_length - 1] reject_ids = reject_ids[:data_args.max_target_length]
accept_ids = source_ids + accept_ids accept_ids = source_ids + accept_ids
reject_ids = source_ids + reject_ids reject_ids = source_ids + reject_ids
@ -166,8 +166,5 @@ def preprocess_dataset(
**kwargs **kwargs
) )
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
return dataset return dataset

View File

@ -1,15 +1,29 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: def split_dataset(
if do_train: dataset: Union["Dataset", "IterableDataset"],
if dev_ratio > 1e-6: # Split the dataset data_args: "DataArguments",
dataset = dataset.train_test_split(test_size=dev_ratio) training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else: else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset} return {"train_dataset": dataset}
else: # do_eval or do_predict else: # do_eval or do_predict
return {"eval_dataset": dataset} return {"eval_dataset": dataset}

View File

@ -97,7 +97,7 @@ class Template:
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = [] encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: if turn_idx == 0 and prefix:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
else: else:
prefix_ids = sep_ids prefix_ids = sep_ids

View File

@ -90,9 +90,9 @@ class DataArguments:
default=None, default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
) )
dev_ratio: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
) )
def init_for_training(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets

View File

@ -67,33 +67,33 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training() data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ if general_args.stage != "sft" and training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages." raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
assert not (training_args.do_train and training_args.predict_with_generate), \ if training_args.do_train and training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True while training." raise ValueError("`predict_with_generate` cannot be set as True while training.")
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
"Please enable `predict_with_generate` to save model predictions." raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if training_args.max_steps == -1 and data_args.streaming:
"Quantization is only compatible with the LoRA method." raise ValueError("Please specify `max_steps` in streaming mode.")
assert not (training_args.max_steps == -1 and data_args.streaming), \ if general_args.stage == "ppo" and data_args.streaming:
"Please specify `max_steps` in streaming mode." raise ValueError("Streaming mode does not suppport PPO training currently.")
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
"Streaming mode does not support evaluation currently." raise ValueError("Streaming mode should have an integer val size.")
assert not (general_args.stage == "ppo" and data_args.streaming), \ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
"Streaming mode does not suppport PPO training currently." raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -113,10 +113,6 @@ def get_train_args(
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
@ -145,14 +141,14 @@ def get_infer_args(
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
"Quantization is only compatible with the LoRA method." raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args

View File

@ -41,14 +41,15 @@ def run_ppo(
max_grad_norm=training_args.max_grad_norm max_grad_norm=training_args.max_grad_norm
) )
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = \ total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps, num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) num_training_steps=num_training_steps
) )
# Initialize our Trainer # Initialize our Trainer

View File

@ -38,7 +38,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training

View File

@ -39,7 +39,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training

View File

@ -46,7 +46,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`