mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support mllm hf inference
Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
This commit is contained in:
		
							parent
							
								
									10a6c395bb
								
							
						
					
					
						commit
						23b881bff1
					
				@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
 | 
			
		||||
    "history": "the column name in the dataset containing the histories. (default: None)",
 | 
			
		||||
    "messages": "the column name in the dataset containing the messages. (default: conversations)",
 | 
			
		||||
    "system": "the column name in the dataset containing the system prompts. (default: None)",
 | 
			
		||||
    "tools": "the column name in the dataset containing the tool description. (default: None)"
 | 
			
		||||
    "tools": "the column name in the dataset containing the tool description. (default: None)",
 | 
			
		||||
    "images": "the column name in the dataset containing the image inputs. (default: None)"
 | 
			
		||||
  },
 | 
			
		||||
  "tags (optional, used for the sharegpt format)": {
 | 
			
		||||
    "role_tag": "the key in the message represents the identity. (default: from)",
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,8 @@
 | 
			
		||||
    "history": "数据集代表历史对话的表头名称(默认:None)",
 | 
			
		||||
    "messages": "数据集代表消息列表的表头名称(默认:conversations)",
 | 
			
		||||
    "system": "数据集代表系统提示的表头名称(默认:None)",
 | 
			
		||||
    "tools": "数据集代表工具描述的表头名称(默认:None)"
 | 
			
		||||
    "tools": "数据集代表工具描述的表头名称(默认:None)",
 | 
			
		||||
    "images": "数据集代表图像输入的表头名称(默认:None)"
 | 
			
		||||
  },
 | 
			
		||||
  "tags(可选,用于 sharegpt 格式)": {
 | 
			
		||||
    "role_tag": "消息中代表发送者身份的键名(默认:from)",
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ examples/
 | 
			
		||||
│   ├── ppo.sh: Do PPO training using LoRA
 | 
			
		||||
│   ├── dpo.sh: Do DPO training using LoRA
 | 
			
		||||
│   ├── orpo.sh: Do ORPO training using LoRA
 | 
			
		||||
│   ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
 | 
			
		||||
│   ├── prepare.sh: Save tokenized dataset
 | 
			
		||||
│   └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
 | 
			
		||||
├── qlora_single_gpu/
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ examples/
 | 
			
		||||
│   ├── ppo.sh: 基于 LoRA 进行 PPO 训练
 | 
			
		||||
│   ├── dpo.sh: 基于 LoRA 进行 DPO 训练
 | 
			
		||||
│   ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
 | 
			
		||||
│   ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
 | 
			
		||||
│   ├── prepare.sh: 保存预处理后的数据集
 | 
			
		||||
│   └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
 | 
			
		||||
├── qlora_single_gpu/
 | 
			
		||||
 | 
			
		||||
@ -1,32 +1,33 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage sft_mm \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
 | 
			
		||||
    --stage sft \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
 | 
			
		||||
    --dataset mllm_instruct_example \
 | 
			
		||||
    --dataset_dir data \
 | 
			
		||||
    --template default \
 | 
			
		||||
    --visual_inputs \
 | 
			
		||||
    --dataset mllm_demo \
 | 
			
		||||
    --dataset_dir ../../data \
 | 
			
		||||
    --template vicuna \
 | 
			
		||||
    --finetuning_type lora \
 | 
			
		||||
    --lora_target all \
 | 
			
		||||
    --output_dir saves/llava-1.5-7b/lora/sft \
 | 
			
		||||
    --lora_target q_proj,v_proj \
 | 
			
		||||
    --output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
 | 
			
		||||
    --overwrite_cache \
 | 
			
		||||
    --overwrite_output_dir \
 | 
			
		||||
    --cutoff_len 1024 \
 | 
			
		||||
    --preprocessing_num_workers 16 \
 | 
			
		||||
    --per_device_train_batch_size 3 \
 | 
			
		||||
    --per_device_train_batch_size 1 \
 | 
			
		||||
    --per_device_eval_batch_size 1 \
 | 
			
		||||
    --gradient_accumulation_steps 1 \
 | 
			
		||||
    --gradient_accumulation_steps 8 \
 | 
			
		||||
    --lr_scheduler_type cosine \
 | 
			
		||||
    --logging_steps 1 \
 | 
			
		||||
    --logging_steps 10 \
 | 
			
		||||
    --warmup_steps 20 \
 | 
			
		||||
    --save_steps 100 \
 | 
			
		||||
    --eval_steps 100 \
 | 
			
		||||
    --evaluation_strategy steps \
 | 
			
		||||
    --load_best_model_at_end \
 | 
			
		||||
    --learning_rate 5e-5 \
 | 
			
		||||
    --num_train_epochs 100 \
 | 
			
		||||
    --num_train_epochs 100.0 \
 | 
			
		||||
    --max_samples 3000 \
 | 
			
		||||
    --val_size 0.1 \
 | 
			
		||||
    --plot_loss \
 | 
			
		||||
    --bf16
 | 
			
		||||
    --fp16
 | 
			
		||||
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer
 | 
			
		||||
    from vllm import AsyncLLMEngine
 | 
			
		||||
 | 
			
		||||
@ -46,6 +47,7 @@ class BaseEngine(ABC):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]: ...
 | 
			
		||||
 | 
			
		||||
