[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -20,7 +20,7 @@ import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional
import torch
from accelerate.utils import DistributedDataParallelKwargs
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
r"""Inherit PPOTrainer."""
def __init__(
self,
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]],
callbacks: Optional[list["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info_rank0(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
)
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
@@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return lr_scheduler
@torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r"""
Generates model's responses given queries.
"""
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
r"""Generate model's responses given queries."""
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: "torch.Tensor" = unwrapped_model.generate(
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
@@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad()
def get_rewards(
self,
queries: List["torch.Tensor"],
responses: List["torch.Tensor"],
) -> List["torch.Tensor"]:
r"""
Computes scores using given reward model.
queries: list["torch.Tensor"],
responses: list["torch.Tensor"],
) -> list["torch.Tensor"]:
r"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
@@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
@@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model: "AutoModelForCausalLMWithValueHead",
queries: "torch.Tensor",
responses: "torch.Tensor",
model_inputs: Dict[str, Any],
model_inputs: dict[str, Any],
return_logits: bool = False,
response_masks: Optional["torch.Tensor"] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""
Calculates model outputs in multiple batches.
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
@@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@override
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
r"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
@@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.save_checkpoint(output_dir)
elif self.args.should_save:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
self._save(output_dir, state_dict=unwrapped_model.state_dict())