mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-18 19:12:49 +08:00
Former-commit-id: 9bec3c98a22c91b1c28fda757db51eb780291641
This commit is contained in:
parent
cf149bf43c
commit
8717e98200
@ -486,7 +486,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
#### Use Huggingface Accelerate
|
#### Use Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate launch --config_file config.yaml src/train_bash.py # arguments (same as above)
|
accelerate launch --config_file config.yaml src/train_bash.py \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Example config.yaml for LoRA training</summary>
|
<details><summary>Example config.yaml for LoRA training</summary>
|
||||||
@ -519,8 +521,8 @@ use_cpu: false
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
--ddp_timeout 180000000 \
|
--ddp_timeout 180000000 \
|
||||||
... # arguments (same as above)
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -485,7 +485,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
#### 使用 Huggingface Accelerate
|
#### 使用 Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate launch --config_file config.yaml src/train_bash.py # 参数同上
|
accelerate launch --config_file config.yaml src/train_bash.py \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
|
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
|
||||||
@ -519,9 +521,8 @@ use_cpu: false
|
|||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
--ddp_timeout 180000000 \
|
--ddp_timeout 180000000 \
|
||||||
... # 参数同上
|
... # 参数同上
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>
|
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>
|
||||||
|
@ -8,11 +8,14 @@ from trl import DPOTrainer
|
|||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -21,6 +24,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||||
ftx_gamma: float,
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: bool = True,
|
disable_dropout: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -30,6 +34,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
self.reference_free = False
|
self.reference_free = False
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
@ -61,6 +66,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
|
|
||||||
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
@ -7,7 +7,7 @@ from ...extras.constants import IGNORE_INDEX
|
|||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
|
from ..utils import create_modelcard_and_push, create_ref_model
|
||||||
from .collator import DPODataCollatorWithPadding
|
from .collator import DPODataCollatorWithPadding
|
||||||
from .trainer import CustomDPOTrainer
|
from .trainer import CustomDPOTrainer
|
||||||
|
|
||||||
@ -44,18 +44,17 @@ def run_dpo(
|
|||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
loss_type=finetuning_args.dpo_loss,
|
loss_type=finetuning_args.dpo_loss,
|
||||||
ftx_gamma=finetuning_args.dpo_ftx,
|
ftx_gamma=finetuning_args.dpo_ftx,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
optimizers=(optimizer, None),
|
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,16 +64,16 @@ def run_ppo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
|
||||||
if optimizer is None:
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
|
||||||
|
|
||||||
if training_args.max_steps > 0:
|
if training_args.max_steps > 0:
|
||||||
num_training_steps = training_args.max_steps
|
num_training_steps = training_args.max_steps
|
||||||
else:
|
else:
|
||||||
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
|
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
30
src/llmtuner/train/pt/trainer.py
Normal file
30
src/llmtuner/train/pt/trainer.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTrainer(Trainer):
|
||||||
|
r"""
|
||||||
|
Inherits Trainer for custom optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
@ -3,12 +3,13 @@
|
|||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
from ..utils import create_modelcard_and_push
|
||||||
|
from .trainer import CustomTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -30,14 +31,13 @@ def run_pt(
|
|||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
trainer = CustomTrainer(
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
optimizers=(optimizer, None),
|
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,25 +6,36 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PairwiseTrainer(Trainer):
|
class PairwiseTrainer(Trainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PeftTrainer to compute pairwise loss.
|
Inherits Trainer to compute pairwise loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
@ -7,7 +7,7 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
|||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
from ..utils import create_modelcard_and_push
|
||||||
from .collator import PairwiseDataCollatorWithPadding
|
from .collator import PairwiseDataCollatorWithPadding
|
||||||
from .metric import compute_accuracy
|
from .metric import compute_accuracy
|
||||||
from .trainer import PairwiseTrainer
|
from .trainer import PairwiseTrainer
|
||||||
@ -35,14 +35,13 @@ def run_rm(
|
|||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
|
||||||
trainer = PairwiseTrainer(
|
trainer = PairwiseTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
optimizers=(optimizer, None),
|
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
@ -4,28 +4,41 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: "torch.nn.Module",
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
prediction_loss_only: bool,
|
prediction_loss_only: bool,
|
||||||
ignore_keys: Optional[List[str]] = None,
|
ignore_keys: Optional[List[str]] = None,
|
||||||
|
@ -9,10 +9,9 @@ from ...extras.constants import IGNORE_INDEX
|
|||||||
from ...extras.misc import get_logits_processor
|
from ...extras.misc import get_logits_processor
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.sft.metric import ComputeMetrics
|
from ..utils import create_modelcard_and_push
|
||||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
from .metric import ComputeMetrics
|
||||||
from ...train.utils import create_modelcard_and_push
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
from ..utils import create_custom_optimzer
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -50,14 +49,13 @@ def run_sft(
|
|||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
optimizers=(optimizer, None),
|
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -19,7 +18,6 @@ if is_galore_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@ -156,9 +154,9 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
|||||||
|
|
||||||
def _create_galore_optimizer(
|
def _create_galore_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
max_steps: int,
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
require_version("galore_torch", "To fix: pip install galore-torch")
|
require_version("galore_torch", "To fix: pip install galore-torch")
|
||||||
|
|
||||||
@ -209,12 +207,6 @@ def _create_galore_optimizer(
|
|||||||
if training_args.gradient_accumulation_steps != 1:
|
if training_args.gradient_accumulation_steps != 1:
|
||||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||||
|
|
||||||
if training_args.max_steps > 0:
|
|
||||||
num_training_steps = training_args.max_steps
|
|
||||||
else:
|
|
||||||
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
|
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
|
||||||
|
|
||||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||||
for param in nodecay_params:
|
for param in nodecay_params:
|
||||||
param_groups = [dict(params=[param])]
|
param_groups = [dict(params=[param])]
|
||||||
@ -231,8 +223,8 @@ def _create_galore_optimizer(
|
|||||||
scheduler_dict[param] = get_scheduler(
|
scheduler_dict[param] = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer_dict[param],
|
optimizer=optimizer_dict[param],
|
||||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
|
||||||
num_training_steps=num_training_steps * 2,
|
num_training_steps=max_steps * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def optimizer_hook(param: "torch.Tensor"):
|
def optimizer_hook(param: "torch.Tensor"):
|
||||||
@ -259,7 +251,6 @@ def _create_galore_optimizer(
|
|||||||
|
|
||||||
def _create_loraplus_optimizer(
|
def _create_loraplus_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
@ -302,12 +293,12 @@ def _create_loraplus_optimizer(
|
|||||||
|
|
||||||
def create_custom_optimzer(
|
def create_custom_optimzer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
max_steps: int,
|
||||||
) -> Optional["torch.optim.Optimizer"]:
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
|
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
|
||||||
|
|
||||||
if finetuning_args.loraplus_lr_ratio is not None:
|
if finetuning_args.loraplus_lr_ratio is not None:
|
||||||
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)
|
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user