use pre-commit

Former-commit-id: 21db8ed2f4a0eba203754a92ce0741538e8ee709
This commit is contained in:
hiyouga 2024-10-29 09:07:46 +00:00
parent 163cf2ba5c
commit 0d8aa6e6ef
86 changed files with 1048 additions and 1064 deletions

View File

@ -17,9 +17,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
_LICENSE = "gpl-3.0" _LICENSE = "gpl-3.0"
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = [] conversations = []

View File

@ -8,9 +8,9 @@ import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit" _LICENSE = "mit"
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = { _URLS = {
"train": [ "train": [
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
data = json.loads(row) data = json.loads(row)
chosen = data["chosen"] chosen = data["chosen"]

View File

@ -20,9 +20,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0" _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT) _BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)

View File

@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
df = pd.read_csv(filepath, header=None) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): yield from enumerate(df.to_dict(orient="records"))
yield i, instance

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# #
# This code is inspired by the Microsoft's DeepSpeed library. # This code is inspired by the Microsoft's DeepSpeed library.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team. # Copyright 2024 imoneoi and the LlamaFactory team.
# #
# This code is inspired by the imoneoi's OpenChat library. # This code is inspired by the imoneoi's OpenChat library.
@ -74,7 +73,7 @@ def calculate_lr(
elif stage == "sft": elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0 valid_tokens, total_tokens = 0, 0

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
elif "4090" in device_name: elif "4090" in device_name:
return 98 * 1e12 * world_size return 98 * 1e12 * world_size
else: else:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu( def calculate_mfu(
@ -140,10 +139,10 @@ def calculate_mfu(
"bf16": True, "bf16": True,
} }
if deepspeed_stage in [2, 3]: if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args) run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f) result = json.load(f)
if dist.is_initialized(): if dist.is_initialized():
@ -157,7 +156,7 @@ def calculate_mfu(
* compute_model_flops(model_name_or_path, total_batch_size, seq_length) * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size) / compute_device_flops(world_size)
) )
print("MFU: {:.2f}%".format(mfu_value * 100)) print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def calculate_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
) )
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
@ -125,8 +124,8 @@ def calculate_ppl(
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2) json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities))) print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print("Perplexities have been saved at {}.".format(save_name)) print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -61,7 +60,7 @@ def length_cdf(
for length, count in length_tuples: for length, count in length_tuples:
count_accu += count count_accu += count
prob_accu += count / total_num * 100 prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval)) print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# #
# This code is inspired by the Tencent's LLaMA-Pro library. # This code is inspired by the Tencent's LLaMA-Pro library.
@ -40,7 +39,7 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str: def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index)) return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion( def block_expansion(
@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict() state_dict = model.state_dict()
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand)) raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
output_state_dict = OrderedDict() output_state_dict = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1 layer_cnt += 1
if (i + 1) % split == 0: if (i + 1) % split == 0:
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key: if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else: else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1 layer_cnt += 1
for key, value in state_dict.items(): for key, value in state_dict.items():
@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze") print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand)) print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true") print("use_llama_pro: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2( def llamafy_baichuan2(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key: elif "lm_head" in key:
llama2_state_dict[key] = value llama2_state_dict[key] = value
else: else:
raise KeyError("Unable to process key {}".format(key)) raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "") return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: Dict[str, Any] = OrderedDict()
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen( def llamafy_qwen(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -70,19 +69,19 @@ def quantize_loftq(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir)) print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(loftq_dir)) print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits)) print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
target_modules=lora_target, target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
) )
# Init PiSSA model # Init PiSSA model
@ -65,17 +64,17 @@ def quantize_pissa(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir)) print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(pissa_dir)) print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("pissa_init: false") print("pissa_init: false")
print("pissa_convert: true") print("pissa_convert: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str: def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]: def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
@ -61,7 +61,7 @@ extra_require = {
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
"dev": ["ruff", "pytest"], "dev": ["pre-commit", "ruff", "pytest"],
} }
@ -72,7 +72,7 @@ def main():
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",

