mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[model] fix kv cache (#7564)
This commit is contained in:
		
							parent
							
								
									a13b1bb49a
								
							
						
					
					
						commit
						2bfcad2394
					
				@ -204,7 +204,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. Use `dataset_shards` to enable parallel preprocessing with streaming.
 | 
			
		||||
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
 | 
			
		||||
 | 
			
		||||
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
 | 
			
		||||
 | 
			
		||||
@ -412,7 +412,7 @@ huggingface-cli login
 | 
			
		||||
| CUDA         | 11.6    | 12.2      |
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.16.4    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.7.3     |
 | 
			
		||||
| vllm         | 0.4.3   | 0.8.2     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### Hardware Requirement
 | 
			
		||||
 | 
			
		||||
@ -206,7 +206,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
 | 
			
		||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。 用 `dataset_shards` 来开启多进程加载。
 | 
			
		||||
[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。
 | 
			
		||||
 | 
			
		||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
 | 
			
		||||
 | 
			
		||||
@ -414,7 +414,7 @@ huggingface-cli login
 | 
			
		||||
| CUDA         | 11.6    | 12.2      |
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.16.4    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.7.3     |
 | 
			
		||||
| vllm         | 0.4.3   | 0.8.2     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### 硬件依赖
 | 
			
		||||
 | 
			
		||||
@ -7,16 +7,16 @@ fsdp_config:
 | 
			
		||||
  fsdp_backward_prefetch: BACKWARD_PRE
 | 
			
		||||
  fsdp_forward_prefetch: false
 | 
			
		||||
  fsdp_cpu_ram_efficient_loading: true
 | 
			
		||||
  fsdp_offload_params: true # offload may affect training speed
 | 
			
		||||
  fsdp_offload_params: false
 | 
			
		||||
  fsdp_sharding_strategy: FULL_SHARD
 | 
			
		||||
  fsdp_state_dict_type: FULL_STATE_DICT
 | 
			
		||||
  fsdp_sync_module_states: true
 | 
			
		||||
  fsdp_use_orig_params: true
 | 
			
		||||
machine_rank: 0
 | 
			
		||||
main_training_function: main
 | 
			
		||||
mixed_precision: bf16 # or fp16
 | 
			
		||||
num_machines: 1 # the number of nodes
 | 
			
		||||
num_processes: 2 # the number of GPUs in all nodes
 | 
			
		||||
mixed_precision: bf16  # or fp16
 | 
			
		||||
num_machines: 1  # the number of nodes
 | 
			
		||||
num_processes: 2  # the number of GPUs in all nodes
 | 
			
		||||
rdzv_backend: static
 | 
			
		||||
same_network: true
 | 
			
		||||
tpu_env: []
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								examples/accelerate/fsdp_config_offload.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								examples/accelerate/fsdp_config_offload.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
compute_environment: LOCAL_MACHINE
 | 
			
		||||
debug: false
 | 
			
		||||
distributed_type: FSDP
 | 
			
		||||
downcast_bf16: 'no'
 | 
			
		||||
fsdp_config:
 | 
			
		||||
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
 | 
			
		||||
  fsdp_backward_prefetch: BACKWARD_PRE
 | 
			
		||||
  fsdp_forward_prefetch: false
 | 
			
		||||
  fsdp_cpu_ram_efficient_loading: true
 | 
			
		||||
  fsdp_offload_params: true  # offload may affect training speed
 | 
			
		||||
  fsdp_sharding_strategy: FULL_SHARD
 | 
			
		||||
  fsdp_state_dict_type: FULL_STATE_DICT
 | 
			
		||||
  fsdp_sync_module_states: true
 | 
			
		||||
  fsdp_use_orig_params: true
 | 
			
		||||
machine_rank: 0
 | 
			
		||||
main_training_function: main
 | 
			
		||||
mixed_precision: bf16  # or fp16
 | 
			
		||||
num_machines: 1  # the number of nodes
 | 
			
		||||
num_processes: 2  # the number of GPUs in all nodes
 | 
			
		||||
rdzv_backend: static
 | 
			
		||||
same_network: true
 | 
			
		||||
tpu_env: []
 | 
			
		||||
tpu_use_cluster: false
 | 
			
		||||
tpu_use_sudo: false
 | 
			
		||||
use_cpu: false
 | 
			
		||||
@ -56,7 +56,7 @@ def vllm_infer(
 | 
			
		||||
 | 
			
		||||
    Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
 | 
			
		||||
    """
 | 
			
		||||
    check_version("vllm>=0.4.3,<=0.7.3")
 | 
			
		||||
    check_version("vllm>=0.4.3,<=0.8.2")
 | 
			
		||||
    if pipeline_parallel_size > get_device_count():
 | 
			
		||||
        raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -53,7 +53,7 @@ extra_require = {
 | 
			
		||||
    "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
 | 
			
		||||
    "awq": ["autoawq"],
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<=0.8.1"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<=0.8.2"],
 | 
			
		||||
    "sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"],
 | 
			
		||||
    "galore": ["galore-torch"],
 | 
			
		||||
    "apollo": ["apollo-torch"],
 | 
			
		||||
 | 
			
		||||
@ -101,12 +101,10 @@ def _load_single_dataset(
 | 
			
		||||
            split=dataset_attr.split,
 | 
			
		||||
            cache_dir=cache_dir,
 | 
			
		||||
            token=model_args.ms_hub_token,
 | 
			
		||||
            use_streaming=data_args.streaming and not data_args.dataset_shards,  # only set to True when user specified streaming but do not want dataset to be sharded
 | 
			
		||||
            use_streaming=data_args.streaming,
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(dataset, MsDataset):
 | 
			
		||||
            dataset = dataset.to_hf_dataset()
 | 
			
		||||
        if data_args.streaming and data_args.dataset_shards:
 | 
			
		||||
            dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
 | 
			
		||||
 | 
			
		||||
    elif dataset_attr.load_from == "om_hub":
 | 
			
		||||
        check_version("openmind>=0.8.0", mandatory=True)
 | 
			
		||||
@ -135,10 +133,10 @@ def _load_single_dataset(
 | 
			
		||||
            token=model_args.hf_hub_token,
 | 
			
		||||
            num_proc=data_args.preprocessing_num_workers,
 | 
			
		||||
            trust_remote_code=model_args.trust_remote_code,
 | 
			
		||||
            streaming=data_args.streaming and not data_args.dataset_shards,
 | 
			
		||||
            streaming=data_args.streaming and dataset_attr.load_from != "file",
 | 
			
		||||
        )
 | 
			
		||||
        if data_args.streaming and data_args.dataset_shards:
 | 
			
		||||
            dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
 | 
			
		||||
        if data_args.streaming and dataset_attr.load_from == "file":
 | 
			
		||||
            dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.num_samples is not None and not data_args.streaming:
 | 
			
		||||
        target_num = dataset_attr.num_samples
 | 
			
		||||
 | 
			
		||||
@ -1186,6 +1186,9 @@ class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        else:
 | 
			
		||||
            mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
 | 
			
		||||
        use_audio_in_video = getattr(processor, "use_audio_in_video", False)
 | 
			
		||||
 | 
			
		||||
@ -1193,18 +1196,22 @@ class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
 | 
			
		||||
            audio_lengths = (input_lengths - 2) // 2 + 1
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("image_grid_thw", None) is not None:
 | 
			
		||||
            image_grid_thw = mm_inputs["image_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("video_grid_thw", None) is not None:
 | 
			
		||||
            video_grid_thw = mm_inputs["video_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if use_audio_in_video:
 | 
			
		||||
            assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
 | 
			
		||||
            assert mm_inputs.get("video_grid_thw", None) is not None, (
 | 
			
		||||
                "video_grid_thw should be exist when use_audio_in_video is `True`"
 | 
			
		||||
            )
 | 
			
		||||
            if audio_lengths is None:
 | 
			
		||||
                raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            if not mm_inputs.get("video_grid_thw", None):
 | 
			
		||||
                raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for i, message in enumerate(messages):  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
@ -1216,6 +1223,7 @@ class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
                        start = pos + len(special_token)
 | 
			
		||||
 | 
			
		||||
                positions_list.append(positions.sort(key=lambda x: x[0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -1278,6 +1286,7 @@ class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import os
 | 
			
		||||
import socket
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Literal, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -278,10 +279,16 @@ def use_ray() -> bool:
 | 
			
		||||
 | 
			
		||||
def find_available_port() -> int:
 | 
			
		||||
    """Find an available port on the local machine."""
 | 
			
		||||
    import socket
 | 
			
		||||
 | 
			
		||||
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | 
			
		||||
    sock.bind(("", 0))
 | 
			
		||||
    port = sock.getsockname()[1]
 | 
			
		||||
    sock.close()
 | 
			
		||||
    return port
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fix_proxy(ipv6_enabled: bool) -> None:
 | 
			
		||||
    """Fix proxy settings for gradio ui."""
 | 
			
		||||
    os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
 | 
			
		||||
    if ipv6_enabled:
 | 
			
		||||
        for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
 | 
			
		||||
            os.environ.pop(name, None)
 | 
			
		||||
 | 
			
		||||
@ -83,10 +83,6 @@ class DataArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of processes to use for the pre-processing."},
 | 
			
		||||
    )
 | 
			
		||||
    dataset_shards: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."},
 | 
			
		||||
    )
 | 
			
		||||
    max_samples: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
 | 
			
		||||
 | 
			
		||||
@ -135,7 +135,7 @@ def _check_extra_dependencies(
 | 
			
		||||
        check_version("mixture-of-depth>=1.1.6", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == EngineName.VLLM:
 | 
			
		||||
        check_version("vllm>=0.4.3,<=0.8.1")
 | 
			
		||||
        check_version("vllm>=0.4.3,<=0.8.2")
 | 
			
		||||
        check_version("vllm", mandatory=True)
 | 
			
		||||
    elif model_args.infer_backend == EngineName.SGLANG:
 | 
			
		||||
        check_version("sglang>=0.4.4")
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
 | 
			
		||||
import torch
 | 
			
		||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.modeling_utils import is_fsdp_enabled
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
 | 
			
		||||
@ -277,14 +276,14 @@ def init_adapter(
 | 
			
		||||
 | 
			
		||||
    # cast trainable parameters to float32 if:
 | 
			
		||||
    # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
 | 
			
		||||
    # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
 | 
			
		||||
    # 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
 | 
			
		||||
    cast_trainable_params_to_fp32 = False
 | 
			
		||||
    if not is_trainable:
 | 
			
		||||
        pass
 | 
			
		||||
    elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
 | 
			
		||||
        logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
 | 
			
		||||
    elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
 | 
			
		||||
        logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
 | 
			
		||||
    elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
 | 
			
		||||
        logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info_rank0("Upcasting trainable params to float32.")
 | 
			
		||||
        cast_trainable_params_to_fp32 = True
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										44
									
								
								src/llamafactory/model/model_utils/kv_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								src/llamafactory/model/model_utils/kv_cache.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
    from ...hparams import ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable:
 | 
			
		||||
        setattr(config, "use_cache", model_args.use_cache)
 | 
			
		||||
        if hasattr(config, "text_config"):
 | 
			
		||||
            setattr(config.text_config, "use_cache", model_args.use_cache)
 | 
			
		||||
 | 
			
		||||
        if model_args.use_cache:
 | 
			
		||||
            logger.info_rank0("KV cache is enabled for faster generation.")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info_rank0("KV cache is disabled.")
 | 
			
		||||
    else:
 | 
			
		||||
        setattr(config, "use_cache", False)
 | 
			
		||||
        if hasattr(config, "text_config"):
 | 
			
		||||
            setattr(config.text_config, "use_cache", False)
 | 
			
		||||
 | 
			
		||||
        logger.info_rank0("KV cache is disabled during training.")
 | 
			
		||||
@ -27,6 +27,7 @@ from ..extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
 | 
			
		||||
from .model_utils.checkpointing import prepare_model_for_training
 | 
			
		||||
from .model_utils.embedding import resize_embedding_layer
 | 
			
		||||
from .model_utils.kv_cache import configure_kv_cache
 | 
			
		||||
from .model_utils.longlora import configure_longlora
 | 
			
		||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
 | 
			
		||||
from .model_utils.packing import configure_packing
 | 
			
		||||
@ -102,23 +103,13 @@ def patch_config(
 | 
			
		||||
    configure_moe(config, model_args, is_trainable)
 | 
			
		||||
    configure_visual_model(config)
 | 
			
		||||
    configure_packing(model_args, is_trainable)
 | 
			
		||||
 | 
			
		||||
    if model_args.use_cache and not is_trainable:
 | 
			
		||||
        setattr(config, "use_cache", True)
 | 
			
		||||
        logger.info_rank0("Using KV cache for faster generation.")
 | 
			
		||||
 | 
			
		||||
    if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
 | 
			
		||||
        text_config = config.text_config
 | 
			
		||||
        setattr(text_config, "use_cache", False)
 | 
			
		||||
    configure_kv_cache(config, model_args, is_trainable)
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "qwen":
 | 
			
		||||
        setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
 | 
			
		||||
        for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
 | 
			
		||||
            setattr(config, dtype_name, model_args.compute_dtype == dtype)
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
 | 
			
		||||
        setattr(config, "use_cache", False)  # qwen2 does not support use_cache when using flash attn
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "minicpmo":
 | 
			
		||||
        setattr(config, "init_audio", True)
 | 
			
		||||
        setattr(config, "init_tts", False)
 | 
			
		||||
 | 
			
		||||
@ -14,10 +14,8 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import platform
 | 
			
		||||
import httpx
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.misc import fix_proxy, is_env_enabled
 | 
			
		||||
from ..extras.packages import is_gradio_available
 | 
			
		||||
from .common import save_config
 | 
			
		||||
from .components import (
 | 
			
		||||
@ -74,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
 | 
			
		||||
 | 
			
		||||
def create_web_demo() -> "gr.Blocks":
 | 
			
		||||
    engine = Engine(pure_chat=True)
 | 
			
		||||
    hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
 | 
			
		||||
 | 
			
		||||
    with gr.Blocks(title="Web Demo", css=CSS) as demo:
 | 
			
		||||
    with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo:
 | 
			
		||||
        lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1)
 | 
			
		||||
        engine.manager.add_elems("top", dict(lang=lang))
 | 
			
		||||
 | 
			
		||||
@ -90,30 +89,18 @@ def create_web_demo() -> "gr.Blocks":
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_web_ui() -> None:
 | 
			
		||||
    os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
 | 
			
		||||
    gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
 | 
			
		||||
    gradio_share = is_env_enabled("GRADIO_SHARE")
 | 
			
		||||
    server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
 | 
			
		||||
    httpx.HTTPCORE_OPTS = {"trust_env": False}
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        demo = create_ui().queue()
 | 
			
		||||
        demo.launch(
 | 
			
		||||
            share=gradio_share,
 | 
			
		||||
            server_name=server_name,
 | 
			
		||||
            inbrowser=True,
 | 
			
		||||
            prevent_thread_lock=False,
 | 
			
		||||
            show_error=True,
 | 
			
		||||
            quiet=True,
 | 
			
		||||
            favicon_path=None
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"Error launching web UI: {str(e)}")
 | 
			
		||||
        raise
 | 
			
		||||
    print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
 | 
			
		||||
    fix_proxy(ipv6_enabled=gradio_ipv6)
 | 
			
		||||
    create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_web_demo() -> None:
 | 
			
		||||
    gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
 | 
			
		||||
    gradio_share = is_env_enabled("GRADIO_SHARE")
 | 
			
		||||
    server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
 | 
			
		||||
    print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
 | 
			
		||||
    fix_proxy(ipv6_enabled=gradio_ipv6)
 | 
			
		||||
    create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from llamafactory.extras.misc import is_env_enabled
 | 
			
		||||
from llamafactory.extras.misc import fix_proxy, is_env_enabled
 | 
			
		||||
from llamafactory.webui.interface import create_ui
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,8 @@ def main():
 | 
			
		||||
    gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
 | 
			
		||||
    gradio_share = is_env_enabled("GRADIO_SHARE")
 | 
			
		||||
    server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
 | 
			
		||||
    print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
 | 
			
		||||
    fix_proxy(ipv6_enabled=gradio_ipv6)
 | 
			
		||||
    create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user