diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index b6d71fcf..41f1f416 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -14,7 +14,6 @@ class ChatModel: model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model = dispatch_model(self.model) - self.model = self.model.eval() # enable evaluation mode self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.system_prompt = data_args.system_prompt diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index b924919c..f72cd2eb 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -175,6 +175,7 @@ def load_model_and_tokenizer( # Initialize adapters model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) + model = model.train() if is_trainable else model.eval() # Prepare model with valuehead for RLHF if stage == "rm" or stage == "ppo": diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 00fd5e41..c21ec2ff 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -99,6 +99,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True + unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype) + self.model.eval() # Get inputs queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) @@ -108,6 +110,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Cast to training mode unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False + unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params) + self.model.train() # Run PPO step stats = self.step(queries, responses, rewards) @@ -157,10 +161,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() - self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) - self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273