Merge branch 'main' into feat/support_ms

Former-commit-id: 00f5c9ee16
This commit is contained in:
hoshi-hiyouga
2023-12-01 20:23:46 +08:00
committed by GitHub
16 changed files with 121 additions and 62 deletions

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.modeling_utils import custom_object_save, unwrap_model
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from llmtuner.extras.constants import LOG_FILE_NAME
@@ -18,6 +19,16 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
_save_model_with_valuehead(
model=unwrap_model(kwargs.pop("model")),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(args.output_dir)
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
class LogCallback(TrainerCallback):

View File

@@ -69,11 +69,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str:
import accelerate
dummy_accelerator = accelerate.Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available() or torch.cuda.is_available():
return os.environ.get("LOCAL_RANK", "0")
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
return "cpu"
def get_logits_processor() -> "LogitsProcessorList":