mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	add e2e tests
Former-commit-id: 0156a37450604641c4f5f9756ad84324698fc88c
This commit is contained in:
		
							parent
							
								
									d6ce902d80
								
							
						
					
					
						commit
						9bdba2f6a8
					
				@ -175,7 +175,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2    |
 | 
			
		||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama)            | 8B/70B                           | llama3    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | llava     |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B                            | cpm       |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B/4B                         | cpm/cpm3  |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)               | 7B/8x7B/8x22B                    | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                            | 1B/7B                            | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | paligemma |
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2    |
 | 
			
		||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama)            | 8B/70B                           | llama3    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | llava     |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B                            | cpm       |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B/4B                         | cpm/cpm3  |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)               | 7B/8x7B/8x22B                    | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                            | 1B/7B                            | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | paligemma |
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: mllm_demo,identity
 | 
			
		||||
dataset: mllm_demo,identity  # video: mllm_video_demo
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@ if is_pyav_available():
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
@ -31,11 +30,17 @@ if TYPE_CHECKING:
 | 
			
		||||
    VideoInput = str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
 | 
			
		||||
def _regularize_images(
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
    max_resolution: Optional[int] = None,
 | 
			
		||||
) -> List["ImageObject"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Regularizes images to avoid error. Including reading, resizing and converting.
 | 
			
		||||
    """
 | 
			
		||||
    image_resolution: int = getattr(processor, "image_resolution", 512)
 | 
			
		||||
    if max_resolution is None:
 | 
			
		||||
        max_resolution: int = getattr(processor, "image_resolution", 512)
 | 
			
		||||
 | 
			
		||||
    results = []
 | 
			
		||||
    for image in images:
 | 
			
		||||
        if isinstance(image, str):
 | 
			
		||||
@ -49,9 +54,9 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
 | 
			
		||||
        if not isinstance(image, ImageObject):
 | 
			
		||||
            raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
 | 
			
		||||
 | 
			
		||||
        if max(image.width, image.height) > image_resolution:
 | 
			
		||||
            factor = image_resolution / max(image.width, image.height)
 | 
			
		||||
            image = image.resize((int(image.width * factor), int(image.height * factor)))
 | 
			
		||||
        if max(image.width, image.height) > max_resolution:
 | 
			
		||||
            factor = max_resolution / max(image.width, image.height)
 | 
			
		||||
            image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
 | 
			
		||||
 | 
			
		||||
        if image.mode != "RGB":
 | 
			
		||||
            image = image.convert("RGB")
 | 
			
		||||
@ -61,11 +66,16 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
 | 
			
		||||
def _regularize_videos(
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
) -> List[List["ImageObject"]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Regularizes videos to avoid error. Including reading, resizing and converting.
 | 
			
		||||
    """
 | 
			
		||||
    video_resolution: int = getattr(processor, "video_resolution", 128)
 | 
			
		||||
    video_fps: float = getattr(processor, "video_fps", 1.0)
 | 
			
		||||
    video_maxlen: int = getattr(processor, "video_maxlen", 64)
 | 
			
		||||
    video_factor: int = getattr(processor, "video_factor", 1)
 | 
			
		||||
    results = []
 | 
			
		||||
    for video in videos:
 | 
			
		||||
@ -73,6 +83,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
 | 
			
		||||
        video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
        total_frames = video_stream.frames
 | 
			
		||||
        sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
 | 
			
		||||
        sample_frames = min(video_maxlen, sample_frames)  # reduce length <= maxlen
 | 
			
		||||
        sample_frames = round(sample_frames / video_factor) * video_factor  # for qwen2_vl
 | 
			
		||||
        sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
        frames: List["ImageObject"] = []
 | 
			
		||||
@ -81,7 +92,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
 | 
			
		||||
            if frame_idx in sample_indices:
 | 
			
		||||
                frames.append(frame.to_image())
 | 
			
		||||
 | 
			
		||||
        frames = _regularize_images(frames, processor)
 | 
			
		||||
        frames = _regularize_images(frames, processor, video_resolution)
 | 
			
		||||
        results.append(frames)
 | 
			
		||||
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
@ -562,8 +562,8 @@ _register_template(
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="cpm3",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,12 +23,133 @@ from typing_extensions import Self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments:
 | 
			
		||||
class QuantizationArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the quantization method.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
 | 
			
		||||
        default="bitsandbytes",
 | 
			
		||||
        metadata={"help": "Quantization method to use for on-the-fly quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_type: Literal["fp4", "nf4"] = field(
 | 
			
		||||
        default="nf4",
 | 
			
		||||
        metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
 | 
			
		||||
    )
 | 
			
		||||
    double_quantization: bool = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_device_map: Optional[Literal["auto"]] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ProcessorArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the image processor.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    image_resolution: int = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "Keeps the height or width of image below this resolution."},
 | 
			
		||||
    )
 | 
			
		||||
    video_resolution: int = field(
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "Keeps the height or width of video below this resolution."},
 | 
			
		||||
    )
 | 
			
		||||
    video_fps: float = field(
 | 
			
		||||
        default=2.0,
 | 
			
		||||
        metadata={"help": "The frames to sample per second for video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    video_maxlen: int = field(
 | 
			
		||||
        default=64,
 | 
			
		||||
        metadata={"help": "The maximum number of sampled frames for video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ExportArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the model export.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    export_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory to save the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "The file shard size (in GB) of the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_device: Literal["cpu", "auto"] = field(
 | 
			
		||||
        default="cpu",
 | 
			
		||||
        metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_dataset: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_nsamples: int = field(
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "The number of samples used for quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_maxlen: int = field(
 | 
			
		||||
        default=1024,
 | 
			
		||||
        metadata={"help": "The maximum length of the model inputs used for quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    export_legacy_format: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
 | 
			
		||||
    )
 | 
			
		||||
    export_hub_model_id: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class VllmArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the vLLM worker.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    vllm_maxlen: int = field(
 | 
			
		||||
        default=2048,
 | 
			
		||||
        metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_gpu_util: float = field(
 | 
			
		||||
        default=0.9,
 | 
			
		||||
        metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_enforce_eager: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_max_lora_rank: int = field(
 | 
			
		||||
        default=32,
 | 
			
		||||
        metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    model_name_or_path: str = field(
 | 
			
		||||
    model_name_or_path: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
 | 
			
		||||
        },
 | 
			
		||||
@ -74,26 +195,6 @@ class ModelArguments:
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use memory-efficient model loading."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
 | 
			
		||||
        default="bitsandbytes",
 | 
			
		||||
        metadata={"help": "Quantization method to use for on-the-fly quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_type: Literal["fp4", "nf4"] = field(
 | 
			
		||||
        default="nf4",
 | 
			
		||||
        metadata={"help": "Quantization data type to use in int4 training."},
 | 
			
		||||
    )
 | 
			
		||||
    double_quantization: bool = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use double quantization in int4 training."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_device_map: Optional[Literal["auto"]] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
 | 
			
		||||
    )
 | 
			
		||||
    rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
 | 
			
		||||
@ -138,34 +239,10 @@ class ModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to randomly initialize the model weights."},
 | 
			
		||||
    )
 | 
			
		||||
    image_resolution: int = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "Keeps the height or width of image below this resolution."},
 | 
			
		||||
    )
 | 
			
		||||
    video_fps: float = field(
 | 
			
		||||
        default=2.0,
 | 
			
		||||
        metadata={"help": "The frames to sample per second for video training."},
 | 
			
		||||
    )
 | 
			
		||||
    infer_backend: Literal["huggingface", "vllm"] = field(
 | 
			
		||||
        default="huggingface",
 | 
			
		||||
        metadata={"help": "Backend engine used at inference."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_maxlen: int = field(
 | 
			
		||||
        default=2048,
 | 
			
		||||
        metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_gpu_util: float = field(
 | 
			
		||||
        default=0.9,
 | 
			
		||||
        metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_enforce_eager: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    vllm_max_lora_rank: int = field(
 | 
			
		||||
        default=32,
 | 
			
		||||
        metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
 | 
			
		||||
    )
 | 
			
		||||
    offload_folder: str = field(
 | 
			
		||||
        default="offload",
 | 
			
		||||
        metadata={"help": "Path to offload model weights."},
 | 
			
		||||
@ -186,42 +263,6 @@ class ModelArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Auth token to log in with ModelScope Hub."},
 | 
			
		||||
    )
 | 
			
		||||
    export_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory to save the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "The file shard size (in GB) of the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_device: Literal["cpu", "auto"] = field(
 | 
			
		||||
        default="cpu",
 | 
			
		||||
        metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_dataset: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_nsamples: int = field(
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "The number of samples used for quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    export_quantization_maxlen: int = field(
 | 
			
		||||
        default=1024,
 | 
			
		||||
        metadata={"help": "The maximum length of the model inputs used for quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    export_legacy_format: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
 | 
			
		||||
    )
 | 
			
		||||
    export_hub_model_id: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
 | 
			
		||||
    )
 | 
			
		||||
    print_param_status: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
 | 
			
		||||
@ -248,6 +289,9 @@ class ModelArguments:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.model_name_or_path is None:
 | 
			
		||||
            raise ValueError("Please provide `model_name_or_path`.")
 | 
			
		||||
 | 
			
		||||
        if self.split_special_tokens and self.use_fast_tokenizer:
 | 
			
		||||
            raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -100,7 +100,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
        setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
        setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
        setattr(processor, "image_resolution", model_args.image_resolution)
 | 
			
		||||
        setattr(processor, "video_resolution", model_args.video_resolution)
 | 
			
		||||
        setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
        setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
        if getattr(config, "model_type", None) == "qwen2_vl":
 | 
			
		||||
            setattr(processor, "video_factor", 2)
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										49
									
								
								tests/e2e/test_chat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tests/e2e/test_chat.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from llamafactory.chat import ChatModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
 | 
			
		||||
 | 
			
		||||
INFER_ARGS = {
 | 
			
		||||
    "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
    "finetuning_type": "lora",
 | 
			
		||||
    "template": "llama3",
 | 
			
		||||
    "infer_dtype": "float16",
 | 
			
		||||
    "do_sample": False,
 | 
			
		||||
    "max_new_tokens": 1,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MESSAGES = [
 | 
			
		||||
    {"role": "user", "content": "Hi"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
EXPECTED_RESPONSE = "_rho"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_chat():
 | 
			
		||||
    chat_model = ChatModel(INFER_ARGS)
 | 
			
		||||
    assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_stream_chat():
 | 
			
		||||
    chat_model = ChatModel(INFER_ARGS)
 | 
			
		||||
    response = ""
 | 
			
		||||
    for token in chat_model.stream_chat(MESSAGES):
 | 
			
		||||
        response += token
 | 
			
		||||
 | 
			
		||||
    assert response == EXPECTED_RESPONSE
 | 
			
		||||
							
								
								
									
										70
									
								
								tests/e2e/test_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								tests/e2e/test_train.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from llamafactory.train.tuner import export_model, run_exp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
 | 
			
		||||
 | 
			
		||||
TRAIN_ARGS = {
 | 
			
		||||
    "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
    "do_train": True,
 | 
			
		||||
    "finetuning_type": "lora",
 | 
			
		||||
    "dataset_dir": "REMOTE:" + DEMO_DATA,
 | 
			
		||||
    "template": "llama3",
 | 
			
		||||
    "cutoff_len": 1,
 | 
			
		||||
    "overwrite_cache": True,
 | 
			
		||||
    "overwrite_output_dir": True,
 | 
			
		||||
    "per_device_train_batch_size": 1,
 | 
			
		||||
    "max_steps": 1,
 | 
			
		||||
    "fp16": True,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INFER_ARGS = {
 | 
			
		||||
    "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
    "adapter_name_or_path": TINY_LLAMA_ADAPTER,
 | 
			
		||||
    "finetuning_type": "lora",
 | 
			
		||||
    "template": "llama3",
 | 
			
		||||
    "infer_dtype": "float16",
 | 
			
		||||
    "export_dir": "llama3_export",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "stage,dataset",
 | 
			
		||||
    [
 | 
			
		||||
        ("pt", "c4_demo"),
 | 
			
		||||
        ("sft", "alpaca_en_demo"),
 | 
			
		||||
        ("rm", "dpo_en_demo"),
 | 
			
		||||
        ("dpo", "dpo_en_demo"),
 | 
			
		||||
        ("kto", "kto_en_demo"),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_train(stage: str, dataset: str):
 | 
			
		||||
    output_dir = "train_{}".format(stage)
 | 
			
		||||
    run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
 | 
			
		||||
    assert os.path.exists(output_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_export():
 | 
			
		||||
    export_model(INFER_ARGS)
 | 
			
		||||
    assert os.path.exists("llama3_export")
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user