import gc import torch from typing import TYPE_CHECKING, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList try: from transformers.utils import ( is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available ) _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available except ImportError: _is_fp16_available = torch.cuda.is_available() _is_bf16_available = torch.cuda.is_bf16_supported() if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel class AverageMeter: r""" Computes and stores the average and current value. """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. """ trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 if param.__class__.__name__ == "Params4bit": num_params = num_params * 2 all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: r""" Infers the optimal dtype according to the model_dtype and device compatibility. """ if _is_bf16_available and model_dtype == torch.bfloat16: return torch.bfloat16 elif _is_fp16_available: return torch.float16 else: return torch.float32 def get_logits_processor() -> LogitsProcessorList: r""" Gets logits processor that removes NaN and Inf logits. """ logits_processor = LogitsProcessorList() logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor def torch_gc() -> None: r""" Collects GPU memory. """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": r""" Dispatches a pre-trained model to GPUs with balanced memory. Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 """ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing return model if torch.cuda.device_count() > 1: from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory if model._no_split_modules is None: raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} max_memory = get_balanced_memory(model, **kwargs) # Make sure tied weights are tied before creating the device map. model.tie_weights() device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) return dispatch_model(model, device_map) else: return model.cuda()