2 Commits

Author SHA1 Message Date
Username_Full
92fa3df4c4 [trainer] add dpo/kto fsdp fsdp2 support (#10127) 2026-02-04 23:27:12 +08:00
Hertz
8bedfafa4e [model] support MiniCPM-o-4.5 (#10163)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-04 23:21:27 +08:00
9 changed files with 38 additions and 20 deletions

View File

@@ -308,7 +308,7 @@ Read technical notes:
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 | | [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o/MiniCPM-V 4.5](https://huggingface.co/openbmb) | 8B/9B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |

View File

@@ -310,7 +310,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 | | [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o/MiniCPM-V 4.5](https://huggingface.co/openbmb) | 8B/9B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |

View File

@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import time import time
from enum import Enum, unique from enum import StrEnum, unique
from typing import Any, Literal from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@unique @unique
class Role(str, Enum): class Role(StrEnum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
@@ -29,7 +29,7 @@ class Role(str, Enum):
@unique @unique
class Finish(str, Enum): class Finish(StrEnum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"
TOOL = "tool_calls" TOOL = "tool_calls"

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import json import json
from enum import Enum, unique from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
import fsspec import fsspec
@@ -35,7 +35,7 @@ SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique @unique
class Role(str, Enum): class Role(StrEnum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"

View File

@@ -14,7 +14,7 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum, unique from enum import StrEnum, unique
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@@ -110,7 +110,7 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class AttentionFunction(str, Enum): class AttentionFunction(StrEnum):
AUTO = "auto" AUTO = "auto"
DISABLED = "disabled" DISABLED = "disabled"
SDPA = "sdpa" SDPA = "sdpa"
@@ -118,21 +118,21 @@ class AttentionFunction(str, Enum):
FA3 = "fa3" FA3 = "fa3"
class EngineName(str, Enum): class EngineName(StrEnum):
HF = "huggingface" HF = "huggingface"
VLLM = "vllm" VLLM = "vllm"
SGLANG = "sglang" SGLANG = "sglang"
KT = "ktransformers" KT = "ktransformers"
class DownloadSource(str, Enum): class DownloadSource(StrEnum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
OPENMIND = "om" OPENMIND = "om"
@unique @unique
class QuantizationMethod(str, Enum): class QuantizationMethod(StrEnum):
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.""" r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BNB = "bnb" BNB = "bnb"
@@ -146,7 +146,7 @@ class QuantizationMethod(str, Enum):
FP8 = "fp8" FP8 = "fp8"
class RopeScaling(str, Enum): class RopeScaling(StrEnum):
LINEAR = "linear" LINEAR = "linear"
DYNAMIC = "dynamic" DYNAMIC = "dynamic"
YARN = "yarn" YARN = "yarn"
@@ -1840,6 +1840,10 @@ register_model_group(
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6", DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
}, },
"MiniCPM-o-4_5": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-4_5",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-4_5",
},
}, },
template="minicpm_o", template="minicpm_o",
multimodal=True, multimodal=True,

View File

@@ -25,8 +25,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.models.utils import prepare_deepspeed, prepare_fsdp
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -97,6 +97,13 @@ class CustomDPOTrainer(DPOTrainer):
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()

View File

@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.models.utils import prepare_deepspeed, prepare_fsdp
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -99,6 +99,13 @@ class CustomKTOTrainer(KTOTrainer):
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()

View File

@@ -27,7 +27,7 @@ Including:
import os import os
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, unique from enum import StrEnum, unique
from functools import lru_cache, wraps from functools import lru_cache, wraps
from typing import Optional from typing import Optional
@@ -39,7 +39,7 @@ from ..utils.types import ProcessGroup, Tensor, TensorLike
@unique @unique
class DeviceType(str, Enum): class DeviceType(StrEnum):
CPU = "cpu" CPU = "cpu"
CUDA = "cuda" CUDA = "cuda"
META = "meta" META = "meta"
@@ -49,7 +49,7 @@ class DeviceType(str, Enum):
@unique @unique
class ReduceOp(str, Enum): class ReduceOp(StrEnum):
SUM = "sum" SUM = "sum"
MEAN = "mean" MEAN = "mean"
MAX = "max" MAX = "max"

View File

@@ -28,7 +28,7 @@ And data parallelism types:
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import StrEnum
from typing import Any, Optional from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
@@ -42,7 +42,7 @@ from . import helper
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Dim(str, Enum): class Dim(StrEnum):
"""Dimension names.""" """Dimension names."""
MP_REPLICATE = "mp_replicate" MP_REPLICATE = "mp_replicate"