From 9b152d9cb543999c21f33e473bb51a819dabd456 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 28 May 2025 18:11:07 +0800 Subject: [PATCH] [webui] fix skip args (#8195) --- src/llamafactory/data/data_utils.py | 62 ++++++++++++----------------- src/llamafactory/webui/common.py | 9 ++++- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 12697b6a..14e26129 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -14,7 +14,7 @@ import json from enum import Enum, unique -from typing import TYPE_CHECKING, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union import fsspec from datasets import DatasetDict, concatenate_datasets, interleave_datasets @@ -142,61 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu return dataset_module -def setup_fs(path, anon=False): - """Set up a filesystem object based on the path protocol.""" +def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem": + r"""Set up a filesystem object based on the path protocol.""" storage_options = {"anon": anon} if anon else {} - if path.startswith("s3://"): fs = fsspec.filesystem("s3", **storage_options) elif path.startswith(("gs://", "gcs://")): fs = fsspec.filesystem("gcs", **storage_options) else: - raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'") + raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.") + if not fs.exists(path): - raise ValueError(f"Path does not exist: {path}") + raise ValueError(f"Path does not exist: {path}.") + return fs -def read_cloud_json(cloud_path): - """Read a JSON/JSONL file from cloud storage (S3 or GCS). +def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]: + r"""Helper function to read JSON/JSONL files using fsspec.""" + with fs.open(path, "r") as f: + if path.endswith(".jsonl"): + return [json.loads(line) for line in f if line.strip()] + else: + return json.load(f) + + +def read_cloud_json(cloud_path: str) -> list[Any]: + r"""Read a JSON/JSONL file from cloud storage (S3 or GCS). Args: - cloud_path : str + cloud_path: str Cloud path in the format: - 's3://bucket-name/file.json' for AWS S3 - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage - lines : bool, default=True - If True, read the file as JSON Lines format (one JSON object per line) """ try: - # Try with anonymous access first - fs = setup_fs(cloud_path, anon=True) + fs = setup_fs(cloud_path, anon=True) # try with anonymous access first except Exception: - # Try again with credentials - fs = setup_fs(cloud_path) + fs = setup_fs(cloud_path) # try again with credentials - if fs.isdir(cloud_path): - files = [x["Key"] for x in fs.listdir(cloud_path)] - else: - files = [cloud_path] # filter out non-JSON files - files = [file for file in files if file.endswith(".json") or file.endswith(".jsonl")] + files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path] + files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files) if not files: - raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}") - data = [] - for file in files: - data.extend(_read_json_with_fs(fs, file, lines=file.endswith(".jsonl"))) - return data + raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.") - -def _read_json_with_fs(fs, path, lines=True): - """Helper function to read JSON/JSONL files using fsspec.""" - with fs.open(path, "r") as f: - if lines: - # Read JSONL (JSON Lines) format - one JSON object per line - data = [json.loads(line) for line in f if line.strip()] - else: - # Read regular JSON format - data = json.load(f) - - return data + return sum([_read_json_with_fs(fs, file) for file in files], []) diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 7b682d25..ccfe9947 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None: def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]: r"""Remove args with NoneType or False or empty string value.""" - no_skip_keys = ["packing", "freeze_vision_tower", "freeze_multi_modal_projector", "freeze_language_model"] + no_skip_keys = [ + "packing", + "enable_thinking", + "use_reentrant_gc", + "double_quantization", + "freeze_vision_tower", + "freeze_multi_modal_projector", + ] return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}