mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] upgrade batching (#9751)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -58,3 +58,10 @@ def test_multi_device():
|
||||
master_port = find_available_port()
|
||||
world_size = 2
|
||||
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/accelerator/test_interface.py
|
||||
"""
|
||||
test_all_device()
|
||||
|
||||
@@ -12,41 +12,41 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
|
||||
|
||||
def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
def test_get_args_from_yaml(tmp_path: Path):
|
||||
config_yaml = """
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
model: llamafactory/tiny-random-qwen3
|
||||
trust_remote_code: true
|
||||
model_class: "llm"
|
||||
model_class: llm
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
peft_config:
|
||||
name: "lora"
|
||||
lora_rank: 0.8
|
||||
name: lora
|
||||
lora_rank: 0.8
|
||||
quant_config: null
|
||||
|
||||
### data
|
||||
dataset: "llamafactory/tiny-supervised-dataset"
|
||||
cutoff_len: 2048
|
||||
dataset: llamafactory/v1-sft-demo
|
||||
|
||||
### training
|
||||
output_dir: "outputs/test_run"
|
||||
output_dir: outputs/test_run
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
dist_config: null
|
||||
|
||||
### sample
|
||||
sample_backend: "hf"
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
"""
|
||||
|
||||
@@ -57,14 +57,26 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
|
||||
with patch.object(sys, "argv", test_argv):
|
||||
data_args, model_args, training_args, sample_args = get_args()
|
||||
assert data_args.dataset == "llamafactory/v1-sft-demo"
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen3"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
assert training_args.output_dir == "outputs/test_run"
|
||||
assert training_args.micro_batch_size == 1
|
||||
assert training_args.global_batch_size == 1
|
||||
assert training_args.learning_rate == 1.0e-4
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
assert sample_args.sample_backend == "hf"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.config.test_args_parser
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_get_args_from_yaml(tmp_path=Path(tmp_dir))
|
||||
|
||||
@@ -33,4 +33,7 @@ def test_map_dataset(num_samples: int):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_data_engine
|
||||
"""
|
||||
test_map_dataset(1)
|
||||
|
||||
@@ -44,5 +44,8 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_model_loader
|
||||
"""
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -46,4 +46,7 @@ def test_normal_batching():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_batching
|
||||
"""
|
||||
test_normal_batching()
|
||||
|
||||
@@ -219,6 +219,9 @@ def test_process_dpo_samples():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_rendering
|
||||
"""
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
test_chatml_rendering_remote(16)
|
||||
|
||||
@@ -120,6 +120,9 @@ def test_pair_converter(num_samples: int):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.plugins.data_plugins.test_converter
|
||||
"""
|
||||
test_alpaca_converter(1)
|
||||
test_sharegpt_converter()
|
||||
test_pair_converter(1)
|
||||
|
||||
@@ -52,3 +52,12 @@ def test_init_on_default():
|
||||
)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device == DistributedInterface().current_device
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/plugins/model_plugins/test_init_plugin.py
|
||||
"""
|
||||
test_init_on_meta()
|
||||
test_init_on_rank0()
|
||||
test_init_on_default()
|
||||
|
||||
@@ -38,4 +38,7 @@ def test_sync_sampler():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/sampler/test_cli_sampler.py
|
||||
"""
|
||||
test_sync_sampler()
|
||||
|
||||
Reference in New Issue
Block a user