mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
support falcon model #72
Former-commit-id: c136f362c1aa75d3374b151188ba4a55d9313a59
This commit is contained in:
parent
827ff46008
commit
c87910ada3
@ -9,6 +9,8 @@
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/07/05] Now we support training the Falcon-7B/40B models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||||
|
|
||||||
[23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
[23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||||
|
|
||||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
|
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
|
||||||
@ -23,6 +25,7 @@
|
|||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||||
|
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||||
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B)
|
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B)
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
@ -283,6 +286,8 @@ Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/ma
|
|||||||
|
|
||||||
Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
|
Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
|
||||||
|
|
||||||
|
Please follow the [Apache-2.0 License](LICENSE) to use the Falcon models.
|
||||||
|
|
||||||
Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model.
|
Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
@ -13,7 +13,6 @@ def main():
|
|||||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||||
print("Remember to copy the *.py files from the original directory.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -183,6 +183,7 @@ def load_pretrained(
|
|||||||
load_in_8bit=True,
|
load_in_8bit=True,
|
||||||
llm_int8_threshold=6.0
|
llm_int8_threshold=6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
elif model_args.quantization_bit == 4:
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||||
@ -195,6 +196,7 @@ def load_pretrained(
|
|||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
)
|
)
|
||||||
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
@ -211,10 +213,20 @@ def load_pretrained(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
torch_dtype=model_args.compute_dtype,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register auto class to save the custom code files.
|
||||||
|
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
||||||
|
config.__class__.register_for_auto_class()
|
||||||
|
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||||
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
||||||
|
model.__class__.register_for_auto_class()
|
||||||
|
|
||||||
|
# Initialize adapters
|
||||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
|
|
||||||
@ -487,49 +499,49 @@ def preprocess_data(
|
|||||||
# for input with history, we build multiple input-label pairs just like:
|
# for input with history, we build multiple input-label pairs just like:
|
||||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
|
max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
for dialog in get_dialog(examples):
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
for i in range(len(dialog) // 2):
|
for i in range(len(dialog) // 2):
|
||||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
|
||||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
|
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||||
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
break
|
||||||
|
|
||||||
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
|
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||||
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
|
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||||
if len(labels) > data_args.max_source_length + data_args.max_target_length:
|
|
||||||
labels = labels[:data_args.max_source_length + data_args.max_target_length]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples):
|
def preprocess_unsupervised_dataset(examples):
|
||||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
for dialog in get_dialog(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
if len(target_ids) > data_args.max_target_length:
|
||||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
target_ids = target_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
input_ids = [tokenizer.bos_token_id] + source_ids
|
model_inputs["input_ids"].append(source_ids)
|
||||||
labels = [tokenizer.bos_token_id] + target_ids
|
model_inputs["labels"].append(target_ids)
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
@ -538,19 +550,19 @@ def preprocess_data(
|
|||||||
for dialog in get_dialog(examples):
|
for dialog in get_dialog(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
|
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||||
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
|
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["accept_ids"].append(accept_ids)
|
model_inputs["accept_ids"].append(accept_ids)
|
||||||
model_inputs["reject_ids"].append(reject_ids)
|
model_inputs["reject_ids"].append(reject_ids)
|
||||||
|
@ -198,6 +198,7 @@ class FinetuningArguments:
|
|||||||
metadata={"help": "Number of decoder blocks in the model. \
|
metadata={"help": "Number of decoder blocks in the model. \
|
||||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||||
|
Falcon choices: [\"32\", \"60\"], \
|
||||||
Baichuan choices: [\"32\"]"}
|
Baichuan choices: [\"32\"]"}
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
@ -208,7 +209,7 @@ class FinetuningArguments:
|
|||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
@ -227,7 +228,7 @@ class FinetuningArguments:
|
|||||||
default="q_proj,v_proj",
|
default="q_proj,v_proj",
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ def prepare_model_for_training(
|
|||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
|
layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user