mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
Fix slow model initialization in bfloat16 dtype.
Former-commit-id: d032daa4bd598dd0d71b43eb68a614de77a699a6
This commit is contained in:
parent
ce79528bb1
commit
3eaabe12aa
@ -67,10 +67,10 @@ def load_model_and_tokenizer(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
patcher.patch_model(model)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
@ -95,7 +95,6 @@ def load_model_and_tokenizer(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
Loading…
x
Reference in New Issue
Block a user