mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-07 22:00:03 +08:00
Compare commits
3 Commits
934b3084ee
...
56f45e826f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56f45e826f | ||
|
|
14abb75126 | ||
|
|
5a9939050e |
@ -110,6 +110,10 @@ def is_starlette_available():
|
|||||||
def is_transformers_version_greater_than(content: str):
|
def is_transformers_version_greater_than(content: str):
|
||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_version_greater_than(content: str):
|
||||||
|
return _get_package_version("torch") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import AttentionFunction
|
from ...extras.constants import AttentionFunction
|
||||||
|
from ...extras.packages import is_torch_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -51,15 +52,14 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
requested_attn_implementation = "eager"
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||||
from transformers.utils import is_torch_sdpa_available
|
if not is_torch_version_greater_than("2.1.1"):
|
||||||
|
|
||||||
if not is_torch_sdpa_available():
|
|
||||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == AttentionFunction.FA2:
|
elif model_args.flash_attn == AttentionFunction.FA2:
|
||||||
if not is_flash_attn_2_available():
|
from transformers import is_torch_npu_available
|
||||||
|
if not (is_flash_attn_2_available() or is_torch_npu_available()):
|
||||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -355,7 +355,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl",
|
model_type="qwen3_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
@ -364,7 +364,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl_moe",
|
model_type="qwen3_vl_moe",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
@ -373,7 +373,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_omni_moe_thinker",
|
model_type="qwen3_omni_moe_thinker",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -203,7 +203,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
bco_losses = self.bco_loss(
|
bco_losses = self.bco_loss(
|
||||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||||
)
|
)
|
||||||
losses += bco_losses * self.bco_gemma
|
losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q
|
||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return losses, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
@ -284,9 +284,6 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
sft_loss = -policy_chosen_logps_avg
|
sft_loss = -policy_chosen_logps_avg
|
||||||
if self.ftx_gamma > 1e-6:
|
if self.ftx_gamma > 1e-6:
|
||||||
losses += self.ftx_gamma * sft_loss
|
losses += self.ftx_gamma * sft_loss
|
||||||
if self.bco_gemma > 1e-6:
|
|
||||||
# re-weigthing for MPO
|
|
||||||
losses /= self.ftx_gamma + self.bco_gemma + 1.0
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user