mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-28 01:30:36 +08:00
Compare commits
2 Commits
10a446e373
...
45f0437a14
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45f0437a14 | ||
|
|
d4e120423d |
@@ -14,9 +14,13 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.misc import check_version
|
from ...extras.misc import check_version
|
||||||
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -25,6 +29,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.57.0"):
|
||||||
|
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
|
||||||
|
|
||||||
|
|
||||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
|
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
|
||||||
check_version("deepspeed>=0.13.0")
|
check_version("deepspeed>=0.13.0")
|
||||||
@@ -175,3 +182,66 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
|||||||
|
|
||||||
elif model_type == "jetmoe":
|
elif model_type == "jetmoe":
|
||||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = config.num_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
|
||||||
|
config, intermediate_size=config.moe_intermediate_size
|
||||||
|
)
|
||||||
|
for _ in range(self.num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
|
||||||
|
# Calculate the routing weights for all experts
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
|
||||||
|
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
|
||||||
|
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||||
|
# Initialize the all-zero weight matrix (same shape as all experts)
|
||||||
|
full_routing_weights = torch.zeros_like(routing_weights)
|
||||||
|
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
|
||||||
|
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
|
||||||
|
|
||||||
|
# Normalized top_k weights (keep the original logic consistent)
|
||||||
|
if self.norm_topk_prob:
|
||||||
|
# Calculate the sum of the weights top_k each row (for normalization)
|
||||||
|
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
# Avoid dividing by zero
|
||||||
|
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
|
||||||
|
full_routing_weights /= top_k_sum
|
||||||
|
|
||||||
|
# Convert back to the input data type
|
||||||
|
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
final_hidden_states = torch.zeros(
|
||||||
|
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Go through all the experts (not just the selected ones)
|
||||||
|
for expert_idx in range(self.num_experts):
|
||||||
|
expert_layer = self.experts[expert_idx]
|
||||||
|
# Get the weight of the current expert (inactive expert has a weight of 0 here)
|
||||||
|
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
|
||||||
|
# All samples participate in the calculations of the current expert, the weight may be equal to 0
|
||||||
|
current_hidden_states = expert_layer(hidden_states) * expert_weights
|
||||||
|
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
|
||||||
|
final_hidden_states += current_hidden_states
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
|
|||||||
@@ -43,10 +43,20 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.57.0"):
|
||||||
|
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||||
|
if is_transformers_version_greater_than("4.57.0"):
|
||||||
|
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||||
|
|
||||||
|
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
@@ -136,6 +146,9 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||||
|
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||||
|
|
||||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,15 @@
|
|||||||
|
|
||||||
from typing import Callable, TypedDict
|
from typing import Callable, TypedDict
|
||||||
|
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired, Required
|
||||||
|
|
||||||
|
from ....extras import logging
|
||||||
from ...extras.types import DPOSample, Sample, SFTSample
|
from ...extras.types import DPOSample, Sample, SFTSample
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AlpacaSample(TypedDict, total=False):
|
class AlpacaSample(TypedDict, total=False):
|
||||||
system: NotRequired[str]
|
system: NotRequired[str]
|
||||||
instruction: NotRequired[str]
|
instruction: NotRequired[str]
|
||||||
@@ -27,6 +31,21 @@ class AlpacaSample(TypedDict, total=False):
|
|||||||
output: NotRequired[str]
|
output: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
ShareGPTMessage = TypedDict(
|
||||||
|
"ShareGPTMessage",
|
||||||
|
{
|
||||||
|
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
|
||||||
|
"value": Required[str], # Content of the message
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShareGPTSample(TypedDict, total=False):
|
||||||
|
"""Type definition for raw ShareGPT sample."""
|
||||||
|
|
||||||
|
conversations: Required[list[ShareGPTMessage]]
|
||||||
|
|
||||||
|
|
||||||
class PairSample(TypedDict, total=False):
|
class PairSample(TypedDict, total=False):
|
||||||
prompt: NotRequired[str]
|
prompt: NotRequired[str]
|
||||||
chosen: NotRequired[list[dict]]
|
chosen: NotRequired[list[dict]]
|
||||||
@@ -48,6 +67,20 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|||||||
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "history" in raw_sample:
|
||||||
|
for idx, item in enumerate(raw_sample["history"]):
|
||||||
|
if len(item) != 2:
|
||||||
|
logger.warning_rank0(
|
||||||
|
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_prompt, old_response = item
|
||||||
|
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
|
||||||
|
messages.append(
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
|
||||||
|
)
|
||||||
|
|
||||||
if "instruction" in raw_sample or "input" in raw_sample:
|
if "instruction" in raw_sample or "input" in raw_sample:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -67,6 +100,62 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|
||||||
|
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
|
||||||
|
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
|
||||||
|
|
||||||
|
Retains only SFT-relevant scenarios and removes parity checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the formatted 'messages' list for SFT training.
|
||||||
|
Returns an empty list if the input data is invalid.
|
||||||
|
"""
|
||||||
|
tag_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
"observation": "observation",
|
||||||
|
"function_call": "function",
|
||||||
|
}
|
||||||
|
messages = raw_sample.get("conversations", [])
|
||||||
|
aligned_messages = []
|
||||||
|
system_content = ""
|
||||||
|
|
||||||
|
# Extract system message if present (typically the first message)
|
||||||
|
if messages and messages[0]["from"] == "system":
|
||||||
|
system_content = messages[0]["value"]
|
||||||
|
messages = messages[1:]
|
||||||
|
|
||||||
|
if system_content:
|
||||||
|
aligned_messages.append(
|
||||||
|
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
|
||||||
|
)
|
||||||
|
|
||||||
|
has_invalid_role = False
|
||||||
|
for message in messages:
|
||||||
|
sender = message["from"]
|
||||||
|
# validate sender is in supported tags
|
||||||
|
if sender not in tag_mapping:
|
||||||
|
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
|
||||||
|
has_invalid_role = True
|
||||||
|
break
|
||||||
|
|
||||||
|
aligned_messages.append(
|
||||||
|
{
|
||||||
|
"role": tag_mapping[sender],
|
||||||
|
"content": [{"type": "text", "value": message["value"]}],
|
||||||
|
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_invalid_role:
|
||||||
|
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
|
||||||
|
return {"messages": []}
|
||||||
|
|
||||||
|
return {"messages": aligned_messages}
|
||||||
|
|
||||||
|
|
||||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||||
"""Convert Pair sample to standard DPO sample.
|
"""Convert Pair sample to standard DPO sample.
|
||||||
|
|
||||||
@@ -148,6 +237,7 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
|||||||
CONVERTERS = {
|
CONVERTERS = {
|
||||||
"alpaca": alpaca_converter,
|
"alpaca": alpaca_converter,
|
||||||
"pair": pair_converter,
|
"pair": pair_converter,
|
||||||
|
"sharegpt": sharegpt_converter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from llamafactory.v1.config.data_args import DataArguments
|
from llamafactory.v1.config.data_args import DataArguments
|
||||||
from llamafactory.v1.core.data_engine import DataEngine
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
|
from llamafactory.v1.plugins.data_plugins.converter import get_converter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
@@ -48,6 +49,96 @@ def test_alpaca_converter(num_samples: int):
|
|||||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharegpt_converter_invalid():
|
||||||
|
example = {
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "system",
|
||||||
|
"value": "Processes historical market data to generate trading signals "
|
||||||
|
"based on specified technical indicators.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||||
|
"Could you proceed with these function calls to assist me with the task?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||||
|
"'name': 'backtest_trading_signals'}```\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "tool",
|
||||||
|
"value": '<tool id="D2">\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", '
|
||||||
|
'"symbol": "AAPL", "signal": "Buy"}]}}}\n</tool>\n',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
dataset_converter = get_converter("sharegpt")
|
||||||
|
assert dataset_converter(example) == {"messages": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharegpt_converter_valid():
|
||||||
|
example = {
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "system",
|
||||||
|
"value": "Processes historical market data to generate trading signals based on "
|
||||||
|
"specified technical indicators.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||||
|
"Could you proceed with these function calls to assist me with the task?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||||
|
"'name': 'backtest_trading_signals'}```\n",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
dataset_converter = get_converter("sharegpt")
|
||||||
|
expected_data = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "Processes historical market data to generate trading signals based on "
|
||||||
|
"specified technical indicators.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"loss_weight": 0.0,
|
||||||
|
"role": "system",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||||
|
"Could you proceed with these function calls to assist me with the task?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"loss_weight": 0.0,
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||||
|
"'name': 'backtest_trading_signals'}```\n",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"loss_weight": 1.0,
|
||||||
|
"role": "assistant",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert dataset_converter(example) == expected_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_pair_converter(num_samples: int):
|
def test_pair_converter(num_samples: int):
|
||||||
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
|
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
|
||||||
@@ -98,4 +189,6 @@ def test_pair_converter(num_samples: int):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_alpaca_converter(1)
|
test_alpaca_converter(1)
|
||||||
|
test_sharegpt_converter_invalid()
|
||||||
|
test_sharegpt_converter_valid()
|
||||||
test_pair_converter(1)
|
test_pair_converter(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user