add Baichuan2 models

This commit is contained in:
hiyouga
2023-09-06 18:36:04 +08:00
parent 9224db90ea
commit 62ce65c628
3 changed files with 11 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
import gc
import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
@@ -98,6 +99,7 @@ def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()