mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
fix ppo runtime error
This commit is contained in:
@@ -264,6 +264,7 @@ def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
@@ -288,7 +289,8 @@ def patch_config(
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
|
||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
||||
Reference in New Issue
Block a user