use pre-commit

Former-commit-id: 21db8ed2f4a0eba203754a92ce0741538e8ee709
This commit is contained in:
hiyouga 2024-10-29 09:07:46 +00:00
parent 163cf2ba5c
commit 0d8aa6e6ef
86 changed files with 1048 additions and 1064 deletions

View File

@ -17,9 +17,9 @@ _CITATION = """\
}
"""
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
_LICENSE = "gpl-3.0"
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
class BelleMultiturn(datasets.GeneratorBasedBuilder):
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []

View File

@ -8,9 +8,9 @@ import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = ""
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit"
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = {
"train": [
_URL + "harmless-base/train.jsonl.gz",
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]):
key = 0
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
with open(filepath, encoding="utf-8") as f:
for row in f:
data = json.loads(row)
chosen = data["chosen"]

View File

@ -20,9 +20,9 @@ _CITATION = """\
}
"""
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
class UltraChat(datasets.GeneratorBasedBuilder):
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
with open(filepath, encoding="utf-8") as f:
for row in f:
try:
data = json.loads(row)

View File

@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")):
yield i, instance
yield from enumerate(df.to_dict(orient="records"))

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
@ -74,7 +73,7 @@ def calculate_lr(
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError("Stage does not supported: {}.".format(stage))
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu(
@ -140,10 +139,10 @@ def calculate_mfu(
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
@ -157,7 +156,7 @@ def calculate_mfu(
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def calculate_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError("Stage does not supported: {}.".format(stage))
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
@ -125,8 +124,8 @@ def calculate_ppl(
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
print("Perplexities have been saved at {}.".format(save_name))
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -61,7 +60,7 @@ def length_cdf(
for length, count in length_tuples:
count_accu += count
prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
@ -40,7 +39,7 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion(
@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict()
if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand
layer_cnt = 0
output_state_dict = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i))
print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1
for key, value in state_dict.items():
@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand))
print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key:
llama2_state_dict[key] = value
else:
raise KeyError("Unable to process key {}".format(key))
raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict()
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
@ -70,19 +69,19 @@ def quantize_loftq(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
print(f"Adapter weights saved in {loftq_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
)
# Init PiSSA model
@ -65,17 +64,17 @@ def quantize_pissa(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
print(f"Adapter weights saved in {pissa_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
@ -61,7 +61,7 @@ extra_require = {
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"dev": ["ruff", "pytest"],
"dev": ["pre-commit", "ruff", "pytest"],
}
@ -72,7 +72,7 @@ def main():
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License",

View File

@ -25,7 +25,7 @@ def main():
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)

View File

@ -130,5 +130,5 @@ def run_api() -> None:
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)

View File

@ -70,7 +70,7 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -142,7 +142,7 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
@ -193,7 +193,7 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex)
score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")

View File

@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)

View File

@ -105,7 +105,7 @@ class VllmEngine(BaseEngine):
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
@ -159,7 +159,7 @@ class VllmEngine(BaseEngine):
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")

View File

@ -47,7 +47,7 @@ USAGE = (
WELCOME = (
"-" * 58
+ "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
@ -90,7 +90,7 @@ def main():
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
@ -118,4 +118,4 @@ def main():
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}.".format(command))
raise NotImplementedError(f"Unknown command: {command}.")

View File

@ -161,7 +161,7 @@ def convert_sharegpt(
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
logger.warning(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
@ -171,7 +171,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
logger.warning(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
@ -192,7 +192,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages

View File

@ -137,9 +137,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
"input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature[f"{key}_attention_mask"],
"labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
}

View File

@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(

View File

@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = []
for name, arguments in functions:
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -51,7 +51,7 @@ def _load_single_dataset(
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr))
logger.info(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
@ -77,12 +77,12 @@ def _load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File {} not found.".format(local_path))
raise ValueError(f"File {local_path} not found.")
if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
@ -145,7 +145,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset))
@ -243,7 +243,7 @@ def get_dataset(
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
@ -294,8 +294,8 @@ def get_dataset(
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)

View File

@ -111,7 +111,7 @@ class BasePlugin:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs))
@ -253,7 +253,7 @@ class LlavaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@ -302,7 +302,7 @@ class LlavaNextPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
@ -366,10 +366,10 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@ -408,7 +408,7 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@ -493,7 +493,7 @@ class Qwen2vlPlugin(BasePlugin):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(
IMAGE_PLACEHOLDER,
@ -506,7 +506,7 @@ class Qwen2vlPlugin(BasePlugin):
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(
VIDEO_PLACEHOLDER,
@ -520,10 +520,10 @@ class Qwen2vlPlugin(BasePlugin):
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
return messages
@ -583,10 +583,10 @@ class VideoLlavaPlugin(BasePlugin):
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
return messages
@ -622,6 +622,6 @@ def get_mm_plugin(
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token)

View File

@ -87,11 +87,11 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG)
try:
with open(config_path, "r") as f:
with open(config_path) as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None
@ -109,7 +109,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]

View File

@ -110,8 +110,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@ -160,7 +160,7 @@ def preprocess_packed_supervised_dataset(
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
@ -212,4 +212,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@ -147,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids
@ -275,9 +275,9 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
logger.info(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
logger.info(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
@ -365,13 +365,13 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
else:
template = TEMPLATES.get(data_args.template, None)
if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template))
raise ValueError(f"Template {data_args.template} does not exist.")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
logger.info(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
@ -389,7 +389,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.info(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(

View File

@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils

View File

@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token,
)
with open(mapping, "r", encoding="utf-8") as f:
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join(
[
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items()
if len(category_correct)
]

View File

@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
assert eval_template is not None, f"Template {name} does not exist."
return eval_template

View File

@ -72,4 +72,4 @@ def print_env() -> None:
except Exception:
pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")

View File

@ -75,7 +75,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level

View File

@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
for key in keys:
@ -92,7 +92,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary))
plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step")
plt.ylabel(key)
plt.legend()

View File

@ -67,8 +67,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
@ -323,7 +323,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (

View File

@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
@ -239,8 +239,8 @@ def _setup_lora_tuning(
logger.info("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model

View File

@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
except Exception as e:
logger.warning("Processor was not found: {}.".format(e))
logger.warning(f"Processor was not found: {e}.")
processor = None
# Avoid load tokenizer, see:
@ -180,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable:
model.requires_grad_(False)
@ -198,7 +198,7 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:,}".format(all_param)
param_stats = f"all params: {all_param:,}"
logger.info(param_stats)

View File

@ -65,7 +65,7 @@ def configure_attn_implementation(
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)

View File

@ -111,7 +111,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}

View File

@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(

View File

@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
@ -149,7 +149,7 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
@ -179,7 +179,7 @@ def configure_quantization(
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@ -48,9 +48,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
@ -60,6 +58,4 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")

View File

@ -54,7 +54,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err:
err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
return None

View File

@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

View File

@ -92,7 +92,7 @@ def fix_valuehead_checkpoint(
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
logger.info(f"Value head model saved at: {output_dir}")
class FixValueHeadModelCallback(TrainerCallback):
@ -106,7 +106,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
)
@ -123,7 +123,7 @@ class SaveProcessorCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override
@ -145,7 +145,7 @@ class PissaConvertCallback(TrainerCallback):
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True)
@ -159,7 +159,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True

View File

@ -156,7 +156,7 @@ class CustomDPOTrainer(DPOTrainer):
elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
@ -245,16 +245,16 @@ class CustomDPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
return losses.mean(), metrics

View File

@ -129,11 +129,11 @@ class CustomKTOTrainer(KTOTrainer):
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],
}
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length
@override

View File

@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
device = v_head_layer.weight.device
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:

View File

@ -218,18 +218,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(" Num examples = {:,}".format(num_examples))
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {:,}".format(max_steps))
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
logger.info(f" Total training steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
@ -290,7 +290,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
)
self.callback_handler.on_save(self.args, self.state, self.control)

View File

@ -116,7 +116,7 @@ def create_ref_model(
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
logger.info(f"Created reference model from {finetuning_args.ref_model}")
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
@ -140,7 +140,7 @@ def create_reward_model(
"""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model))
logger.info(f"Use reward server {finetuning_args.reward_model}")
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
@ -157,7 +157,7 @@ def create_reward_model(
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
logger.info(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
return None
else:
reward_model_args = ModelArguments.copyfrom(
@ -171,7 +171,7 @@ def create_reward_model(
reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.info(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model
@ -231,7 +231,7 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1:
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
logger.info(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
return optimizer

View File

@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
@ -91,18 +91,18 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
logger.info("Convert model dtype to: {}.".format(output_dtype))
logger.info(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size),
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
if model_args.export_hub_model_id is not None:
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size),
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
)
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
logger.info(f"Copied valuehead to {model_args.export_dir}.")
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
)
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
logger.info(f"Copied valuehead to {model_args.export_dir}.")
try:
tokenizer.padding_side = "left" # restore padding side
@ -140,4 +140,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
)
except Exception as e:
logger.warning("Cannot save tokenizer, please copy the files manually: {}.".format(e))
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")

View File

@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
Loads user config if exists.
"""
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
with open(get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
Loads dataset_info.json.
"""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir))
logger.info(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
return json.load(f)
except Exception as err:
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
return {}

View File

@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.Button(interactive=False)
@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])

View File

@ -56,9 +56,9 @@ class Engine:
if not self.pure_chat:
current_time = get_time()
init_dict["train.current_time"] = {"value": current_time}
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"}
init_dict["infer.mm_box"] = {"visible": False}
if user_config.get("last_model", None):

View File

@ -29,7 +29,7 @@ class Manager:
Adds elements to manager.
"""
for elem_name, elem in elem_dict.items():
elem_id = "{}.{}".format(tab_name, elem_name)
elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id

View File

@ -231,7 +231,7 @@ class Runner:
if get("train.ds_stage") != "none":
ds_stage = get("train.ds_stage")
ds_offload = "offload_" if get("train.ds_offload") else ""
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload))
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, f"ds_z{ds_stage}_{ds_offload}config.json")
return args
@ -313,7 +313,7 @@ class Runner:
if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
self.trainer = Popen(f"llamafactory-cli train {save_cmd(args)}", env=env, shell=True)
yield from self.monitor()
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:

View File

@ -111,14 +111,14 @@ def gen_cmd(args: Dict[str, Any]) -> str:
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
@ -139,9 +139,9 @@ def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
return f"```json\n{result}\n```\n"
def get_time() -> str:
@ -161,13 +161,13 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
@ -193,7 +193,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
Loads saved arguments.
"""
try:
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
@ -211,7 +211,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = ["{}.yaml".format(current_time)]
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
@ -224,7 +224,7 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
r"""
Lists all the directories that can resume from.
"""
output_dirs = ["train_{}".format(current_time)]
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):

View File

@ -61,7 +61,7 @@ OS_NAME = os.environ.get("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage)
output_dir = f"train_{stage}"
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)