disentangle model from tuner and rename modules

This commit is contained in:
hiyouga
2023-11-15 16:29:09 +08:00
parent 2f02f688e1
commit 4736344eb1
57 changed files with 324 additions and 263 deletions

View File

@@ -13,14 +13,13 @@ try:
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
@@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
@@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
return parser.parse_args_into_dataclasses()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
def torch_gc() -> None:
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
Collects GPU memory.
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()