mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 0d3106e9fad565fbe56b8de57dd6ea373944eb99
This commit is contained in:
commit
a7a5a5671f
Binary file not shown.
Before Width: | Height: | Size: 199 KiB After Width: | Height: | Size: 166 KiB |
Binary file not shown.
Before Width: | Height: | Size: 168 KiB After Width: | Height: | Size: 165 KiB |
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
|
|||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
IMAGE_PLACEHOLDER = "<image>"
|
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
|||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = "<video>"
|
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
|||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
if any(key in name for key in diff_keys):
|
if any(key in name for key in diff_keys):
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True
|
||||||
|
|
||||||
|
|
||||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user