mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-07 03:12:13 +08:00
Compare commits
2 Commits
4f2f058d42
...
c4cf97d84d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4cf97d84d | ||
|
|
05271756d2 |
@ -66,6 +66,12 @@ EXPOSE 8000
|
||||
ENV http_proxy=
|
||||
ENV https_proxy=
|
||||
|
||||
# Set no_proxy environment variable
|
||||
ENV no_proxy="localhost, 127.0.0.1, ::1"
|
||||
|
||||
# fix pydantic version
|
||||
RUN pip install pydantic==2.10.6
|
||||
|
||||
# Reset pip config
|
||||
RUN pip config unset global.index-url && \
|
||||
pip config unset global.extra-index-url
|
||||
|
||||
45
examples/deepspeed/ds_z3_fp8_config.json
Normal file
45
examples/deepspeed/ds_z3_fp8_config.json
Normal file
@ -0,0 +1,45 @@
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"zero_force_ds_cpu_optimizer": true,
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": false
|
||||
},
|
||||
"overlap_comm": false,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1000000000,
|
||||
"reduce_bucket_size": 12845056,
|
||||
"stage3_prefetch_bucket_size": 11560550,
|
||||
"stage3_param_persistence_threshold": 35840,
|
||||
"stage3_max_live_parameters": 1000000000,
|
||||
"stage3_max_reuse_distance": 1000000000,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"steps_per_print": 10000000,
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"comms_config": {
|
||||
"verbose": false
|
||||
},
|
||||
"monitor_config": {
|
||||
"enabled": true,
|
||||
"tag": "DeepSpeedMonitor",
|
||||
"csv_monitor": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
}
|
||||
48
examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml
Normal file
48
examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
# FP8 training example with DeepSpeed ZeRO-3
|
||||
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
|
||||
# with DeepSpeed providing memory optimization (not FP8 handling)
|
||||
|
||||
### Model configuration
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### Method configuration
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### Dataset configuration
|
||||
dataset: identity
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### Output configuration
|
||||
output_dir: saves/llama3-8b/fp8-deepspeed/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### Training configuration
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 5.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
|
||||
### FP8 configuration
|
||||
fp8: true
|
||||
fp8_backend: torchao # Use TorchAO backend for FP8
|
||||
fp8_enable_fsdp_float8_all_gather: false # Not used with DeepSpeed
|
||||
|
||||
### DeepSpeed configuration
|
||||
deepspeed: examples/deepspeed/ds_z3_fp8_config.json
|
||||
|
||||
### Logging configuration
|
||||
report_to: wandb
|
||||
run_name: llama3_fp8_deepspeed_sft
|
||||
51
examples/extras/fp8/llama3_fp8_fsdp_sft.yaml
Normal file
51
examples/extras/fp8/llama3_fp8_fsdp_sft.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
# FP8 training example with FSDP
|
||||
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
|
||||
# with FSDP for distributed training and float8 all-gather optimization
|
||||
|
||||
### Model configuration
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### Method configuration
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### Dataset configuration
|
||||
dataset: identity
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### Output configuration
|
||||
output_dir: saves/llama3-8b/fp8-fsdp/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### Training configuration
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 5.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
|
||||
### FP8 configuration
|
||||
fp8: true
|
||||
fp8_backend: torchao # Use TorchAO backend for FP8
|
||||
fp8_enable_fsdp_float8_all_gather: true # Enable FSDP2 float8 all-gather optimization
|
||||
|
||||
### FSDP configuration (using training arguments - no separate FSDP config file)
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
|
||||
### Logging configuration
|
||||
report_to: wandb
|
||||
run_name: llama3_fp8_fsdp_sft
|
||||
3
setup.py
3
setup.py
@ -70,6 +70,9 @@ extra_require = {
|
||||
],
|
||||
"openmind": ["openmind"],
|
||||
"swanlab": ["swanlab"],
|
||||
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
|
||||
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"dev": ["pre-commit", "ruff", "pytest", "build"],
|
||||
}
|
||||
|
||||
|
||||
@ -213,6 +213,23 @@ class QuantizationArguments:
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -131,6 +131,14 @@ def _verify_model_args(
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
# Validate advanced training features
|
||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
|
||||
171
src/llamafactory/train/fp8_utils.py
Normal file
171
src/llamafactory/train/fp8_utils.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
|
||||
# Create Float8LinearConfig if torchao backend is used
|
||||
config = None
|
||||
if backend == "torchao" or backend == "auto":
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
|
||||
# Use rowwise scaling for better performance (as recommended by torchao)
|
||||
# Configure alignment requirements for FP8 kernels
|
||||
config = Float8LinearConfig.from_recipe_name("rowwise")
|
||||
|
||||
# Enable alignment for better kernel performance
|
||||
if hasattr(config, "enable_amax_init"):
|
||||
config.enable_amax_init = True
|
||||
if hasattr(config, "enable_pre_and_post_forward"):
|
||||
config.enable_pre_and_post_forward = True
|
||||
|
||||
# Create module filter function to skip problematic layers
|
||||
# TorchAO FP8 requires dimensions divisible by 16 for optimal kernels
|
||||
def module_filter_func(module, layer_name):
|
||||
# Skip embedding and output layers for numerical stability
|
||||
skip_layers = ["embed", "lm_head", "output", "classifier"]
|
||||
if any(skip_name in layer_name.lower() for skip_name in skip_layers):
|
||||
return False
|
||||
|
||||
# Only convert Linear layers
|
||||
if not (hasattr(module, "weight") and len(module.weight.shape) == 2):
|
||||
return False
|
||||
|
||||
# Check dimension alignment for FP8 kernels
|
||||
weight = module.weight
|
||||
in_features, out_features = weight.shape[1], weight.shape[0]
|
||||
|
||||
# Skip layers with dimensions not divisible by 16 to avoid kernel errors
|
||||
if in_features % 16 != 0 or out_features % 16 != 0:
|
||||
logger.debug(
|
||||
f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||
|
||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||
except Exception as e:
|
||||
logger.info_rank0(f"Failed to create FP8 configuration: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
"fp8" if FP8 is enabled, None otherwise
|
||||
"""
|
||||
return "fp8" if model_args.fp8 else None
|
||||
|
||||
|
||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||
|
||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||
DeepSpeed or FSDP is used for distributed training. This function sets up the environment
|
||||
variables and validates the FP8 configuration.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
"""
|
||||
import os
|
||||
|
||||
if not model_args.fp8:
|
||||
return
|
||||
|
||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||
|
||||
# Configure FP8 backend and options
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
if backend != "auto":
|
||||
os.environ["FP8_BACKEND"] = backend
|
||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||
|
||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||
|
||||
# Enable FSDP float8 all-gather optimization if requested
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||
|
||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||
|
||||
|
||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
"""Verify that FP8 training is actually working after model preparation.
|
||||
|
||||
Args:
|
||||
accelerator: The HuggingFace Accelerator instance
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
return
|
||||
|
||||
# Check Accelerate's FP8 status
|
||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
if backend == "torchao" or backend == "auto":
|
||||
logger.info_rank0(
|
||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||
"ensure model layer dimensions are mostly divisible by 16. "
|
||||
"If you encounter issues, try fp8_backend='te' with Transformer Engine."
|
||||
)
|
||||
else:
|
||||
logger.info_rank0(f"FP8 training enabled with {backend} backend.")
|
||||
|
||||
logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}")
|
||||
|
||||
if not fp8_enabled:
|
||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||
@ -21,21 +21,29 @@ from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
r"""Inherit Trainer for custom optimizer."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
@ -56,6 +64,10 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@ -29,6 +29,7 @@ from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
@ -37,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
@ -83,6 +88,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user