mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
parent
ea2eb8e58c
commit
3e3652f4e0
@ -28,7 +28,7 @@ class AverageMeter:
|
|||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
# Avoid runtime error in model.generate(do_sample=True).
|
# Avoids runtime error in model.generate(do_sample=True).
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
@ -63,7 +63,7 @@ def print_trainable_params(model: torch.nn.Module) -> None:
|
|||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
@ -83,19 +83,23 @@ def prepare_model_for_training(
|
|||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||||
input_dtype = output_embedding_layer.weight.dtype
|
input_dtype = output_layer.weight.dtype
|
||||||
|
|
||||||
class CastOutputToFloat(torch.nn.Sequential):
|
class CastOutputToFloat(torch.nn.Sequential):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||||
|
|
||||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
new_output_layer = CastOutputToFloat(output_layer)
|
||||||
|
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
|
||||||
|
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
|
||||||
|
setattr(new_output_layer, "weight", output_layer.weight)
|
||||||
|
setattr(model, output_layer_name, new_output_layer)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
r"""
|
r"""
|
||||||
Collects GPU memory.
|
Collects GPU memory.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user