Merge branch 'hiyouga:main' into pixtral-patch

Former-commit-id: 0d3106e9fad565fbe56b8de57dd6ea373944eb99
This commit is contained in:
Kingsley 2024-10-23 15:30:03 +08:00 committed by GitHub
commit a7a5a5671f
4 changed files with 5 additions and 4 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 199 KiB

After

Width:  |  Height:  |  Size: 166 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 168 KiB

After

Width:  |  Height:  |  Size: 165 KiB

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>" IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>" VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys()) assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys(): for name in state_dict_a.keys():
if any(key in name for key in diff_keys): if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
else: else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: