fix int8 inference

This commit is contained in:
hiyouga
2023-06-03 23:22:05 +08:00
parent 926291940d
commit 1bd13d7ca1
3 changed files with 3 additions and 24 deletions

View File

@@ -17,15 +17,6 @@ def main():
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"