diff --git a/README.md b/README.md index 30d1b9f9..d7e3fbca 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage. -[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. Use `dataset_shards` to enable parallel preprocessing with streaming. +[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. [23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details. @@ -412,7 +412,7 @@ huggingface-cli login | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.3 | 0.7.3 | +| vllm | 0.4.3 | 0.8.2 | | flash-attn | 2.3.0 | 2.7.2 | ### Hardware Requirement diff --git a/README_zh.md b/README_zh.md index d6643906..5c64dab7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -206,7 +206,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详细用法请参照 [examples](examples/README_zh.md)。 -[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。 用 `dataset_shards` 来开启多进程加载。 +[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。 @@ -414,7 +414,7 @@ huggingface-cli login | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.3 | 0.7.3 | +| vllm | 0.4.3 | 0.8.2 | | flash-attn | 2.3.0 | 2.7.2 | ### 硬件依赖 diff --git a/examples/accelerate/fsdp_config.yaml b/examples/accelerate/fsdp_config.yaml index 6fb09a95..09d2f5d7 100644 --- a/examples/accelerate/fsdp_config.yaml +++ b/examples/accelerate/fsdp_config.yaml @@ -7,16 +7,16 @@ fsdp_config: fsdp_backward_prefetch: BACKWARD_PRE fsdp_forward_prefetch: false fsdp_cpu_ram_efficient_loading: true - fsdp_offload_params: true # offload may affect training speed + fsdp_offload_params: false fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: FULL_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_training_function: main -mixed_precision: bf16 # or fp16 -num_machines: 1 # the number of nodes -num_processes: 2 # the number of GPUs in all nodes +mixed_precision: bf16 # or fp16 +num_machines: 1 # the number of nodes +num_processes: 2 # the number of GPUs in all nodes rdzv_backend: static same_network: true tpu_env: [] diff --git a/examples/accelerate/fsdp_config_offload.yaml b/examples/accelerate/fsdp_config_offload.yaml new file mode 100644 index 00000000..a55e652e --- /dev/null +++ b/examples/accelerate/fsdp_config_offload.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: true # offload may affect training speed + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 # or fp16 +num_machines: 1 # the number of nodes +num_processes: 2 # the number of GPUs in all nodes +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 24334911..dceb1d31 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -56,7 +56,7 @@ def vllm_infer( Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo """ - check_version("vllm>=0.4.3,<=0.7.3") + check_version("vllm>=0.4.3,<=0.8.2") if pipeline_parallel_size > get_device_count(): raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") diff --git a/setup.py b/setup.py index 9053d483..6cfe66a0 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ extra_require = { "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3,<=0.8.1"], + "vllm": ["vllm>=0.4.3,<=0.8.2"], "sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"], "galore": ["galore-torch"], "apollo": ["apollo-torch"], diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 78fa1192..9bfc55e3 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -101,12 +101,10 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, - use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded + use_streaming=data_args.streaming, ) if isinstance(dataset, MsDataset): dataset = dataset.to_hf_dataset() - if data_args.streaming and data_args.dataset_shards: - dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards) elif dataset_attr.load_from == "om_hub": check_version("openmind>=0.8.0", mandatory=True) @@ -135,10 +133,10 @@ def _load_single_dataset( token=model_args.hf_hub_token, num_proc=data_args.preprocessing_num_workers, trust_remote_code=model_args.trust_remote_code, - streaming=data_args.streaming and not data_args.dataset_shards, + streaming=data_args.streaming and dataset_attr.load_from != "file", ) - if data_args.streaming and data_args.dataset_shards: - dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards) + if data_args.streaming and dataset_attr.load_from == "file": + dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers) if dataset_attr.num_samples is not None and not data_args.streaming: target_num = dataset_attr.num_samples diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b6928636..b9180db2 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1186,6 +1186,9 @@ class Qwen2OmniPlugin(BasePlugin): messages = deepcopy(messages) if self.expand_mm_tokens: mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + else: + mm_inputs = {} + num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 use_audio_in_video = getattr(processor, "use_audio_in_video", False) @@ -1193,18 +1196,22 @@ class Qwen2OmniPlugin(BasePlugin): if "feature_attention_mask" in mm_inputs: input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 audio_lengths = (input_lengths - 2) // 2 + 1 + if mm_inputs.get("image_grid_thw", None) is not None: image_grid_thw = mm_inputs["image_grid_thw"] merge_length = processor.omni_processor.merge_size**2 + if mm_inputs.get("video_grid_thw", None) is not None: video_grid_thw = mm_inputs["video_grid_thw"] merge_length = processor.omni_processor.merge_size**2 if use_audio_in_video: - assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`" - assert mm_inputs.get("video_grid_thw", None) is not None, ( - "video_grid_thw should be exist when use_audio_in_video is `True`" - ) + if audio_lengths is None: + raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") + + if not mm_inputs.get("video_grid_thw", None): + raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") + positions_list = [] for i, message in enumerate(messages): # get multimodal index when use_audio positions = [] @@ -1216,6 +1223,7 @@ class Qwen2OmniPlugin(BasePlugin): break positions.append((pos, special_token)) start = pos + len(special_token) + positions_list.append(positions.sort(key=lambda x: x[0])) for message in messages: @@ -1278,6 +1286,7 @@ class Qwen2OmniPlugin(BasePlugin): content = content.replace(AUDIO_PLACEHOLDER, "", 1) num_audio_tokens += 1 num_video_tokens += 1 + message["content"] = content if len(audios) != num_audio_tokens: diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index b87829ea..22fa490a 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -17,6 +17,7 @@ import gc import os +import socket from typing import TYPE_CHECKING, Any, Literal, Union import torch @@ -278,10 +279,16 @@ def use_ray() -> bool: def find_available_port() -> int: """Find an available port on the local machine.""" - import socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() return port + + +def fix_proxy(ipv6_enabled: bool) -> None: + """Fix proxy settings for gradio ui.""" + os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0" + if ipv6_enabled: + for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(name, None) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 3a51142b..3a66b2c0 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,10 +83,6 @@ class DataArguments: default=None, metadata={"help": "The number of processes to use for the pre-processing."}, ) - dataset_shards: Optional[int] = field( - default=None, - metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."}, - ) max_samples: Optional[int] = field( default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 66bea4f3..f23ba89f 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -135,7 +135,7 @@ def _check_extra_dependencies( check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == EngineName.VLLM: - check_version("vllm>=0.4.3,<=0.8.1") + check_version("vllm>=0.4.3,<=0.8.2") check_version("vllm", mandatory=True) elif model_args.infer_backend == EngineName.SGLANG: check_version("sglang>=0.4.4") diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 7d3ed389..bbc0056c 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING import torch from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.modeling_utils import is_fsdp_enabled from ..extras import logging from .model_utils.misc import find_all_linear_modules, find_expanded_modules @@ -277,14 +276,14 @@ def init_adapter( # cast trainable parameters to float32 if: # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) - # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) + # 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32) cast_trainable_params_to_fp32 = False if not is_trainable: pass elif finetuning_args.pure_bf16 or finetuning_args.use_badam: logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.") - elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): - logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.") + elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled(): + logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.") else: logger.info_rank0("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True diff --git a/src/llamafactory/model/model_utils/kv_cache.py b/src/llamafactory/model/model_utils/kv_cache.py new file mode 100644 index 00000000..cd2c119f --- /dev/null +++ b/src/llamafactory/model/model_utils/kv_cache.py @@ -0,0 +1,44 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...extras import logging + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable: + setattr(config, "use_cache", model_args.use_cache) + if hasattr(config, "text_config"): + setattr(config.text_config, "use_cache", model_args.use_cache) + + if model_args.use_cache: + logger.info_rank0("KV cache is enabled for faster generation.") + else: + logger.info_rank0("KV cache is disabled.") + else: + setattr(config, "use_cache", False) + if hasattr(config, "text_config"): + setattr(config.text_config, "use_cache", False) + + logger.info_rank0("KV cache is disabled during training.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index a1142a30..6b690f40 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer +from .model_utils.kv_cache import configure_kv_cache from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.packing import configure_packing @@ -102,23 +103,13 @@ def patch_config( configure_moe(config, model_args, is_trainable) configure_visual_model(config) configure_packing(model_args, is_trainable) - - if model_args.use_cache and not is_trainable: - setattr(config, "use_cache", True) - logger.info_rank0("Using KV cache for faster generation.") - - if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache: - text_config = config.text_config - setattr(text_config, "use_cache", False) + configure_kv_cache(config, model_args, is_trainable) if getattr(config, "model_type", None) == "qwen": setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) - if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": - setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn - if getattr(config, "model_type", None) == "minicpmo": setattr(config, "init_audio", True) setattr(config, "init_tts", False) diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index d8932ac0..691a88a3 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -14,10 +14,8 @@ import os import platform -import httpx - -from ..extras.misc import is_env_enabled +from ..extras.misc import fix_proxy, is_env_enabled from ..extras.packages import is_gradio_available from .common import save_config from .components import ( @@ -74,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": def create_web_demo() -> "gr.Blocks": engine = Engine(pure_chat=True) + hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0] - with gr.Blocks(title="Web Demo", css=CSS) as demo: + with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo: lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1) engine.manager.add_elems("top", dict(lang=lang)) @@ -90,30 +89,18 @@ def create_web_demo() -> "gr.Blocks": def run_web_ui() -> None: - os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0" gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_share = is_env_enabled("GRADIO_SHARE") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") - httpx.HTTPCORE_OPTS = {"trust_env": False} - - try: - demo = create_ui().queue() - demo.launch( - share=gradio_share, - server_name=server_name, - inbrowser=True, - prevent_thread_lock=False, - show_error=True, - quiet=True, - favicon_path=None - ) - except Exception as e: - print(f"Error launching web UI: {str(e)}") - raise + print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860") + fix_proxy(ipv6_enabled=gradio_ipv6) + create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) def run_web_demo() -> None: gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_share = is_env_enabled("GRADIO_SHARE") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") + print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860") + fix_proxy(ipv6_enabled=gradio_ipv6) create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) diff --git a/src/webui.py b/src/webui.py index 088f2365..f13d2f26 100644 --- a/src/webui.py +++ b/src/webui.py @@ -14,7 +14,7 @@ import os -from llamafactory.extras.misc import is_env_enabled +from llamafactory.extras.misc import fix_proxy, is_env_enabled from llamafactory.webui.interface import create_ui @@ -22,6 +22,8 @@ def main(): gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_share = is_env_enabled("GRADIO_SHARE") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") + print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860") + fix_proxy(ipv6_enabled=gradio_ipv6) create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)