mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
add max_memory for gptq #1923
Former-commit-id: c4a3977ad7278b59a9b3dadcb446bb4c99da5c9d
This commit is contained in:
parent
82a79e9fdf
commit
c1233ab65f
@ -63,8 +63,8 @@ def get_dataset(
|
|||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
try:
|
try:
|
||||||
from modelscope import MsDataset # type: ignore
|
from modelscope import MsDataset
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
@ -75,7 +75,7 @@ def get_dataset(
|
|||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
).to_hf_dataset()
|
).to_hf_dataset()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
@ -3,25 +3,22 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
is_torch_bf16_cpu_available,
|
is_torch_bf16_cpu_available,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_npu_available
|
is_torch_npu_available,
|
||||||
|
is_torch_xpu_available
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
|
||||||
except ImportError:
|
|
||||||
_is_fp16_available = torch.cuda.is_available()
|
|
||||||
try:
|
try:
|
||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||||
except:
|
except:
|
||||||
_is_bf16_available = False
|
_is_bf16_available = False
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import HfArgumentParser
|
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@ -68,12 +65,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
def get_current_device() -> torch.device:
|
||||||
import accelerate
|
r"""
|
||||||
if accelerate.utils.is_xpu_available():
|
Gets the current available device.
|
||||||
|
"""
|
||||||
|
if is_torch_xpu_available():
|
||||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
elif accelerate.utils.is_npu_available():
|
elif is_torch_npu_available():
|
||||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
elif torch.cuda.is_available():
|
elif is_torch_cuda_available():
|
||||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
@ -117,7 +116,7 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modelscope import snapshot_download # type: ignore
|
from modelscope import snapshot_download
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
model_args.model_name_or_path = snapshot_download(
|
model_args.model_name_or_path = snapshot_download(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
|
@ -76,6 +76,7 @@ def configure_quantization(
|
|||||||
if finetuning_args.export_quantization_bit is not None: # gptq
|
if finetuning_args.export_quantization_bit is not None: # gptq
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
from accelerate.utils import get_max_memory
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
raise ValueError("ChatGLM model is not supported.")
|
raise ValueError("ChatGLM model is not supported.")
|
||||||
@ -86,6 +87,7 @@ def configure_quantization(
|
|||||||
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
||||||
)
|
)
|
||||||
config_kwargs["device_map"] = "auto"
|
config_kwargs["device_map"] = "auto"
|
||||||
|
config_kwargs["max_memory"] = get_max_memory()
|
||||||
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.misc import get_current_device
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -20,7 +21,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||||
"""
|
"""
|
||||||
if getattr(model, "quantization_method", None): # already set on current device
|
if getattr(model, "quantization_method", None): # already set on current device
|
||||||
@ -43,7 +44,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|||||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||||
return dispatch_model(model, **device_map_kwargs)
|
return dispatch_model(model, **device_map_kwargs)
|
||||||
else:
|
else:
|
||||||
return model.cuda()
|
return model.to(device=get_current_device())
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user