fix mixtral inference #1821

Former-commit-id: f86857bd9ef456e77ad79a584f1fa08a129e5270
This commit is contained in:
hiyouga 2023-12-20 15:11:15 +08:00
parent e06b9c4fa1
commit a862ce636f

View File

@ -1,5 +1,6 @@
import math import math
import torch import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file from transformers.utils import cached_file
@ -20,7 +21,7 @@ logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r""" r"""
Dispatches a pre-trained model to GPUs with balanced memory. Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
""" """
if getattr(model, "quantization_method", None): # already set on current device if getattr(model, "quantization_method", None): # already set on current device
return model return model
@ -32,12 +33,15 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
if model._no_split_modules is None: if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs) max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map) device_map_kwargs = {"device_map": device_map}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else: else:
return model.cuda() return model.cuda()