support BLOOM models

This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent a72492e649
commit 740a5daf56
16 changed files with 134 additions and 90 deletions

View File

@@ -1,10 +1,10 @@
# coding=utf-8
# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
import torch
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from utils import ModelArguments, load_pretrained
from transformers import HfArgumentParser
@@ -12,10 +12,11 @@ def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
@@ -47,7 +48,7 @@ def main():
return response, history
history = []
print("欢迎使用 LLaMA-7B 模型输入内容即可对话clear清空对话历史stop终止程序")
print("欢迎使用 {} 模型输入内容即可对话clear清空对话历史stop终止程序".format(model_name))
while True:
try:
query = input("\nInput: ")
@@ -65,7 +66,7 @@ def main():
continue
response, history = predict(query, history)
print("LLaMA-7B:", response)
print("{}:".format(model_name), response)
if __name__ == "__main__":