View File

@ -25,7 +25,7 @@ def main():
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -130,5 +130,5 @@ def run_api() -> None:
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -70,7 +70,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -142,7 +142,7 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response( async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
@ -193,7 +193,7 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response( async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, image = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex) score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")

View File

@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)

View File

@ -105,7 +105,7 @@ class VllmEngine(BaseEngine):
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None: if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
@ -159,7 +159,7 @@ class VllmEngine(BaseEngine):
if image is not None: # add image features if image is not None: # add image features
if not isinstance(image, (str, ImageObject)): if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image))) raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image).convert("RGB") image = Image.open(image).convert("RGB")

View File

@ -47,7 +47,7 @@ USAGE = (
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) + f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION)) + " " * (21 - len(VERSION))
+ "|\n|" + "|\n|"
+ " " * 56 + " " * 56
@ -90,7 +90,7 @@ def main():
if force_torchrun or get_device_count() > 1: if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}.".format(command)) raise NotImplementedError(f"Unknown command: {command}.")

View File

@ -161,7 +161,7 @@ def convert_sharegpt(
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning(f"Invalid role tag in {messages}.")
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
@ -171,7 +171,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning(f"Invalid message count in {messages}.")
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
@ -192,7 +192,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1] chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True broken_data = True
prompt = aligned_messages prompt = aligned_messages

View File

@ -137,9 +137,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"): for key in ("chosen", "rejected"):
for feature in features: for feature in features:
target_feature = { target_feature = {
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature[f"{key}_attention_mask"],
"labels": feature["{}_labels".format(key)], "labels": feature[f"{key}_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
} }

View File

@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy)) raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset( def split_dataset(

View File

@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str): if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value)) raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -51,7 +51,7 @@ def _load_single_dataset(
r""" r"""
Loads a single dataset and aligns it to the standard format. Loads a single dataset and aligns it to the standard format.
""" """
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
@ -77,12 +77,12 @@ def _load_single_dataset(
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File {} not found.".format(local_path)) raise ValueError(f"File {local_path} not found.")
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else: else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
@ -145,7 +145,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
@ -243,7 +243,7 @@ def get_dataset(
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
@ -294,8 +294,8 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)

View File

@ -111,7 +111,7 @@ class BasePlugin:
image = Image.open(image["path"]) image = Image.open(image["path"])
if not isinstance(image, ImageObject): if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs)) results.append(self._preprocess_image(image, **kwargs))
@ -253,7 +253,7 @@ class LlavaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -302,7 +302,7 @@ class LlavaNextPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@override @override
@ -366,10 +366,10 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -408,7 +408,7 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", "") message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -493,7 +493,7 @@ class Qwen2vlPlugin(BasePlugin):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw): if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
@ -506,7 +506,7 @@ class Qwen2vlPlugin(BasePlugin):
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, VIDEO_PLACEHOLDER,
@ -520,10 +520,10 @@ class Qwen2vlPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
return messages return messages
@ -583,10 +583,10 @@ class VideoLlavaPlugin(BasePlugin):
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token)) raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token)) raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
return messages return messages
@ -622,6 +622,6 @@ def get_mm_plugin(
) -> "BasePlugin": ) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None) plugin_class = PLUGINS.get(name, None)
if plugin_class is None: if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name)) raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token) return plugin_class(image_token, video_token)

View File

@ -87,11 +87,11 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None dataset_info = None
@ -109,7 +109,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
continue continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name] has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name]

View File

