fix ppo runtime error

Former-commit-id: cdb7f82869b07d9d5d31b7b2aaf6b033bd00e32e
This commit is contained in:
stephen 2024-03-08 11:48:26 +08:00
parent b268215a0e
commit 495b858606
2 changed files with 4 additions and 2 deletions

View File

@ -61,7 +61,7 @@ def load_model(
"""
init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:

View File

@ -264,6 +264,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@ -288,7 +289,8 @@ def patch_config(
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if "device_map" not in init_kwargs: # quant models cannot use auto device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
def patch_model(