Former-commit-id: 9bec3c98a22c91b1c28fda757db51eb780291641
This commit is contained in:
hiyouga 2024-03-20 17:59:45 +08:00
parent cf149bf43c
commit 8717e98200
12 changed files with 104 additions and 48 deletions

View File

@ -486,7 +486,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
#### Use Huggingface Accelerate #### Use Huggingface Accelerate
```bash ```bash
accelerate launch --config_file config.yaml src/train_bash.py # arguments (same as above) accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # arguments (same as above)
``` ```
<details><summary>Example config.yaml for LoRA training</summary> <details><summary>Example config.yaml for LoRA training</summary>
@ -519,8 +521,8 @@ use_cpu: false
```bash ```bash
deepspeed --num_gpus 8 src/train_bash.py \ deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \ --deepspeed ds_config.json \
--ddp_timeout 180000000 \ --ddp_timeout 180000000 \
... # arguments (same as above) ... # arguments (same as above)
``` ```

View File

@ -485,7 +485,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
#### 使用 Huggingface Accelerate #### 使用 Huggingface Accelerate
```bash ```bash
accelerate launch --config_file config.yaml src/train_bash.py # 参数同上 accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # 参数同上
``` ```
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary> <details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
@ -519,9 +521,8 @@ use_cpu: false
```bash ```bash
deepspeed --num_gpus 8 src/train_bash.py \ deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \ --deepspeed ds_config.json \
--ddp_timeout 180000000 \ --ddp_timeout 180000000 \
... # 参数同上 ... # 参数同上
``` ```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary> <details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>

View File

@ -8,11 +8,14 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ...hparams import FinetuningArguments
class CustomDPOTrainer(DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
@ -21,6 +24,7 @@ class CustomDPOTrainer(DPOTrainer):
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"], loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float, ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True, disable_dropout: bool = True,
**kwargs, **kwargs,
@ -30,6 +34,7 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.reference_free = False self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
@ -61,6 +66,13 @@ class CustomDPOTrainer(DPOTrainer):
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
r""" r"""
Computes supervised cross-entropy loss of given labels under the given logits. Computes supervised cross-entropy loss of given labels under the given logits.

View File

@ -7,7 +7,7 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model from ..utils import create_modelcard_and_push, create_ref_model
from .collator import DPODataCollatorWithPadding from .collator import DPODataCollatorWithPadding
from .trainer import CustomDPOTrainer from .trainer import CustomDPOTrainer
@ -44,18 +44,17 @@ def run_dpo(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta, beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss, loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx, ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model, model=model,
ref_model=ref_model, ref_model=ref_model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -64,16 +64,16 @@ def run_ppo(
) )
# Create optimizer and scheduler # Create optimizer and scheduler
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0: if training_args.max_steps > 0:
num_training_steps = training_args.max_steps num_training_steps = training_args.max_steps
else: else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,

View File

@ -0,0 +1,30 @@
from typing import TYPE_CHECKING
from transformers import Trainer
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
if TYPE_CHECKING:
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

View File

@ -3,12 +3,13 @@
import math import math
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling, Trainer from transformers import DataCollatorForLanguageModeling
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_modelcard_and_push from ..utils import create_modelcard_and_push
from .trainer import CustomTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -30,14 +31,13 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) trainer = CustomTrainer(
trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -6,25 +6,36 @@ import torch
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PairwiseTrainer(Trainer): class PairwiseTrainer(Trainer):
r""" r"""
Inherits PeftTrainer to compute pairwise loss. Inherits Trainer to compute pairwise loss.
""" """
def __init__(self, *args, **kwargs): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:

View File

@ -7,7 +7,7 @@ from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_modelcard_and_push from ..utils import create_modelcard_and_push
from .collator import PairwiseDataCollatorWithPadding from .collator import PairwiseDataCollatorWithPadding
from .metric import compute_accuracy from .metric import compute_accuracy
from .trainer import PairwiseTrainer from .trainer import PairwiseTrainer
@ -35,14 +35,13 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = PairwiseTrainer( trainer = PairwiseTrainer(
model=model, model=model,
args=training_args, args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
optimizers=(optimizer, None),
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -4,28 +4,41 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r""" r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
""" """
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def prediction_step( def prediction_step(
self, self,
model: nn.Module, model: "torch.nn.Module",
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,

View File

@ -9,10 +9,9 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics from ..utils import create_modelcard_and_push
from ...train.sft.trainer import CustomSeq2SeqTrainer from .metric import ComputeMetrics
from ...train.utils import create_modelcard_and_push from .trainer import CustomSeq2SeqTrainer
from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,14 +49,13 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -1,4 +1,3 @@
import math
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch import torch
@ -19,7 +18,6 @@ if is_galore_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -156,9 +154,9 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
max_steps: int,
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore-torch") require_version("galore_torch", "To fix: pip install galore-torch")
@ -209,12 +207,6 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.") raise ValueError("Per-layer GaLore does not support gradient accumulation.")
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params: for param in nodecay_params:
param_groups = [dict(params=[param])] param_groups = [dict(params=[param])]
@ -231,8 +223,8 @@ def _create_galore_optimizer(
scheduler_dict[param] = get_scheduler( scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer_dict[param], optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2, num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
num_training_steps=num_training_steps * 2, num_training_steps=max_steps * 2,
) )
def optimizer_hook(param: "torch.Tensor"): def optimizer_hook(param: "torch.Tensor"):
@ -259,7 +251,6 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer( def _create_loraplus_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
@ -302,12 +293,12 @@ def _create_loraplus_optimizer(
def create_custom_optimzer( def create_custom_optimzer(
model: "PreTrainedModel", model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
max_steps: int,
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
return _create_galore_optimizer(model, dataset, training_args, finetuning_args) return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args) return _create_loraplus_optimizer(model, training_args, finetuning_args)