@ -110,8 +110,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"])) print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"])) print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@ -160,7 +160,7 @@ def preprocess_packed_supervised_dataset(
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else: else:
lengths.append(length) lengths.append(length)
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
@ -212,4 +212,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@ -147,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids return token_ids
@ -275,9 +275,9 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added: if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info(f"Add eos token: {tokenizer.eos_token}")
else: else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
@ -365,13 +365,13 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
else: else:
template = TEMPLATES.get(data_args.template, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template)) raise ValueError(f"Template {data_args.template} does not exist.")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None: if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format)) logger.info(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
@ -389,7 +389,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(

View File

@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils": def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None) tool_utils = TOOLS.get(name, None)
if tool_utils is None: if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name)) raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils return tool_utils

View File

@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items() for category_name, category_correct in category_corrects.items()
if len(category_correct) if len(category_correct)
] ]

View File

@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, f"Template {name} does not exist."
return eval_template return eval_template

View File

@ -72,4 +72,4 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")

View File

@ -75,7 +75,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()] return logging._nameToLevel[env_level_str.upper()]
else: else:
raise ValueError("Unknown logging level: {}.".format(env_level_str)) raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level return _default_log_level

View File

@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for key in keys: for key in keys:
@ -92,7 +92,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
plt.figure() plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary)) plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()

View File

@ -67,8 +67,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
@ -323,7 +323,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None: if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if ( if (

View File

@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
@ -239,8 +239,8 @@ def _setup_lora_tuning(
logger.info("Using PiSSA initialization.") logger.info("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa" peft_kwargs["init_lora_weights"] = "pissa"
else: else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
else: else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model return model

View File

@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, config, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.warning("Processor was not found: {}.".format(e)) logger.warning(f"Processor was not found: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
@ -180,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args) vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
@ -198,7 +198,7 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param trainable_params, all_param, 100 * trainable_params / all_param
) )
else: else:
param_stats = "all params: {:,}".format(all_param) param_stats = f"all params: {all_param:,}"
logger.info(param_stats) logger.info(param_stats)

View File

@ -65,7 +65,7 @@ def configure_attn_implementation(
requested_attn_implementation = "flash_attention_2" requested_attn_implementation = "flash_attention_2"
else: else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation) setattr(config, "attn_implementation", requested_attn_implementation)

View File

@ -111,7 +111,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing: if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}

View File

@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0: if num_layers % num_layer_trainable != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
) )
stride = num_layers // num_layer_trainable stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = [] module_names = []
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any( if any(target_module in name for target_module in target_modules) and any(

View File

@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2 quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
@ -149,7 +149,7 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
@ -179,7 +179,7 @@ def configure_quantization(
else: else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value: elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig( init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance ) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value: elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8: if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.") raise ValueError("EETQ only accepts 8-bit quantization.")
@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq") require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig() init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@ -48,9 +48,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
current_max_length = getattr(config, "max_position_embeddings", None) current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length: if current_max_length and model_args.model_max_length > current_max_length:
logger.info( logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length) setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else: else:
@ -60,6 +58,4 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
scaling_factor = 2.0 scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info( logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@ -54,7 +54,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err: except Exception as err:
err_text = str(err) err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.") logger.info("Ignore the above message if you are not resuming the training of a value head model.")
return None return None

View File

@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
else: else:
return return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

View File

@ -92,7 +92,7 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir)) logger.info(f"Value head model saved at: {output_dir}")
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
@ -106,7 +106,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
) )
@ -123,7 +123,7 @@ class SaveProcessorCallback(TrainerCallback):
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override @override
@ -145,7 +145,7 @@ class PissaConvertCallback(TrainerCallback):
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
@ -159,7 +159,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
# 1. save a pissa backup with init_lora_weights: True # 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa # 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True # 3. load the pissa backup with init_lora_weights: True

View File

@ -156,7 +156,7 @@ class CustomDPOTrainer(DPOTrainer):
elif self.loss_type == "simpo": elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps) losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else: else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type)) raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
@ -245,16 +245,16 @@ class CustomDPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float() reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
if self.loss_type == "orpo": if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu() metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
return losses.mean(), metrics return losses.mean(), metrics

View File

