mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
disentangle model from tuner and rename modules
This commit is contained in:
@@ -13,14 +13,13 @@ try:
|
||||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
@@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
@@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
Collects GPU memory.
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
Reference in New Issue
Block a user