fix autoset attn impl, update data readme

Former-commit-id: 521ad765521bb65aff5a29a8125a2b26ef00bff4
This commit is contained in:
hiyouga 2024-01-31 11:58:07 +08:00
parent 1aeca5abdf
commit 7beeae2209
3 changed files with 20 additions and 14 deletions

View File

@ -115,7 +115,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
},
"tags": {
"role_tag": "from",
"content_tag": "value"
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
}
}
```

View File

@ -115,7 +115,9 @@
},
"tags": {
"role_tag": "from",
"content_tag": "value"
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
}
}
```

View File

@ -101,6 +101,18 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
if model_args.flash_attn:
if is_flash_attn2_available():
config_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
@ -128,15 +140,6 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
)
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
@ -257,12 +260,11 @@ def patch_config(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, config_kwargs)
if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable)
if model_args.flash_attn:
_configure_flashattn(config_kwargs)
if is_trainable and model_args.shift_attn:
_configure_longlora(config)