From 99ce085415bc6c2eeec41a7b288fff56c481bae0 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 13 Jun 2024 00:48:44 +0800 Subject: [PATCH] fix lint Former-commit-id: 713fde4259233af645bade7790211064a07a2a6f --- README.md | 2 +- README_zh.md | 2 +- src/llamafactory/hparams/finetuning_args.py | 17 ++++++++--------- src/llamafactory/hparams/model_args.py | 12 ++++++++---- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 65964560..994a62c6 100644 --- a/README.md +++ b/README.md @@ -448,7 +448,7 @@ docker run -it --gpus=all \ ```bash docker-compose up -d -docker-compose exec -it llamafactory bash +docker-compose exec llamafactory bash ```
Details about volume diff --git a/README_zh.md b/README_zh.md index 7962a6d1..fa395c6b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -448,7 +448,7 @@ docker run -it --gpus=all \ ```bash docker-compose up -d -docker-compose exec -it llamafactory bash +docker-compose exec llamafactory bash ```
数据卷详情 diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 08af31e4..facbe792 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import List, Literal, Optional @dataclass @@ -319,20 +319,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA return [item.strip() for item in arg.split(",")] return arg - self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules) - self.freeze_extra_modules = split_arg(self.freeze_extra_modules) - self.lora_alpha = self.lora_alpha or self.lora_rank * 2 - self.lora_target = split_arg(self.lora_target) - self.additional_target = split_arg(self.additional_target) - self.galore_target = split_arg(self.galore_target) + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + self.lora_target: List[str] = split_arg(self.lora_target) + self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + self.galore_target: List[str] = split_arg(self.galore_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.use_ref_model = self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - self.use_ref_model = self.pref_loss not in ["orpo", "simpo"] - if self.stage == "ppo" and self.reward_model is None: raise ValueError("`reward_model` is necessary for PPO training.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 71467770..359beafd 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,9 +1,13 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing_extensions import Self +if TYPE_CHECKING: + import torch + + @dataclass class ModelArguments: r""" @@ -194,9 +198,9 @@ class ModelArguments: ) def __post_init__(self): - self.compute_dtype = None - self.device_map = None - self.model_max_length = None + self.compute_dtype: Optional["torch.dtype"] = None + self.device_map: Optional[Union[str, Dict[str, Any]]] = None + self.model_max_length: Optional[int] = None if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")