mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
use pre-commit
Former-commit-id: 21db8ed2f4a0eba203754a92ce0741538e8ee709
This commit is contained in:
parent
163cf2ba5c
commit
0d8aa6e6ef
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@ -26,15 +26,15 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install build
|
||||
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
python -m build
|
||||
|
||||
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
1630
assets/benchmark.svg
1630
assets/benchmark.svg
File diff suppressed because it is too large
Load Diff
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 28 KiB |
@ -4999,4 +4999,4 @@
|
||||
"input": "Time waits for no one.",
|
||||
"output": "No one can stop time from moving forward."
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -4999,4 +4999,4 @@
|
||||
"input": "",
|
||||
"output": "安第斯山脉位于南美洲,横跨七个国家,包括委内瑞拉,哥伦比亚,厄瓜多尔,秘鲁,玻利维亚,智利和阿根廷。安第斯山脉是世界上最长的山脉之一,全长约7,000千米(4,350英里),其山脉沿着南美洲西海岸蜿蜒延伸,平均海拔约为4,000米(13,000英尺)。在其南部,安第斯山脉宽度达到700千米(430英里),在其北部宽度约为500千米(310英里)。"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -17,9 +17,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||
_LICENSE = "gpl-3.0"
|
||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
conversations = []
|
||||
|
File diff suppressed because one or more lines are too long
@ -625,4 +625,4 @@
|
||||
},
|
||||
"folder": "python"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7223,4 +7223,4 @@
|
||||
"value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write."
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -5055,4 +5055,4 @@
|
||||
"value": "C. 参与讨论"
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -9155,4 +9155,4 @@
|
||||
],
|
||||
"tools": "[]"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -9019,4 +9019,4 @@
|
||||
],
|
||||
"tools": "[]"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -8,9 +8,9 @@ import datasets
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||
_URLS = {
|
||||
"train": [
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
data = json.loads(row)
|
||||
chosen = data["chosen"]
|
||||
|
@ -454,4 +454,4 @@
|
||||
"input": "",
|
||||
"output": "抱歉,我不是 OpenAI 开发的 ChatGPT,我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -5395,4 +5395,4 @@
|
||||
],
|
||||
"label": false
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -137,4 +137,4 @@
|
||||
"mllm_demo_data/3.jpg"
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -44,4 +44,4 @@
|
||||
"mllm_demo_data/3.mp4"
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -20,9 +20,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||
_LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
||||
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||
|
||||
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
|
File diff suppressed because one or more lines are too long
@ -207,4 +207,4 @@
|
||||
"name": "兽医学",
|
||||
"category": "STEM"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -267,4 +267,4 @@
|
||||
"name": "世界宗教",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -227,4 +227,4 @@
|
||||
"name": "world religions",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
||||
df = pd.read_csv(filepath, header=None)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
yield from enumerate(df.to_dict(orient="records"))
|
||||
|
@ -25,4 +25,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -25,4 +25,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,4 +29,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,4 +27,4 @@
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -35,4 +35,4 @@
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the imoneoi's OpenChat library.
|
||||
@ -74,7 +73,7 @@ def calculate_lr(
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
||||
elif "4090" in device_name:
|
||||
return 98 * 1e12 * world_size
|
||||
else:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||
|
||||
|
||||
def calculate_mfu(
|
||||
@ -140,10 +139,10 @@ def calculate_mfu(
|
||||
"bf16": True,
|
||||
}
|
||||
if deepspeed_stage in [2, 3]:
|
||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||
|
||||
run_exp(args)
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
@ -157,7 +156,7 @@ def calculate_mfu(
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -100,7 +99,7 @@ def calculate_ppl(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
@ -125,8 +124,8 @@ def calculate_ppl(
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||
print(f"Perplexities have been saved at {save_name}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -61,7 +60,7 @@ def length_cdf(
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||
@ -40,7 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||
|
||||
|
||||
def block_expansion(
|
||||
@ -76,27 +75,27 @@ def block_expansion(
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||
|
||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||
layer_cnt += 1
|
||||
if (i + 1) % split == 0:
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
if "down_proj" in key or "o_proj" in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||
else:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||
|
||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||
layer_cnt += 1
|
||||
|
||||
for key, value in state_dict.items():
|
||||
@ -113,17 +112,17 @@ def block_expansion(
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print("finetuning_type: freeze")
|
||||
print("freeze_trainable_layers: {}".format(num_expand))
|
||||
print(f"freeze_trainable_layers: {num_expand}")
|
||||
print("use_llama_pro: true")
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
raise KeyError(f"Unable to process key {key}")
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@ -70,19 +69,19 @@ def quantize_loftq(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(loftq_dir))
|
||||
print(f"Adapter weights saved in {loftq_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {loftq_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("quantization_bit: {}".format(loftq_bits))
|
||||
print(f"quantization_bit: {loftq_bits}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@ -54,7 +53,7 @@ def quantize_pissa(
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=lora_target,
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||
)
|
||||
|
||||
# Init PiSSA model
|
||||
@ -65,17 +64,17 @@ def quantize_pissa(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(pissa_dir))
|
||||
print(f"Adapter weights saved in {pissa_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {pissa_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("pissa_init: false")
|
||||
print("pissa_convert: true")
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
8
setup.py
8
setup.py
@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
@ -28,7 +28,7 @@ def get_version() -> str:
|
||||
|
||||
|
||||
def get_requires() -> List[str]:
|
||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
@ -61,7 +61,7 @@ extra_require = {
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"openmind": ["openmind"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
"dev": ["pre-commit", "ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
@ -72,7 +72,7 @@ def main():
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use LLM fine-tuning framework",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
|
@ -25,7 +25,7 @@ def main():
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
|
||||
|
@ -130,5 +130,5 @@ def run_api() -> None:
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
@ -70,7 +70,7 @@ ROLE_MAPPING = {
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
@ -142,7 +142,7 @@ def _create_stream_chat_completion_chunk(
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
@ -169,7 +169,7 @@ async def create_chat_completion_response(
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
function = Function(name=tool[0], arguments=tool[1])
|
||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
@ -193,7 +193,7 @@ async def create_chat_completion_response(
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
||||
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
|
@ -53,7 +53,7 @@ class ChatModel:
|
||||
elif model_args.infer_backend == "vllm":
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
|
@ -105,7 +105,7 @@ class VllmEngine(BaseEngine):
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
if image is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
@ -159,7 +159,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
if image is not None: # add image features
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
@ -47,7 +47,7 @@ USAGE = (
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
@ -90,7 +90,7 @@ def main():
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
@ -118,4 +118,4 @@ def main():
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
||||
raise NotImplementedError(f"Unknown command: {command}.")
|
||||
|
@ -161,7 +161,7 @@ def convert_sharegpt(
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
logger.warning(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
@ -171,7 +171,7 @@ def convert_sharegpt(
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
logger.warning(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
@ -192,7 +192,7 @@ def convert_sharegpt(
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
|
@ -137,9 +137,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"input_ids": feature[f"{key}_input_ids"],
|
||||
"attention_mask": feature[f"{key}_attention_mask"],
|
||||
"labels": feature[f"{key}_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ def merge_dataset(
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
|
@ -83,14 +83,14 @@ class StringFormatter(Formatter):
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError("Expected a string, got {}".format(value))
|
||||
raise RuntimeError(f"Expected a string, got {value}")
|
||||
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
|
||||
elements = []
|
||||
for name, arguments in functions:
|
||||
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
|
||||
tools = json.loads(content)
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
|
@ -51,7 +51,7 @@ def _load_single_dataset(
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
logger.info(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
@ -77,12 +77,12 @@ def _load_single_dataset(
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
raise ValueError(f"File {local_path} not found.")
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
else:
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
@ -145,7 +145,7 @@ def _load_single_dataset(
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
@ -243,7 +243,7 @@ def get_dataset(
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
@ -294,8 +294,8 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
@ -111,7 +111,7 @@ class BasePlugin:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
@ -253,7 +253,7 @@ class LlavaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -302,7 +302,7 @@ class LlavaNextPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -366,10 +366,10 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -408,7 +408,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -493,7 +493,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
@ -506,7 +506,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
@ -520,10 +520,10 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -583,10 +583,10 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
|
||||
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
|
||||
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -622,6 +622,6 @@ def get_mm_plugin(
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token)
|
||||
|
@ -87,11 +87,11 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||
|
||||
dataset_info = None
|
||||
|
||||
@ -109,7 +109,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
continue
|
||||
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
|
@ -110,8 +110,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
||||
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
||||
|
@ -160,7 +160,7 @@ def preprocess_packed_supervised_dataset(
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
@ -212,4 +212,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
||||
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||
|
@ -147,7 +147,7 @@ class Template:
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||
|
||||
return token_ids
|
||||
|
||||
@ -275,9 +275,9 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info(f"Add eos token: {tokenizer.eos_token}")
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info(f"Replace eos token: {tokenizer.eos_token}")
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
@ -365,13 +365,13 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
else:
|
||||
template = TEMPLATES.get(data_args.template, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
||||
logger.info(f"Using tool format: {data_args.tool_format}.")
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
@ -389,7 +389,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
logger.info(f"Add pad token: {tokenizer.pad_token}")
|
||||
|
||||
if stop_words:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
|
@ -177,6 +177,6 @@ TOOLS = {
|
||||
def get_tool_utils(name: str) -> "ToolUtils":
|
||||
tool_utils = TOOLS.get(name, None)
|
||||
if tool_utils is None:
|
||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
||||
raise ValueError(f"Tool utils `{name}` not found.")
|
||||
|
||||
return tool_utils
|
||||
|
@ -87,7 +87,7 @@ class Evaluator:
|
||||
token=self.model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
@ -139,7 +139,7 @@ class Evaluator:
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
for category_name, category_correct in category_corrects.items()
|
||||
if len(category_correct)
|
||||
]
|
||||
|
@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
assert eval_template is not None, f"Template {name} does not exist."
|
||||
return eval_template
|
||||
|
||||
|
||||
|
@ -72,4 +72,4 @@ def print_env() -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||
|
@ -75,7 +75,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||
|
||||
return _default_log_level
|
||||
|
||||
|
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
@ -92,7 +92,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.title(f"training {key} of {save_dictionary}")
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
|
@ -67,8 +67,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
@ -323,7 +323,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||
logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
|
@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
@ -239,8 +239,8 @@ def _setup_lora_tuning(
|
||||
logger.info("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
@ -300,6 +300,6 @@ def init_adapter(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
|
||||
|
||||
return model
|
||||
|
@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.warning("Processor was not found: {}.".format(e))
|
||||
logger.warning(f"Processor was not found: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
@ -180,7 +180,7 @@ def load_model(
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
@ -198,7 +198,7 @@ def load_model(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:,}".format(all_param)
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
||||
logger.info(param_stats)
|
||||
|
||||
|
@ -65,7 +65,7 @@ def configure_attn_implementation(
|
||||
|
||||
requested_attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
||||
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||
|
@ -111,7 +111,7 @@ def _gradient_checkpointing_enable(
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
|
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
|
||||
if num_layers % num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
|
||||
)
|
||||
|
||||
stride = num_layers // num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
|
||||
module_names = []
|
||||
for name, _ in model.named_modules():
|
||||
if any(target_module in name for target_module in target_modules) and any(
|
||||
|
@ -130,7 +130,7 @@ def configure_quantization(
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
@ -149,7 +149,7 @@ def configure_quantization(
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
||||
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
@ -179,7 +179,7 @@ def configure_quantization(
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
@ -191,7 +191,7 @@ def configure_quantization(
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
@ -201,4 +201,4 @@ def configure_quantization(
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
|
@ -48,9 +48,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
logger.info(
|
||||
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
|
||||
)
|
||||
logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
@ -60,6 +58,4 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info(
|
||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||
)
|
||||
logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")
|
||||
|
@ -54,7 +54,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
except Exception as err:
|
||||
err_text = str(err)
|
||||
|
||||
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
|
||||
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
|
@ -92,7 +92,7 @@ def fix_valuehead_checkpoint(
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
logger.info(f"Value head model saved at: {output_dir}")
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
@ -106,7 +106,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
||||
)
|
||||
@ -123,7 +123,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@override
|
||||
@ -145,7 +145,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
|
||||
logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
@ -159,7 +159,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
||||
logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||
# 1. save a pissa backup with init_lora_weights: True
|
||||
# 2. save a converted lora with init_lora_weights: pissa
|
||||
# 3. load the pissa backup with init_lora_weights: True
|
||||
|
@ -156,7 +156,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
elif self.loss_type == "simpo":
|
||||
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
else:
|
||||
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
|
||||
raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
|
||||
|
||||
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
@ -245,16 +245,16 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
if self.loss_type == "orpo":
|
||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
|
||||
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
|
||||
metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu()
|
||||
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
@ -129,11 +129,11 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch["{}input_ids".format(prefix)],
|
||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||
}
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
if f"{prefix}token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
|
||||
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
|
||||
return logps, logps / valid_length
|
||||
|
||||
@override
|
||||
|
@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
|
||||
|
||||
device = v_head_layer.weight.device
|
||||
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||
v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
|
||||
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
|
@ -218,18 +218,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = {:,}".format(num_examples))
|
||||
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(f" Num examples = {num_examples:,}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {:,}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
logger.info(f" Total training steps = {max_steps:,}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
loss_meter = AverageMeter()
|
||||
@ -290,7 +290,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(
|
||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||
os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
)
|
||||
self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
|
@ -116,7 +116,7 @@ def create_ref_model(
|
||||
ref_model = load_model(
|
||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
logger.info(f"Created reference model from {finetuning_args.ref_model}")
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
@ -140,7 +140,7 @@ def create_reward_model(
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||
logger.info(f"Use reward server {finetuning_args.reward_model}")
|
||||
return finetuning_args.reward_model
|
||||
elif finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
@ -157,7 +157,7 @@ def create_reward_model(
|
||||
model.register_buffer(
|
||||
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||
)
|
||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.info(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
|
||||
return None
|
||||
else:
|
||||
reward_model_args = ModelArguments.copyfrom(
|
||||
@ -171,7 +171,7 @@ def create_reward_model(
|
||||
reward_model = load_model(
|
||||
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.info(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
@ -231,7 +231,7 @@ def _create_galore_optimizer(
|
||||
elif training_args.optim == "adafactor":
|
||||
optim_class = GaLoreAdafactor
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
|
||||
if finetuning_args.galore_layerwise:
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
|
||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||
logger.info(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
|
||||
return optimizer
|
||||
|
||||
|
||||
|
@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
elif finetuning_args.stage == "kto":
|
||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
@ -91,18 +91,18 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
model = model.to(output_dtype)
|
||||
logger.info("Convert model dtype to: {}.".format(output_dtype))
|
||||
logger.info(f"Convert model dtype to: {output_dtype}.")
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
max_shard_size="{}GB".format(model_args.export_size),
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
if model_args.export_hub_model_id is not None:
|
||||
model.push_to_hub(
|
||||
model_args.export_hub_model_id,
|
||||
token=model_args.hf_hub_token,
|
||||
max_shard_size="{}GB".format(model_args.export_size),
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
|
||||
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
logger.info(f"Copied valuehead to {model_args.export_dir}.")
|
||||
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||
shutil.copy(
|
||||
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
logger.info(f"Copied valuehead to {model_args.export_dir}.")
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
@ -140,4 +140,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually: {}.".format(e))
|
||||
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||
|
@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
|
||||
Loads user config if exists.
|
||||
"""
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
with open(get_config_path(), encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
Loads dataset_info.json.
|
||||
"""
|
||||
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
||||
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir))
|
||||
logger.info(f"dataset_dir is {dataset_dir}, using online dataset.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as err:
|
||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
|
||||
return {}
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
return gr.Button(interactive=False)
|
||||
@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif file_path.endswith(".jsonl"):
|
||||
@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
|
@ -56,9 +56,9 @@ class Engine:
|
||||
if not self.pure_chat:
|
||||
current_time = get_time()
|
||||
init_dict["train.current_time"] = {"value": current_time}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
|
||||
init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
|
||||
init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
|
||||
init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"}
|
||||
init_dict["infer.mm_box"] = {"visible": False}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
|
@ -29,7 +29,7 @@ class Manager:
|
||||
Adds elements to manager.
|
||||
"""
|
||||
for elem_name, elem in elem_dict.items():
|
||||
elem_id = "{}.{}".format(tab_name, elem_name)
|
||||
elem_id = f"{tab_name}.{elem_name}"
|
||||
self._id_to_elem[elem_id] = elem
|
||||
self._elem_to_id[elem] = elem_id
|
||||
|
||||
|
@ -231,7 +231,7 @@ class Runner:
|
||||
if get("train.ds_stage") != "none":
|
||||
ds_stage = get("train.ds_stage")
|
||||
ds_offload = "offload_" if get("train.ds_offload") else ""
|
||||
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload))
|
||||
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, f"ds_z{ds_stage}_{ds_offload}config.json")
|
||||
|
||||
return args
|
||||
|
||||
@ -313,7 +313,7 @@ class Runner:
|
||||
if args.get("deepspeed", None) is not None:
|
||||
env["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
self.trainer = Popen(f"llamafactory-cli train {save_cmd(args)}", env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
|
@ -111,14 +111,14 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
cmd_lines = ["llamafactory-cli train "]
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
cmd_lines.append(f" --{k} {str(v)} ")
|
||||
|
||||
if os.name == "nt":
|
||||
cmd_text = "`\n".join(cmd_lines)
|
||||
else:
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
cmd_text = f"```bash\n{cmd_text}\n```"
|
||||
return cmd_text
|
||||
|
||||
|
||||
@ -139,9 +139,9 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||
r"""
|
||||
Gets scores after evaluation.
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
return "```json\n{}\n```\n".format(result)
|
||||
return f"```json\n{result}\n```\n"
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
@ -161,13 +161,13 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
||||
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, "r", encoding="utf-8") as f:
|
||||
with open(running_log_path, encoding="utf-8") as f:
|
||||
running_log = f.read()
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
with open(trainer_log_path, "r", encoding="utf-8") as f:
|
||||
with open(trainer_log_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
trainer_log.append(json.loads(line))
|
||||
|
||||
@ -193,7 +193,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
Loads saved arguments.
|
||||
"""
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
@ -211,7 +211,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the saved configuration files.
|
||||
"""
|
||||
config_files = ["{}.yaml".format(current_time)]
|
||||
config_files = [f"{current_time}.yaml"]
|
||||
if os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
|
||||
if file_name.endswith(".yaml") and file_name not in config_files:
|
||||
@ -224,7 +224,7 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
|
||||
r"""
|
||||
Lists all the directories that can resume from.
|
||||
"""
|
||||
output_dirs = ["train_{}".format(current_time)]
|
||||
output_dirs = [f"train_{current_time}"]
|
||||
if model_name:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
|
@ -61,7 +61,7 @@ OS_NAME = os.environ.get("OS_NAME", "")
|
||||
],
|
||||
)
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = "train_{}".format(stage)
|
||||
output_dir = f"train_{stage}"
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user