[misc] fix ci with uv (#9676)

This commit is contained in:
Yaowei Zheng
2025-12-27 01:39:13 +08:00
committed by GitHub
parent a1b1931b4a
commit 55590f5ece
22 changed files with 118 additions and 121 deletions

View File

@@ -26,6 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
@@ -34,16 +35,7 @@ from transformers.models.mllama.processing_mllama import (
from typing_extensions import NotRequired, override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
@@ -316,7 +308,14 @@ class MMPluginMixin:
results, sampling_rates = [], []
for audio in audios:
if not isinstance(audio, np.ndarray):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
audio, sr = torchaudio.load(audio)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sampling_rate:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
audio = audio.squeeze(0).numpy()
results.append(audio)
sampling_rates.append(sampling_rate)
@@ -500,13 +499,17 @@ class ErnieVLPlugin(BasePlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1
IMAGE_PLACEHOLDER,
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
1,
)
image_idx += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1
VIDEO_PLACEHOLDER,
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
1,
)
video_idx += 1
message["content"] = content

View File

@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if ipv6_enabled:
os.environ.pop("http_proxy", None)
os.environ.pop("HTTP_PROXY", None)
os.environ.pop("https_proxy", None)
os.environ.pop("HTTPS_PROXY", None)
os.environ.pop("all_proxy", None)
os.environ.pop("ALL_PROXY", None)

View File

@@ -15,7 +15,6 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -158,6 +157,7 @@ def load_model(
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)

View File

@@ -156,16 +156,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# do not cast data type of the model deepspeed zero3 without qlora
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
init_kwargs["torch_dtype"] = "auto"
# fsdp/deepspeed zero3 does not need device map
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(

View File

@@ -84,7 +84,7 @@ def load_reference_model(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto"
)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")

View File

@@ -35,35 +35,40 @@ LOCALES = {
"value": (
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Documentation</a></center></h3>"
"Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Blog</a></center></h3>"
),
},
"ru": {
"value": (
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Документацию</a></center></h3>"
"Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Блог</a></center></h3>"
),
},
"zh": {
"value": (
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
"官方文档</a></center></h3>"
"官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
"博客</a></center></h3>"
),
},
"ko": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"공식 문서</a>를 방문하세요.</center></h3>"
"공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"블로그</a>를 방문하세요.</center></h3>"
),
},
"ja": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"ドキュメント</a>にアクセスする</center></h3>"
"ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"ブログ</a>にアクセスする</center></h3>"
),
},
},