mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
parent
7cc0721028
commit
f6b2bcfa16
@ -41,7 +41,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
return dispatch_model(model, **device_map_kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user