mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 01:20:35 +08:00
[v1] upgrade batching (#9751)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -12,41 +12,41 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
|
||||
|
||||
def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
def test_get_args_from_yaml(tmp_path: Path):
|
||||
config_yaml = """
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
model: llamafactory/tiny-random-qwen3
|
||||
trust_remote_code: true
|
||||
model_class: "llm"
|
||||
model_class: llm
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
peft_config:
|
||||
name: "lora"
|
||||
lora_rank: 0.8
|
||||
name: lora
|
||||
lora_rank: 0.8
|
||||
quant_config: null
|
||||
|
||||
### data
|
||||
dataset: "llamafactory/tiny-supervised-dataset"
|
||||
cutoff_len: 2048
|
||||
dataset: llamafactory/v1-sft-demo
|
||||
|
||||
### training
|
||||
output_dir: "outputs/test_run"
|
||||
output_dir: outputs/test_run
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
dist_config: null
|
||||
|
||||
### sample
|
||||
sample_backend: "hf"
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
"""
|
||||
|
||||
@@ -57,14 +57,26 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
|
||||
with patch.object(sys, "argv", test_argv):
|
||||
data_args, model_args, training_args, sample_args = get_args()
|
||||
assert data_args.dataset == "llamafactory/v1-sft-demo"
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen3"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
assert training_args.output_dir == "outputs/test_run"
|
||||
assert training_args.micro_batch_size == 1
|
||||
assert training_args.global_batch_size == 1
|
||||
assert training_args.learning_rate == 1.0e-4
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
assert sample_args.sample_backend == "hf"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.config.test_args_parser
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_get_args_from_yaml(tmp_path=Path(tmp_dir))
|
||||
|
||||
Reference in New Issue
Block a user