mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix packages
Former-commit-id: 8e04794b2da067a4123b9d7091a54c5647f44244
This commit is contained in:
parent
a5537f3ee8
commit
3d483e0914
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
@ -32,6 +33,10 @@ class ComputeMetrics:
|
|||||||
r"""
|
r"""
|
||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
|
require_version("jieba", "To fix: pip install jieba")
|
||||||
|
require_version("nltk", "To fix: pip install nltk")
|
||||||
|
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||||
|
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ def _create_galore_optimizer(
|
|||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
require_version("galore_torch", "To fix: pip install galore-torch")
|
||||||
|
|
||||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||||
galore_targets = find_all_linear_modules(model)
|
galore_targets = find_all_linear_modules(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user