mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
fix ppo args
This commit is contained in:
@@ -201,7 +201,9 @@ def get_train_args(
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
|
||||
Reference in New Issue
Block a user