@ -55,6 +57,7 @@ class BaseEngine(ABC):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]: ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
    from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -36,9 +38,10 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
 | 
			
		||||
        return task.result()
 | 
			
		||||
 | 
			
		||||
    async def achat(
 | 
			
		||||
@ -46,18 +49,20 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, **input_kwargs)
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, **input_kwargs)
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
 | 
			
		||||
@ -70,9 +75,10 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
 | 
			
		||||
            yield new_token
 | 
			
		||||
 | 
			
		||||
    def get_scores(
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,9 @@ from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
    from trl import PreTrainedModelWrapper
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
@ -30,7 +32,9 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        generating_args: "GeneratingArguments",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.can_generate = finetuning_args.stage == "sft"
 | 
			
		||||
        self.tokenizer = load_tokenizer(model_args)
 | 
			
		||||
        tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
        self.tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
        self.processor = tokenizer_module["processor"]
 | 
			
		||||
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
 | 
			
		||||
        self.model = load_model(
 | 
			
		||||
@ -42,13 +46,18 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
    def _process_args(
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: Dict[str, Any],
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Tuple[Dict[str, Any], int]:
 | 
			
		||||
        if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
 | 
			
		||||
            messages[0]["content"] = messages[0]["content"] + "<image>"
 | 
			
		||||
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(
 | 
			
		||||
            tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
 | 
			
		||||
@ -95,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            logits_processor=get_logits_processor(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if processor is not None and image is not None:
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
            pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
 | 
			
		||||
            gen_kwargs["pixel_values"] = pixel_values.to(model.device)
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -102,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
    def _chat(
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: Dict[str, Any],
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
 | 
			
		||||
            model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        generate_output = model.generate(**gen_kwargs)
 | 
			
		||||
        response_ids = generate_output[:, prompt_length:]
 | 
			
		||||
@ -135,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
    def _stream_chat(
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: Dict[str, Any],
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Callable[[], str]:
 | 
			
		||||
        gen_kwargs, _ = HuggingfaceEngine._process_args(
 | 
			
		||||
            model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
			
		||||
        gen_kwargs["streamer"] = streamer
 | 
			
		||||
@ -199,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -208,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        input_args = (
 | 
			
		||||
            self.model,
 | 
			
		||||
            self.tokenizer,
 | 
			
		||||
            self.processor,
 | 
			
		||||
            self.template,
 | 
			
		||||
            self.generating_args,
 | 
			
		||||
            messages,
 | 
			
		||||
            system,
 | 
			
		||||
            tools,
 | 
			
		||||
            image,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self._semaphore:
 | 
			
		||||
@ -224,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -233,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        input_args = (
 | 
			
		||||
            self.model,
 | 
			
		||||
            self.tokenizer,
 | 
			
		||||
            self.processor,
 | 
			
		||||
            self.template,
 | 
			
		||||
            self.generating_args,
 | 
			
		||||
            messages,
 | 
			
		||||
            system,
 | 
			
		||||
            tools,
 | 
			
		||||
            image,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self._semaphore:
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,10 @@ if is_vllm_available():
 | 
			
		||||
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
 | 
			
		||||
    from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,7 +32,9 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        infer_dtype = str(infer_dtype).split(".")[-1]
 | 
			
		||||
 | 
			
		||||
        self.can_generate = finetuning_args.stage == "sft"
 | 
			
		||||
        self.tokenizer = load_tokenizer(model_args)
 | 
			
		||||
        tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
        self.tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
        self.processor = tokenizer_module["processor"]
 | 
			
		||||
        self.tokenizer.padding_side = "left"
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
 | 
			
		||||
        self.generating_args = generating_args.to_dict()
 | 
			
		||||
@ -58,6 +63,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
 | 
			
		||||
@ -121,10 +127,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        final_output = None
 | 
			
		||||
        generator = await self._generate(messages, system, tools, **input_kwargs)
 | 
			
		||||
        generator = await self._generate(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
        async for request_output in generator:
 | 
			
		||||
            final_output = request_output
 | 
			
		||||
 | 
			
		||||
@ -146,10 +153,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
        generator = await self._generate(messages, system, tools, **input_kwargs)
 | 
			
		||||
        generator = await self._generate(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
        async for result in generator:
 | 
			
		||||
            delta_text = result.outputs[0].text[len(generated_text) :]
 | 
			
		||||
            generated_text = result.outputs[0].text
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from .utils import Role
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
    from PIL.Image import Image
 | 
			
		||||
    from transformers import ProcessorMixin, Seq2SeqTrainingArguments
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
@ -271,7 +271,11 @@ def get_preprocess_and_print_func(
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
) -> Tuple[Callable, Callable]:
 | 
			
		||||
    if stage == "pt":
 | 
			
		||||
        preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
 | 
			
		||||
        preprocess_func = partial(
 | 
			
		||||
            preprocess_pretrain_dataset,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            data_args=data_args,
 | 
			
		||||
        )
 | 
			
		||||
        print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
 | 
			
		||||
    elif stage == "sft" and not training_args.predict_with_generate:
 | 
			
		||||
        if data_args.packing:
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ from .template import get_eval_template
 | 
			
		||||
class Evaluator:
 | 
			
		||||
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
        self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
 | 
			
		||||
        self.tokenizer = load_tokenizer(self.model_args)
 | 
			
		||||
        self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
 | 
			
		||||
        self.tokenizer.padding_side = "right"  # avoid overflow issue in batched inference for llama2
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
 | 
			
		||||
        self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
@ -196,6 +196,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
    if model_args.visual_inputs and data_args.packing:
 | 
			
		||||
        raise ValueError("Cannot use packing in MLLM fine-tuning.")
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
    _check_extra_dependencies(model_args, finetuning_args, training_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,8 +24,9 @@ def run_dpo(
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(
 | 
			
		||||
 | 
			
		||||
@ -24,8 +24,9 @@ def run_orpo(
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,9 @@ def run_ppo(
 | 
			
		||||
    generating_args: "GeneratingArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
 | 
			
		||||
 | 
			
		||||
    tokenizer.padding_side = "left"  # use left-padding in generation while using right-padding in training
 | 
			
		||||
 | 
			
		||||
@ -25,8 +25,9 @@ def run_pt(
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,8 +25,9 @@ def run_rm(
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,9 +29,9 @@ def run_sft(
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
 | 
			
		||||
    dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
 | 
			
		||||
    if training_args.predict_with_generate:
 | 
			
		||||
        tokenizer.padding_side = "left"  # use left-padding in generation
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
 | 
			
		||||
    if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
 | 
			
		||||
        raise ValueError("Please merge adapters before quantizing the model.")
 | 
			
		||||
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)["tokenizer"]
 | 
			
		||||
    get_template_and_fix_tokenizer(tokenizer, data_args.template)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -91,7 +91,7 @@ def create_ref_model(
 | 
			
		||||
        )
 | 
			
		||||
        ref_model_args = ModelArguments(**ref_model_args_dict)
 | 
			
		||||
        ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
 | 
			
		||||
        tokenizer = load_tokenizer(ref_model_args)
 | 
			
		||||
        tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
 | 
			
		||||
        ref_model = load_model(
 | 
			
		||||
            tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
 | 
			
		||||
        )
 | 
			
		||||
@ -100,7 +100,7 @@ def create_ref_model(
 | 
			
		||||
        if finetuning_args.finetuning_type == "lora":
 | 
			
		||||
            ref_model = None
 | 
			
		||||
        else:
 | 
			
		||||
            tokenizer = load_tokenizer(model_args)
 | 
			
		||||
            tokenizer = load_tokenizer(model_args)["tokenizer"]
 | 
			
		||||
            ref_model = load_model(
 | 
			
		||||
                tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
 | 
			
		||||
            )
 | 
			
		||||
@ -147,7 +147,7 @@ def create_reward_model(
 | 
			
		||||
        )
 | 
			
		||||
        reward_model_args = ModelArguments(**reward_model_args_dict)
 | 
			
		||||
        reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
 | 
			
		||||
        tokenizer = load_tokenizer(reward_model_args)
 | 
			
		||||
        tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
 | 
			
		||||
        reward_model = load_model(
 | 
			
		||||
            tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,8 @@ import json
 | 
			
		||||
import os
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
from ..chat import ChatModel
 | 
			
		||||
from ..data import Role
 | 
			
		||||
from ..extras.misc import torch_gc
 | 
			
		||||
@ -112,6 +114,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: str,
 | 
			
		||||
        tools: str,
 | 
			
		||||
        image: Optional[NDArray],
 | 
			
		||||
        max_new_tokens: int,
 | 
			
		||||
        top_p: float,
 | 
			
		||||
        temperature: float,
 | 
			
		||||
@ -119,7 +122,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        chatbot[-1][1] = ""
 | 
			
		||||
        response = ""
 | 
			
		||||
        for new_text in self.stream_chat(
 | 
			
		||||
            messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
 | 
			
		||||
            messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
 | 
			
		||||
        ):
 | 
			
		||||
            response += new_text
 | 
			
		||||
            if tools:
 | 
			
		||||
 | 
			
		||||
@ -23,9 +23,15 @@ def create_chat_box(
 | 
			
		||||
        messages = gr.State([])
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            with gr.Column(scale=4):
 | 
			
		||||
                role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
 | 
			
		||||
                system = gr.Textbox(show_label=False)
 | 
			
		||||
                tools = gr.Textbox(show_label=False, lines=2)
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
                    with gr.Column():
 | 
			
		||||
                        role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
 | 
			
		||||
                        system = gr.Textbox(show_label=False)
 | 
			
		||||
                        tools = gr.Textbox(show_label=False, lines=4)
 | 
			
		||||
 | 
			
		||||
                    with gr.Column():
 | 
			
		||||
                        image = gr.Image(type="numpy")
 | 
			
		||||
 | 
			
		||||
                query = gr.Textbox(show_label=False, lines=8)
 | 
			
		||||
                submit_btn = gr.Button(variant="primary")
 | 
			
		||||
 | 
			
		||||
@ -43,7 +49,7 @@ def create_chat_box(
 | 
			
		||||
        [chatbot, messages, query],
 | 
			
		||||
    ).then(
 | 
			
		||||
        engine.chatter.stream,
 | 
			
		||||
        [chatbot, messages, system, tools, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages],
 | 
			
		||||
    )
 | 
			
		||||
    clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
 | 
			
		||||
@ -56,6 +62,7 @@ def create_chat_box(
 | 
			
		||||
            role=role,
 | 
			
		||||
            system=system,
 | 
			
		||||
            tools=tools,
 | 
			
		||||
            image=image,
 | 
			
		||||
            query=query,
 | 
			
		||||
            submit_btn=submit_btn,
 | 
			
		||||
            max_new_tokens=max_new_tokens,
 | 
			
		||||
 | 
			
		||||
@ -1073,6 +1073,17 @@ LOCALES = {
 | 
			
		||||
            "placeholder": "工具列表(非必填)",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "image": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Image (optional)",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Изображение (по желанию)",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "图像(非必填)",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "query": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "placeholder": "Input...",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user