diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 4bebcd68..5ae79774 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -31,31 +31,31 @@ class DatasetAttr: Dataset attributes. """ - """ basic configs """ + # basic configs load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str formatting: Literal["alpaca", "sharegpt"] = "alpaca" ranking: bool = False - """ extra configs """ + # extra configs subset: Optional[str] = None folder: Optional[str] = None num_samples: Optional[int] = None - """ common columns """ + # common columns system: Optional[str] = None tools: Optional[str] = None images: Optional[str] = None - """ rlhf columns """ + # rlhf columns chosen: Optional[str] = None rejected: Optional[str] = None kto_tag: Optional[str] = None - """ alpaca columns """ + # alpaca columns prompt: Optional[str] = "instruction" query: Optional[str] = "input" response: Optional[str] = "output" history: Optional[str] = None - """ sharegpt columns """ + # sharegpt columns messages: Optional[str] = "conversations" - """ sharegpt tags """ + # sharegpt tags role_tag: Optional[str] = "from" content_tag: Optional[str] = "value" user_tag: Optional[str] = "human" diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a00411c3..d8892a96 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -509,15 +509,19 @@ register_model_group( }, "Gemma-2-9B": { DownloadSource.DEFAULT: "google/gemma-2-9b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", }, "Gemma-2-27B": { DownloadSource.DEFAULT: "google/gemma-2-27b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", }, "Gemma-2-9B-Chat": { DownloadSource.DEFAULT: "google/gemma-2-9b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", }, "Gemma-2-27B-Chat": { DownloadSource.DEFAULT: "google/gemma-2-27b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", }, }, template="gemma", diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 8b89e38a..2f9978a5 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs from tqdm import tqdm from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers.optimization import get_scheduler +from transformers.trainer import DEFAULT_CALLBACKS from transformers.trainer_callback import CallbackHandler from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -105,6 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) ] ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin + if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs + ppo_config.log_with = None # Create optimizer and scheduler if training_args.max_steps > 0: @@ -143,6 +146,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.control = TrainerControl() self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks self.callback_handler = CallbackHandler( callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler ) @@ -339,11 +343,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer): batch[k] = v[:, start_index:] with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6 + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) if self.model_args.upcast_layernorm: layernorm_params = dump_layernorm(unwrapped_model) - generate_output: torch.Tensor = unwrapped_model.generate( + generate_output: "torch.Tensor" = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch ) if self.model_args.upcast_layernorm: @@ -354,12 +358,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer): queries, responses = [], [] for i in range(len(query)): query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() - response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() + response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero() - if len(response_index) == 0: - response_length = 1 # allow empty response + if len(response_indexes) == 0: # allow empty response + response_length = 1 + elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token + response_length = response_indexes[-1].item() + 2 else: - response_length = response_index[-1].item() + 1 + response_length = response_indexes[-1].item() + 1 queries.append(query[i, query_start_index:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right @@ -382,7 +388,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return get_rewards_from_server(self.reward_model, messages) - batch = self.prepare_model_inputs(queries, responses) + batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) if self.finetuning_args.reward_model_type == "lora": @@ -392,7 +398,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): reward_model = self.reward_model with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 - _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) + _, _, values = reward_model(**batch, return_dict=True, use_cache=False) if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="default") @@ -400,13 +406,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.is_chatglm_model: # assume same architecture values = torch.transpose(values, 0, 1) - rewards = [] - for i in range(values.size(0)): - end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero() - end_index = end_indexes[-1].item() if len(end_indexes) else 0 - rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type - - return rewards + rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return rewards.to(torch.float32).detach().cpu() # use fp32 type @PPODecorators.empty_device_cache() def batched_forward_pass( @@ -440,7 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): attention_mask = input_kwargs["attention_mask"] with self.amp_context: # support bf16 - logits, _, values = model(**input_kwargs) + logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False) if self.is_chatglm_model: values = torch.transpose(values, 0, 1) diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index fb880b1c..0dfae013 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict import numpy as np -def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: - preds, _ = eval_preds - return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} +if TYPE_CHECKING: + from transformers import EvalPrediction + + +def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: + return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])} diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index accc877d..f7160cfc 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -1,7 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by the CarperAI's trlx library. -# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,28 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# MIT License -# -# Copyright (c) 2022 CarperAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. import json import os @@ -53,6 +31,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: from transformers import PreTrainedModel, ProcessorMixin from transformers.trainer import PredictionOutput + from trl import AutoModelForCausalLMWithValueHead from ...hparams import FinetuningArguments @@ -108,46 +87,23 @@ class PairwiseTrainer(Trainer): See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842 """ # Compute rewards - _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False) - unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) if getattr(unwrapped_model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) - # Split the inputs and rewards into two parts, chosen and rejected batch_size = inputs["input_ids"].size(0) // 2 - chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] - chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] - chosen_scores, rejected_scores = [], [] - - # Compute pairwise loss. Only backprop on the different tokens before padding - loss = 0 - for i in range(batch_size): - chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 - rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 - check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() - - if len(check_divergence) == 0: - end_index = chosen_length - div_index = end_index - 1 - else: - end_index = max(chosen_length, rejected_length) - div_index = check_divergence[0] - - assert div_index > 0 - chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] - rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] - if return_outputs: # use the score on the last token except pad token for inference - chosen_scores.append(chosen_rewards[i, chosen_length - 1]) - rejected_scores.append(rejected_rewards[i, rejected_length - 1]) - loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() - - loss = loss / batch_size + chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0) + chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0) + chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) + rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) + chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() + loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean() if return_outputs: - chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) - return loss, [loss, chosen_scores, rejected_scores] - - return loss + return loss, (loss, chosen_scores, rejected_scores) + else: + return loss def save_predictions(self, predict_results: "PredictionOutput") -> None: r""" diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index e0b32b77..384814cc 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -1,7 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by the CarperAI's trlx library. -# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,28 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# MIT License -# -# Copyright (c) 2022 CarperAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. from typing import TYPE_CHECKING, List, Optional diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index c69608c0..86f8bb15 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Dict import numpy as np import torch -from transformers import EvalPrediction from transformers.utils import is_jieba_available, is_nltk_available from ...extras.constants import IGNORE_INDEX @@ -29,7 +28,7 @@ from ...extras.packages import is_rouge_available if TYPE_CHECKING: - from transformers import PreTrainedTokenizer + from transformers import EvalPrediction, PreTrainedTokenizer if is_jieba_available(): diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index dc982e07..99f2b660 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb elif finetuning_args.stage == "kto": run_kto(model_args, data_args, training_args, finetuning_args, callbacks) else: - raise ValueError("Unknown task.") + raise ValueError("Unknown task: {}.".format(finetuning_args.stage)) def export_model(args: Optional[Dict[str, Any]] = None) -> None: