mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
parent
397f6bb615
commit
8154b4bdf6
@ -81,9 +81,16 @@ def load_model_and_tokenizer(
|
|||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
patcher.patch_valuehead_model(model)
|
patcher.patch_valuehead_model(model)
|
||||||
vhead_params = load_valuehead_params(model_args)
|
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
vhead_path = model_args.adapter_name_or_path[-1]
|
||||||
|
else:
|
||||||
|
vhead_path = model_args.model_name_or_path
|
||||||
|
|
||||||
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
|
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
|
@ -85,34 +85,21 @@ def get_modelcard_args(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
|
||||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
"""
|
"""
|
||||||
if model_args.adapter_name_or_path is not None:
|
|
||||||
path_or_repo_id = model_args.adapter_name_or_path[-1]
|
|
||||||
else:
|
|
||||||
path_or_repo_id = model_args.model_name_or_path
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"path_or_repo_id": path_or_repo_id,
|
"path_or_repo_id": path_or_repo_id,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
"token": model_args.hf_hub_token
|
"token": model_args.hf_hub_token
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
|
||||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
|
||||||
except Exception as err:
|
|
||||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
|
||||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
return {
|
return {
|
||||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
||||||
@ -121,6 +108,12 @@ def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tenso
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||||
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
|
except Exception as err:
|
||||||
|
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -206,6 +206,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if self.finetuning_args.upcast_layernorm:
|
if self.finetuning_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
|
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||||
|
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v[:, start_index:]
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config,
|
||||||
@ -220,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
|
|
||||||
if len(response_index) == 0:
|
if len(response_index) == 0:
|
||||||
@ -228,7 +233,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
else:
|
else:
|
||||||
response_length = response_index[-1].item() + 1
|
response_length = response_index[-1].item() + 1
|
||||||
|
|
||||||
queries.append(query[i, query_length:]) # remove padding from left
|
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||||
responses.append(response[i, :response_length]) # remove padding from right
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
return queries, responses
|
return queries, responses
|
||||||
|
Loading…
x
Reference in New Issue
Block a user