diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 4371a0f4..d9801fc1 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -81,9 +81,16 @@ def load_model_and_tokenizer( if add_valuehead: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) patcher.patch_valuehead_model(model) - vhead_params = load_valuehead_params(model_args) + + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + vhead_params = load_valuehead_params(vhead_path, model_args) if vhead_params is not None: model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) if not is_trainable: model.requires_grad_(False) # fix all model params diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 1b84f31d..e8aa164d 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -85,34 +85,21 @@ def get_modelcard_args( } -def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]: +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. """ - if model_args.adapter_name_or_path is not None: - path_or_repo_id = model_args.adapter_name_or_path[-1] - else: - path_or_repo_id = model_args.model_name_or_path - kwargs = { "path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token } - try: - vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) - logger.info("Loaded valuehead from {}".format(path_or_repo_id)) - return torch.load(vhead_file, map_location="cpu") - except Exception as err: - logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err))) - try: from safetensors import safe_open vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) - logger.info("Loaded valuehead from {}".format(path_or_repo_id)) with safe_open(vhead_file, framework="pt", device="cpu") as f: return { "v_head.summary.weight": f.get_tensor("v_head.summary.weight"), @@ -121,6 +108,12 @@ def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tenso except Exception as err: logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err))) + try: + vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) + return torch.load(vhead_file, map_location="cpu") + except Exception as err: + logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err))) + logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) return None diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 40129840..0c6af1d2 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -206,6 +206,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.finetuning_args.upcast_layernorm: layernorm_params = dump_layernorm(self.model) + if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 + start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() + for k, v in batch.items(): + batch[k] = v[:, start_index:] + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) generate_output: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, @@ -220,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu() queries, responses = [], [] for i in range(len(query)): - query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() + query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() if len(response_index) == 0: @@ -228,7 +233,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): else: response_length = response_index[-1].item() + 1 - queries.append(query[i, query_length:]) # remove padding from left + queries.append(query[i, query_start_index:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right return queries, responses