mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 05:02:50 +08:00
fix rm server
Former-commit-id: 55021097d565536a68113ee33af31beaff38334e
This commit is contained in:
parent
0a78375650
commit
1a86cc3078
@ -87,11 +87,11 @@ def load_model_and_tokenizer(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
config=config,
|
config=config,
|
||||||
|
torch_dtype=model_args.compute_dtype,
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
register_autoclass(config, model, tokenizer)
|
register_autoclass(config, model, tokenizer)
|
||||||
|
|
||||||
@ -113,6 +113,7 @@ def load_model_and_tokenizer(
|
|||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
|
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -276,5 +276,6 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
|||||||
|
|
||||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
|
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
|
||||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
|
@ -27,7 +27,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
if model._no_split_modules is None:
|
if getattr(model, "_no_split_modules", None) is None:
|
||||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user