modify code structure

This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 1d8a1878ea
commit 08f180e788
25 changed files with 188 additions and 145 deletions

View File

@@ -11,7 +11,6 @@ from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@@ -90,14 +89,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter)
steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
@@ -106,21 +104,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
unwrapped_model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
@@ -137,10 +137,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
@torch.no_grad()
def get_inputs(
self,
inputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
**generation_kwargs
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
@@ -152,7 +154,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
@@ -161,7 +163,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
unwrapped_model.pretrained_model.generation_config._from_model_config = False
queries, responses = [], []
query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu()
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
@@ -181,11 +183,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards