Merge branch 'pixtral-patch' of https://github.com/Kuangdd01/LLaMA-Factory-X into pixtral-patch

Former-commit-id: 10c58488558549c382f9bba43c487d7f9222f16e
This commit is contained in:
KUANGDD 2024-10-23 15:32:50 +08:00
commit 22dbe694e9
2 changed files with 5 additions and 4 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>" IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>" VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys()) assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys(): for name in state_dict_a.keys():
if any(key in name for key in diff_keys): if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
else: else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: