mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
fix layer norm name in PPO
Former-commit-id: e3aaef7d4a37e4aa388a9158c382db8239843a5e
This commit is contained in:
parent
b0e9a673be
commit
6ab22a0181
@ -41,7 +41,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
|
||||
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user