mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-15 17:42:48 +08:00
support RM metrics, add generating Args
Former-commit-id: cec6524d6b1be65c5d171a5b3dcaae7818132bc5
This commit is contained in:
parent
5fe70c9350
commit
1fbda5d139
Binary file not shown.
Before Width: | Height: | Size: 141 KiB After Width: | Height: | Size: 146 KiB |
@ -1 +1 @@
|
||||
f437d58b7791609ee91f064551c5c5734a0fd97a
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
@ -1 +1 @@
|
||||
0e346cf70e633456c7e83f68765361016005447a
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
@ -79,11 +79,11 @@
|
||||
},
|
||||
"comparison_gpt4_en": {
|
||||
"file_name": "comparison_gpt4_data_en.json",
|
||||
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
|
||||
"file_sha1": "96fa18313544e22444fe20eead7754b17da452ae"
|
||||
},
|
||||
"comparison_gpt4_zh": {
|
||||
"file_name": "comparison_gpt4_data_zh.json",
|
||||
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
|
||||
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd"
|
||||
},
|
||||
"hh_rlhf_en": {
|
||||
"script_url": "hh_rlhf_en",
|
||||
@ -103,14 +103,5 @@
|
||||
"response": "",
|
||||
"history": ""
|
||||
}
|
||||
},
|
||||
"pretrain_data": {
|
||||
"file_name": "pretrain_data",
|
||||
"columns": {
|
||||
"prompt": "content",
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,11 +2,11 @@ torch>=1.13.1
|
||||
protobuf
|
||||
cpm_kernels
|
||||
sentencepiece
|
||||
transformers>=4.27.4
|
||||
datasets>=2.10.0
|
||||
accelerate>=0.18.0
|
||||
transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.19.0
|
||||
peft>=0.3.0
|
||||
trl>=0.4.1
|
||||
trl>=0.4.4
|
||||
jieba
|
||||
rouge_chinese
|
||||
nltk
|
||||
|
@ -42,7 +42,7 @@ app = FastAPI()
|
||||
|
||||
@app.post("/")
|
||||
async def create_item(request: Request):
|
||||
global model, tokenizer, prompt_template
|
||||
global model, tokenizer, prompt_template, generating_args
|
||||
|
||||
# Parse the request JSON
|
||||
json_post_raw = await request.json()
|
||||
@ -56,16 +56,9 @@ async def create_item(request: Request):
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
# Generation arguments
|
||||
gen_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.95,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 1.0,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["input_ids"] = input_ids
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
|
||||
# Generate response
|
||||
with torch.no_grad():
|
||||
@ -95,7 +88,7 @@ async def create_item(request: Request):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
|
@ -15,7 +15,7 @@ from transformers import TextIteratorStreamer
|
||||
|
||||
def main():
|
||||
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
@ -25,17 +25,10 @@ def main():
|
||||
def predict_and_print(query, history: list):
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.95,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 1.0,
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
}
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["input_ids"] = input_ids
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["streamer"] = streamer
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
response = ""
|
||||
|
@ -6,18 +6,17 @@
|
||||
import math
|
||||
|
||||
from torch.optim import AdamW
|
||||
|
||||
from transformers.optimization import get_scheduler
|
||||
from trl import PPOConfig
|
||||
|
||||
from utils import (
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
DynamicDataCollatorWithPadding,
|
||||
PPOPeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
@ -29,7 +28,7 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
|
@ -5,14 +5,15 @@
|
||||
|
||||
|
||||
import math
|
||||
|
||||
from utils import (
|
||||
DynamicDataCollatorWithPadding,
|
||||
PeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DynamicDataCollatorWithPadding,
|
||||
PeftTrainer,
|
||||
LogCallback,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
@ -24,7 +25,7 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
|
@ -6,13 +6,14 @@
|
||||
|
||||
|
||||
from utils import (
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
PairwiseDataCollatorWithPadding,
|
||||
PairwisePeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
compute_accuracy,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
@ -23,7 +24,7 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
@ -45,6 +46,7 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
|
@ -5,14 +5,14 @@
|
||||
|
||||
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DynamicDataCollatorWithPadding,
|
||||
Seq2SeqPeftTrainer,
|
||||
ComputeMetrics,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
get_logits_processor,
|
||||
plot_loss
|
||||
)
|
||||
@ -25,7 +25,7 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
|
@ -11,7 +11,7 @@ from .data_collator import DynamicDataCollatorWithPadding
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .template import Template
|
||||
|
@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
from .config import (
|
||||
ModelArguments,
|
||||
DataTrainingArguments,
|
||||
FinetuningArguments
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
from .template import Template
|
||||
@ -54,7 +55,8 @@ check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -91,12 +93,10 @@ def _init_adapter(
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert is_mergeable and len(
|
||||
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
else:
|
||||
assert is_mergeable or len(
|
||||
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
|
||||
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
@ -106,8 +106,7 @@ def _init_adapter(
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (
|
||||
not is_mergeable): # continually train on the lora weights
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
@ -186,11 +185,9 @@ def load_pretrained(
|
||||
)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.0.dev0",
|
||||
"To fix: pip install git+https://github.com/huggingface/transformers.git")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
require_version("accelerate>=0.20.0.dev0",
|
||||
"To fix: pip install git+https://github.com/huggingface/accelerate.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@ -201,10 +198,10 @@ def load_pretrained(
|
||||
else:
|
||||
raise NotImplementedError
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable:
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
@ -221,6 +218,13 @@ def load_pretrained(
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
load_valuehead_params(model, model_args.checkpoint_dir[0])
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
@ -228,11 +232,6 @@ def load_pretrained(
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
load_valuehead_params(model, model_args.reward_model)
|
||||
|
||||
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
|
||||
# To meet the compliance requirements of the transformers library
|
||||
if model_args.quantization_bit is not None:
|
||||
model._is_int8_training_enabled = True
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
@ -245,11 +244,11 @@ def load_pretrained(
|
||||
def prepare_args(
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1]))
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
@ -313,13 +312,14 @@ def prepare_args(
|
||||
return model_args, data_args, training_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
|
||||
if data_args.prompt_template == "alpaca":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def prepare_data(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataTrainingArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
@ -358,10 +359,12 @@ def prepare_data(
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
checksum(data_file, dataset_attr.file_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension if extension in ["csv", "json"] else "text",
|
||||
data_files=data_file,
|
||||
@ -406,6 +409,7 @@ def preprocess_data(
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
@ -55,11 +56,11 @@ class ModelArguments:
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use."}
|
||||
metadata={"help": "Quantization data type to use in int4 training."}
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
@ -67,8 +68,7 @@ class ModelArguments:
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
@ -76,8 +76,7 @@ class ModelArguments:
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
@ -156,41 +155,24 @@ class DataTrainingArguments:
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
dataset_attrs = []
|
||||
dataset_attr = None
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
)
|
||||
else:
|
||||
# Support Directory
|
||||
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
|
||||
path = os.path.join(dataset_info[name]["file_name"], file_name)
|
||||
dataset_attrs.append(DatasetAttr(
|
||||
"file",
|
||||
file_name=path,
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
))
|
||||
if dataset_attr is not None:
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
self.dataset_list.append(dataset_attr)
|
||||
else:
|
||||
for i, dataset_attr in enumerate(dataset_attrs):
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
@ -228,22 +210,20 @@ class FinetuningArguments:
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str):
|
||||
self.lora_target = [target.strip() for target in
|
||||
self.lora_target.split(",")] # support custom target modules of LoRA
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
|
||||
trainable_layer_ids]
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
@ -259,3 +239,44 @@ class FinetuningArguments:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.95,
|
||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||
)
|
||||
infer_num_beams: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data_dict = asdict(self)
|
||||
num_beams = data_dict.pop("infer_num_beams")
|
||||
data_dict["num_beams"] = num_beams
|
||||
return data_dict
|
||||
|
@ -3,7 +3,6 @@ import torch
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from transformers import DataCollatorWithPadding, BatchEncoding
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from .other import IGNORE_INDEX
|
||||
@ -16,11 +15,9 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
ignore_pad_token_for_loss: Optional[bool] = False
|
||||
):
|
||||
super().__init__(tokenizer, padding=True)
|
||||
self.model = model
|
||||
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
|
||||
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Dict, Sequence, Union
|
||||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
from .data_collator import DynamicDataCollatorWithPadding
|
||||
|
||||
@ -10,6 +11,12 @@ from .other import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
preds = np.array(preds)
|
||||
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
|
||||
|
||||
|
||||
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
@ -47,5 +54,4 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
_, _, values = model(**inputs)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
outputs = {"r_accept": r_accept, "r_reject": r_reject}
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss
|
||||
|
@ -14,27 +14,25 @@ class Template:
|
||||
return getattr(self, "_format_{}".format(self.name))(query, history, prefix)
|
||||
|
||||
def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||
prompt = prefix
|
||||
if history:
|
||||
for old_query, response in history:
|
||||
prompt += old_query + "\n" + response + "\n"
|
||||
prompt += query
|
||||
return prompt
|
||||
r"""
|
||||
Use for language model inference without histories.
|
||||
"""
|
||||
return query
|
||||
|
||||
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
if prefix:
|
||||
prompt = prefix
|
||||
else:
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n\n"
|
||||
prompt += "Instruction:\n"
|
||||
if history:
|
||||
for old_query, response in history:
|
||||
prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response)
|
||||
prompt += "Human:\n{}\n\nAssistant:".format(query)
|
||||
prompt += "### Instruction:\n{}\n\n### Response:\n{}\n\n".format(old_query, response)
|
||||
prompt += "### Instruction:\n{}\n\n### Response:\n".format(query)
|
||||
return prompt
|
||||
|
||||
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||
|
@ -21,7 +21,7 @@ from transformers.utils.versions import require_version
|
||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
|
||||
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
@ -87,9 +87,9 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||
"do_sample": True,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"num_beams": 1,
|
||||
"num_beams": generating_args.infer_num_beams,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": 1.0,
|
||||
"repetition_penalty": generating_args.repetition_penalty,
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
}
|
||||
@ -133,8 +133,8 @@ with gr.Blocks() as demo:
|
||||
with gr.Column(scale=1):
|
||||
emptyBtn = gr.Button("Clear History")
|
||||
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user