mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 02:35:59 +08:00
[tracker] Add Trackio Integration for LlamaFactory (#10165)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -100,6 +100,52 @@ def _parse_args(
|
|||||||
return tuple(parsed_args)
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||||
|
"""Validates Trackio-specific arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_args: TrainingArguments instance (not a dictionary)
|
||||||
|
"""
|
||||||
|
report_to = training_args.report_to
|
||||||
|
if not report_to:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(report_to, str):
|
||||||
|
report_to = [report_to]
|
||||||
|
|
||||||
|
if "trackio" not in report_to:
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Enforce project (required by Trackio) ---
|
||||||
|
if not training_args.project:
|
||||||
|
raise ValueError("`--project` must be specified when using Trackio.")
|
||||||
|
|
||||||
|
# --- Validate trackio_space_id format ---
|
||||||
|
space_id = training_args.trackio_space_id
|
||||||
|
if space_id:
|
||||||
|
if space_id != "trackio" and "/" not in space_id:
|
||||||
|
logger.warning(
|
||||||
|
f"trackio_space_id '{space_id}' should typically be in format "
|
||||||
|
"'org/space' for Hugging Face Spaces deployment."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Inform about default project usage ---
|
||||||
|
if training_args.project == "huggingface":
|
||||||
|
logger.info(
|
||||||
|
"Using default project name 'huggingface'. "
|
||||||
|
"Consider setting a custom project name with --project "
|
||||||
|
"for better organization."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Validate hub repo privacy flag ---
|
||||||
|
if training_args.hub_private_repo:
|
||||||
|
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||||
|
|
||||||
|
# --- Recommend run_name for experiment clarity ---
|
||||||
|
if not training_args.run_name:
|
||||||
|
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging() -> None:
|
def _set_transformers_logging() -> None:
|
||||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
if training_args.report_to and any(
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||||
|
):
|
||||||
|
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||||
|
|
||||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
@@ -352,6 +400,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
_set_env_vars()
|
_set_env_vars()
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
_verify_trackio_args(training_args)
|
||||||
|
|
||||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||||
|
|||||||
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "trackio" in args.report_to:
|
||||||
|
import trackio
|
||||||
|
|
||||||
|
trackio.config.update(
|
||||||
|
{
|
||||||
|
"model_args": self.model_args.to_dict(),
|
||||||
|
"data_args": self.data_args.to_dict(),
|
||||||
|
"finetuning_args": self.finetuning_args.to_dict(),
|
||||||
|
"generating_args": self.generating_args.to_dict(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.finetuning_args.use_swanlab:
|
if self.finetuning_args.use_swanlab:
|
||||||
import swanlab # type: ignore
|
import swanlab # type: ignore
|
||||||
|
|
||||||
|
|||||||
@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
enable_thinking = gr.Checkbox(value=True)
|
enable_thinking = gr.Checkbox(value=True)
|
||||||
report_to = gr.Dropdown(
|
report_to = gr.Dropdown(
|
||||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
|
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
|
||||||
value="none",
|
value="none",
|
||||||
allow_custom_value=True,
|
allow_custom_value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Trackio Settings", open=False):
|
||||||
|
project = gr.Textbox(
|
||||||
|
value="huggingface",
|
||||||
|
label="Project Name",
|
||||||
|
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
|
||||||
|
)
|
||||||
|
|
||||||
|
trackio_space_id = gr.Textbox(
|
||||||
|
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
|
||||||
|
)
|
||||||
|
|
||||||
|
hub_private_repo = gr.Checkbox(
|
||||||
|
value=False, label="Private Repository", info="Make the Hugging Face repository private"
|
||||||
|
)
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{
|
{
|
||||||
logging_steps,
|
logging_steps,
|
||||||
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
enable_thinking,
|
enable_thinking,
|
||||||
report_to,
|
report_to,
|
||||||
|
project,
|
||||||
|
trackio_space_id,
|
||||||
|
hub_private_repo,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
enable_thinking=enable_thinking,
|
enable_thinking=enable_thinking,
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
|
project=project,
|
||||||
|
trackio_space_id=trackio_space_id,
|
||||||
|
hub_private_repo=hub_private_repo,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user