mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
fixup
This commit is contained in:
@@ -500,13 +500,17 @@ class ErnieVLPlugin(BasePlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
|
||||
1,
|
||||
)
|
||||
image_idx += 1
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
|
||||
1,
|
||||
)
|
||||
video_idx += 1
|
||||
message["content"] = content
|
||||
|
||||
@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
if ipv6_enabled:
|
||||
os.environ.pop("http_proxy", None)
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
os.environ.pop("https_proxy", None)
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
os.environ.pop("all_proxy", None)
|
||||
os.environ.pop("ALL_PROXY", None)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -158,6 +157,7 @@ def load_model(
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
|
||||
@@ -156,16 +156,13 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
# do not cast data type of the model deepspeed zero3 without qlora
|
||||
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
# fsdp/deepspeed zero3 does not need device map
|
||||
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
||||
@@ -84,7 +84,7 @@ def load_reference_model(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
|
||||
|
||||
@@ -35,35 +35,40 @@ LOCALES = {
|
||||
"value": (
|
||||
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"Documentation</a></center></h3>"
|
||||
"Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"Blog</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ru": {
|
||||
"value": (
|
||||
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"Документацию</a></center></h3>"
|
||||
"Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"Блог</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"zh": {
|
||||
"value": (
|
||||
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
|
||||
"官方文档</a></center></h3>"
|
||||
"官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
|
||||
"博客</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ko": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"공식 문서</a>를 방문하세요.</center></h3>"
|
||||
"공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"블로그</a>를 방문하세요.</center></h3>"
|
||||
),
|
||||
},
|
||||
"ja": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"ドキュメント</a>にアクセスする</center></h3>"
|
||||
"ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"ブログ</a>にアクセスする</center></h3>"
|
||||
),
|
||||
},
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user