@ -129,11 +129,11 @@ class CustomKTOTrainer(KTOTrainer):
""" """
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = { model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch[f"{prefix}attention_mask"],
} }
if "{}token_type_ids".format(prefix) in batch: if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length return logps, logps / valid_length
@override @override

View File

@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
device = v_head_layer.weight.device device = v_head_layer.weight.device
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:

View File

@ -218,18 +218,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero(): if self.is_world_process_zero():
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = {:,}".format(num_examples)) logger.info(f" Num examples = {num_examples:,}")
logger.info(" Num Epochs = {:,}".format(num_train_epochs)) logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size)) logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info( logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size total_train_batch_size
) )
) )
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps)) logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs)) logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
logger.info(" Total training steps = {:,}".format(max_steps)) logger.info(f" Total training steps = {max_steps:,}")
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0])) logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
loss_meter = AverageMeter() loss_meter = AverageMeter()
@ -290,7 +290,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if (step + 1) % self.args.save_steps == 0: # save checkpoint if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model( self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
) )
self.callback_handler.on_save(self.args, self.state, self.control) self.callback_handler.on_save(self.args, self.state, self.control)

View File

@ -116,7 +116,7 @@ def create_ref_model(
ref_model = load_model( ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from {}".format(finetuning_args.ref_model)) logger.info(f"Created reference model from {finetuning_args.ref_model}")
else: else:
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
@ -140,7 +140,7 @@ def create_reward_model(
""" """
if finetuning_args.reward_model_type == "api": if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url." assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model)) logger.info(f"Use reward server {finetuning_args.reward_model}")
return finetuning_args.reward_model return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora": elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
@ -157,7 +157,7 @@ def create_reward_model(
model.register_buffer( model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
) )
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) logger.info(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
return None return None
else: else:
reward_model_args = ModelArguments.copyfrom( reward_model_args = ModelArguments.copyfrom(
@ -171,7 +171,7 @@ def create_reward_model(
reward_model = load_model( reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
) )
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) logger.info(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model return reward_model
@ -231,7 +231,7 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor": elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor optim_class = GaLoreAdafactor
else: else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) raise NotImplementedError(f"Unknow optim: {training_args.optim}")
if finetuning_args.galore_layerwise: if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay), dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) logger.info(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
return optimizer return optimizer

View File

@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
elif finetuning_args.stage == "kto": elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks) run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else: else:
raise ValueError("Unknown task: {}.".format(finetuning_args.stage)) raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
@ -91,18 +91,18 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", output_dtype) setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype) model = model.to(output_dtype)
logger.info("Convert model dtype to: {}.".format(output_dtype)) logger.info(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format), safe_serialization=(not model_args.export_legacy_format),
) )
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
model.push_to_hub( model.push_to_hub(
model_args.export_hub_model_id, model_args.export_hub_model_id,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format), safe_serialization=(not model_args.export_legacy_format),
) )
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
) )
logger.info("Copied valuehead to {}.".format(model_args.export_dir)) logger.info(f"Copied valuehead to {model_args.export_dir}.")
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy( shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
) )
logger.info("Copied valuehead to {}.".format(model_args.export_dir)) logger.info(f"Copied valuehead to {model_args.export_dir}.")
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
@ -140,4 +140,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
) )
except Exception as e: except Exception as e:
logger.warning("Cannot save tokenizer, please copy the files manually: {}.".format(e)) logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")

View File

@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
Loads user config if exists. Loads user config if exists.
""" """
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
except Exception: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
Loads dataset_info.json. Loads dataset_info.json.
""" """
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir)) logger.info(f"dataset_dir is {dataset_dir}, using online dataset.")
return {} return {}
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
return json.load(f) return json.load(f)
except Exception as err: except Exception as err:
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
return {} return {}

View File

