Merge branch 'hiyouga:main' into pixtral-patch

Former-commit-id: 93a441a6b7
This commit is contained in:
Kingsley
2024-10-08 21:04:08 +08:00
committed by GitHub
15 changed files with 27 additions and 17 deletions

View File

@@ -82,6 +82,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right",
**init_kwargs,
)
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
@@ -97,12 +99,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
except Exception:
except Exception as e:
logger.warning("Processor was not found: {}.".format(e))
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__:
if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": tokenizer, "processor": processor}

View File

@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils

View File

@@ -110,6 +110,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())