mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	update parser
Former-commit-id: d98258aa08d93494ad50d7786064e7fda15f6ca9
This commit is contained in:
		
							parent
							
								
									7ff8a064f3
								
							
						
					
					
						commit
						7c492864e9
					
				@ -1,6 +1,6 @@
 | 
			
		||||
import json
 | 
			
		||||
import datasets
 | 
			
		||||
from typing import Any, Dict, List
 | 
			
		||||
from typing import Any, Dict, Generator, List, Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_DESCRIPTION = "An example of dataset."
 | 
			
		||||
@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder):
 | 
			
		||||
            )
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
 | 
			
		||||
    def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
 | 
			
		||||
        example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
 | 
			
		||||
        for key, example in enumerate(example_dataset):
 | 
			
		||||
            yield key, example
 | 
			
		||||
 | 
			
		||||
@ -73,19 +73,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
 | 
			
		||||
        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
 | 
			
		||||
            raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        if finetuning_args.stage != "sft":
 | 
			
		||||
            raise ValueError("vLLM engine only supports auto-regressive models.")
 | 
			
		||||
 | 
			
		||||
        if model_args.adapter_name_or_path is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
 | 
			
		||||
 | 
			
		||||
        if model_args.quantization_bit is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support quantization.")
 | 
			
		||||
 | 
			
		||||
        if model_args.rope_scaling is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support RoPE scaling.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    parser = HfArgumentParser(_TRAIN_ARGS)
 | 
			
		||||
@ -154,6 +141,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
        if training_args.fp16 or training_args.bf16:
 | 
			
		||||
            raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
@ -252,12 +242,27 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
    model_args.device_map = "auto"
 | 
			
		||||
 | 
			
		||||
    if data_args.template is None:
 | 
			
		||||
        raise ValueError("Please specify which `template` to use.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        if finetuning_args.stage != "sft":
 | 
			
		||||
            raise ValueError("vLLM engine only supports auto-regressive models.")
 | 
			
		||||
 | 
			
		||||
        if model_args.adapter_name_or_path is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
 | 
			
		||||
 | 
			
		||||
        if model_args.quantization_bit is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support quantization.")
 | 
			
		||||
 | 
			
		||||
        if model_args.rope_scaling is not None:
 | 
			
		||||
            raise ValueError("vLLM engine does not support RoPE scaling.")
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    model_args.device_map = "auto"
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, finetuning_args, generating_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -265,12 +270,17 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
 | 
			
		||||
    model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
    model_args.device_map = "auto"
 | 
			
		||||
 | 
			
		||||
    if data_args.template is None:
 | 
			
		||||
        raise ValueError("Please specify which `template` to use.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    model_args.device_map = "auto"
 | 
			
		||||
 | 
			
		||||
    transformers.set_seed(eval_args.seed)
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, eval_args, finetuning_args
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user