Former-commit-id: 38b6b0f52e
This commit is contained in:
hiyouga
2024-06-16 01:06:41 +08:00
parent 96b82ccd4d
commit c0c6b8075a
22 changed files with 27 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -114,8 +114,8 @@ class CustomKTOTrainer(KTOTrainer):
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def forward(

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team.
#
# This code is inspired by CarperAI's trlx library.
# This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,8 +89,8 @@ class PairwiseTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def compute_loss(

View File

@@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team.
#
# This code is inspired by CarperAI's trlx library.
# This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by HuggingFace's transformers library and THUDM's ChatGLM implementation.
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#

View File

@@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");