@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception: except Exception:
return gr.Button(interactive=False) return gr.Button(interactive=False)
@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
def _load_data_file(file_path: str) -> List[Any]: def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"): if file_path.endswith(".json"):
return json.load(f) return json.load(f)
elif file_path.endswith(".jsonl"): elif file_path.endswith(".jsonl"):
@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])

View File

@ -56,9 +56,9 @@ class Engine:
if not self.pure_chat: if not self.pure_chat:
current_time = get_time() current_time = get_time()
init_dict["train.current_time"] = {"value": current_time} init_dict["train.current_time"] = {"value": current_time}
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"}
init_dict["infer.mm_box"] = {"visible": False} init_dict["infer.mm_box"] = {"visible": False}
if user_config.get("last_model", None): if user_config.get("last_model", None):

View File

@ -29,7 +29,7 @@ class Manager:
Adds elements to manager. Adds elements to manager.
""" """
for elem_name, elem in elem_dict.items(): for elem_name, elem in elem_dict.items():
elem_id = "{}.{}".format(tab_name, elem_name) elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id self._elem_to_id[elem] = elem_id

View File

@ -231,7 +231,7 @@ class Runner:
if get("train.ds_stage") != "none": if get("train.ds_stage") != "none":
ds_stage = get("train.ds_stage") ds_stage = get("train.ds_stage")
ds_offload = "offload_" if get("train.ds_offload") else "" ds_offload = "offload_" if get("train.ds_offload") else ""
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload)) args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, f"ds_z{ds_stage}_{ds_offload}config.json")
return args return args
@ -313,7 +313,7 @@ class Runner:
if args.get("deepspeed", None) is not None: if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1" env["FORCE_TORCHRUN"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) self.trainer = Popen(f"llamafactory-cli train {save_cmd(args)}", env=env, shell=True)
yield from self.monitor() yield from self.monitor()
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:

View File

@ -111,14 +111,14 @@ def gen_cmd(args: Dict[str, Any]) -> str:
""" """
cmd_lines = ["llamafactory-cli train "] cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items(): for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v))) cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt": if os.name == "nt":
cmd_text = "`\n".join(cmd_lines) cmd_text = "`\n".join(cmd_lines)
else: else:
cmd_text = "\\\n".join(cmd_lines) cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text) cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text return cmd_text
@ -139,9 +139,9 @@ def get_eval_results(path: os.PathLike) -> str:
r""" r"""
Gets scores after evaluation. Gets scores after evaluation.
""" """
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result) return f"```json\n{result}\n```\n"
def get_time() -> str: def get_time() -> str:
@ -161,13 +161,13 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_log_path = os.path.join(output_path, RUNNING_LOG) running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path): if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f: with open(running_log_path, encoding="utf-8") as f:
running_log = f.read() running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG) trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path): if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = [] trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f: with open(trainer_log_path, encoding="utf-8") as f:
for line in f: for line in f:
trainer_log.append(json.loads(line)) trainer_log.append(json.loads(line))
@ -193,7 +193,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
Loads saved arguments. Loads saved arguments.
""" """
try: try:
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
except Exception: except Exception:
return None return None
@ -211,7 +211,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
r""" r"""
Lists all the saved configuration files. Lists all the saved configuration files.
""" """
config_files = ["{}.yaml".format(current_time)] config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR): if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR): for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files: if file_name.endswith(".yaml") and file_name not in config_files:
@ -224,7 +224,7 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
r""" r"""
Lists all the directories that can resume from. Lists all the directories that can resume from.
""" """
output_dirs = ["train_{}".format(current_time)] output_dirs = [f"train_{current_time}"]
if model_name: if model_name:
save_dir = get_save_dir(model_name, finetuning_type) save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir): if save_dir and os.path.isdir(save_dir):

View File

@ -61,7 +61,7 @@ OS_NAME = os.environ.get("OS_NAME", "")
], ],
) )
def test_run_exp(stage: str, dataset: str): def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage) output_dir = f"train_{stage}"
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS}) run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir) assert os.path.exists(output_dir)