import gc import os import sys import torch from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList try: from transformers.utils import ( is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available ) _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available() except ImportError: _is_fp16_available = torch.cuda.is_available() try: _is_bf16_available = torch.cuda.is_bf16_supported() except: _is_bf16_available = False if TYPE_CHECKING: from transformers import HfArgumentParser class AverageMeter: r""" Computes and stores the average and current value. """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. """ trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 if param.__class__.__name__ == "Params4bit": num_params = num_params * 2 all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def get_current_device() -> str: import accelerate dummy_accelerator = accelerate.Accelerator() if accelerate.utils.is_xpu_available(): return "xpu:{}".format(dummy_accelerator.local_process_index) else: return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" def get_logits_processor() -> "LogitsProcessorList": r""" Gets logits processor that removes NaN and Inf logits. """ logits_processor = LogitsProcessorList() logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: r""" Infers the optimal dtype according to the model_dtype and device compatibility. """ if _is_bf16_available and model_dtype == torch.bfloat16: return torch.bfloat16 elif _is_fp16_available: return torch.float16 else: return torch.float32 def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): return parser.parse_json_file(os.path.abspath(sys.argv[1])) else: return parser.parse_args_into_dataclasses() def torch_gc() -> None: r""" Collects GPU memory. """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect()