diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 81762635b..308eecade 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -100,6 +100,52 @@ def _parse_args( return tuple(parsed_args) +def _verify_trackio_args(training_args: "TrainingArguments") -> None: + """Validates Trackio-specific arguments. + + Args: + training_args: TrainingArguments instance (not a dictionary) + """ + report_to = training_args.report_to + if not report_to: + return + + if isinstance(report_to, str): + report_to = [report_to] + + if "trackio" not in report_to: + return + + # --- Enforce project (required by Trackio) --- + if not training_args.project: + raise ValueError("`--project` must be specified when using Trackio.") + + # --- Validate trackio_space_id format --- + space_id = training_args.trackio_space_id + if space_id: + if space_id != "trackio" and "/" not in space_id: + logger.warning( + f"trackio_space_id '{space_id}' should typically be in format " + "'org/space' for Hugging Face Spaces deployment." + ) + + # --- Inform about default project usage --- + if training_args.project == "huggingface": + logger.info( + "Using default project name 'huggingface'. " + "Consider setting a custom project name with --project " + "for better organization." + ) + + # --- Validate hub repo privacy flag --- + if training_args.hub_private_repo: + logger.info("Repository will be created as private on Hugging Face Hub.") + + # --- Recommend run_name for experiment clarity --- + if not training_args.run_name: + logger.warning("Consider setting --run_name for better experiment tracking clarity.") + + def _set_transformers_logging() -> None: if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]: transformers.utils.logging.set_verbosity_info() @@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: raise ValueError("Unsloth does not support lora reward model.") - if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: - raise ValueError("PPO only accepts wandb or tensorboard logger.") + if training_args.report_to and any( + logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to + ): + raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.") if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") @@ -352,6 +400,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS _set_env_vars() _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) + _verify_trackio_args(training_args) if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index d164c0443..ac574ae7c 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback): } ) + if "trackio" in args.report_to: + import trackio + + trackio.config.update( + { + "model_args": self.model_args.to_dict(), + "data_args": self.data_args.to_dict(), + "finetuning_args": self.finetuning_args.to_dict(), + "generating_args": self.generating_args.to_dict(), + } + ) + if self.finetuning_args.use_swanlab: import swanlab # type: ignore diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 8b7aa6e94..f1bc1aea7 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: with gr.Column(): enable_thinking = gr.Checkbox(value=True) report_to = gr.Dropdown( - choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"], + choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"], value="none", allow_custom_value=True, ) + with gr.Accordion("Trackio Settings", open=False): + project = gr.Textbox( + value="huggingface", + label="Project Name", + info="Project name for experiment tracking (used by Trackio, W&B, etc.)", + ) + + trackio_space_id = gr.Textbox( + value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment" + ) + + hub_private_repo = gr.Checkbox( + value=False, label="Private Repository", info="Make the Hugging Face repository private" + ) + input_elems.update( { logging_steps, @@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_llama_pro, enable_thinking, report_to, + project, + trackio_space_id, + hub_private_repo, } ) elem_dict.update( @@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_llama_pro=use_llama_pro, enable_thinking=enable_thinking, report_to=report_to, + project=project, + trackio_space_id=trackio_space_id, + hub_private_repo=hub_private_repo, ) )