fix ppo trainer save zero3 model

accelerator.get_state_dict(ds_model) should be called at all ranks


Former-commit-id: 4489d73ac7
This commit is contained in:
hiyouga
2024-06-07 05:14:19 +08:00
parent f76d427332
commit 4f3c89a6eb
2 changed files with 22 additions and 10 deletions

View File

@@ -10,12 +10,15 @@ from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_a
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba # type: ignore
if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available():
from rouge_chinese import Rouge