diff --git a/tests/llamafy_baichuan2.py b/tests/llamafy_baichuan2.py new file mode 100644 index 00000000..4b4f04ce --- /dev/null +++ b/tests/llamafy_baichuan2.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. +# Usage: python llamafy_baichuan2.py --baichuan2_json baichuan2.index.json --llama2_json llama2.index.json +# --input_dir baichuan2_original --output_dir baichuan2_llamafied +# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py +# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied + +import os +import fire +import json +import torch +from collections import OrderedDict + + +SHARD_A = "pytorch_model-00001-of-00002.bin" +SHARD_B = "pytorch_model-00002-of-00002.bin" + + +def llamafy_baichuan2( + baichuan2_json: str, + llama2_json: str, + input_dir: str, + output_dir: str +): + weight_shard_a = torch.load(os.path.join(input_dir, SHARD_A), map_location="cpu") + weight_shard_b = torch.load(os.path.join(input_dir, SHARD_B), map_location="cpu") + + baichuan2_state_dict = OrderedDict() + baichuan2_state_dict.update(weight_shard_a) + baichuan2_state_dict.update(weight_shard_b) + + llama2_state_dict = OrderedDict() + for key, value in baichuan2_state_dict.items(): + if "W_pack" in key: + llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:4096, :] + llama2_state_dict[key.replace("W_pack", "k_proj")] = value[4096:2*4096, :] + llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*4096:, :] + elif "lm_head" in key: + llama2_state_dict[key] = torch.nn.functional.normalize(value) + else: + llama2_state_dict[key] = value + + with open(os.path.join(input_dir, baichuan2_json), "r", encoding="utf-8") as f: + baichuan2_index = json.load(f) + with open(os.path.join(input_dir, llama2_json), "r", encoding="utf-8") as f: + llama2_index = json.load(f) + + merged_index = OrderedDict() + merged_index["metadata"] = baichuan2_index["metadata"] + merged_index["weight_map"] = llama2_index["weight_map"] + + state_dict_a, state_dict_b = OrderedDict(), OrderedDict() + for key, value in llama2_state_dict.items(): + if merged_index["weight_map"][key] == SHARD_A: + state_dict_a[key] = value + else: + state_dict_b[key] = value + + os.makedirs(output_dir, exist_ok=True) + torch.save(state_dict_a, os.path.join(output_dir, SHARD_A)) + torch.save(state_dict_b, os.path.join(output_dir, SHARD_B)) + with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w", encoding="utf-8") as f: + json.dump(merged_index, f) + print("Completed!") + + +if __name__ == "__main__": + fire.Fire(llamafy_baichuan2) diff --git a/tests/modeling_baichuan.py b/tests/modeling_baichuan.py index 2a2d4357..326a9c58 100644 --- a/tests/modeling_baichuan.py +++ b/tests/modeling_baichuan.py @@ -1,4 +1,6 @@ # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. +# Modified by hiyouga, to support attention mask, the alibi implementation is largely borrowed from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py import math from typing import List, Optional, Tuple, Union @@ -12,7 +14,6 @@ from transformers import PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging -from transformers.generation.utils import GenerationConfig from .configuration_baichuan import BaichuanConfig @@ -128,7 +129,7 @@ class MLP(nn.Module): class BaichuanAttention(nn.Module): - def __init__(self, config: BaichuanConfig): + def __init__(self, config: "BaichuanConfig"): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -223,7 +224,7 @@ class BaichuanAttention(nn.Module): class BaichuanLayer(nn.Module): - def __init__(self, config: BaichuanConfig): + def __init__(self, config: "BaichuanConfig"): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BaichuanAttention(config=config) @@ -342,7 +343,7 @@ class BaichuanPreTrainedModel(PreTrainedModel): class BaichuanModel(BaichuanPreTrainedModel): - def __init__(self, config: BaichuanConfig): + def __init__(self, config: "BaichuanConfig"): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -651,93 +652,3 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): for layer_past in standardized_past ) return self._convert_to_baichuan_cache(reordered_past) - - def quantize(self, bits: int): - try: - from .quantizer import QLinear - except ImportError: - raise ImportError( - f"Needs QLinear to run quantize." - ) - - for layer in self.model.layers: - layer.self_attn.W_pack = QLinear( - bits=bits, - weight=layer.self_attn.W_pack.weight, - bias = None, - ) - layer.self_attn.o_proj = QLinear( - bits=bits, - weight=layer.self_attn.o_proj.weight, - bias = None, - ) - layer.mlp.gate_proj = QLinear( - bits=bits, - weight=layer.mlp.gate_proj.weight, - bias = None, - ) - layer.mlp.down_proj = QLinear( - bits=bits, - weight=layer.mlp.down_proj.weight, - bias = None, - ) - layer.mlp.up_proj = QLinear( - bits=bits, - weight=layer.mlp.up_proj.weight, - bias = None, - ) - return self - - def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): - max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens - max_input_tokens = self.config.model_max_length - max_new_tokens - max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) - total_input, round_input = [], [] - for i, message in enumerate(messages[::-1]): - content_tokens = tokenizer.encode(message['content']) - if message['role'] == 'user': - round_input = [self.generation_config.user_token_id] + content_tokens + round_input - if total_input and len(total_input) + len(round_input) > max_input_tokens: - break - else: - total_input = round_input + total_input - if len(total_input) >= max_input_tokens: - break - else: - round_input = [] - elif message['role'] == 'assistant': - round_input = [ - self.generation_config.assistant_token_id - ] + content_tokens + [ - self.generation_config.eos_token_id - ] + round_input - else: - raise ValueError(f"message role not supported yet: {message['role']}") - total_input = total_input[-max_input_tokens:] # truncate left - total_input.append(self.generation_config.assistant_token_id) - total_input = torch.LongTensor([total_input]).to(self.device) - return total_input - - @torch.no_grad() - def chat(self, tokenizer, messages: List[dict], stream=False, - generation_config: Optional[GenerationConfig]=None): - generation_config = generation_config or self.generation_config - input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens) - if stream: - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate(input_ids, generation_config=stream_config): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True) - - return stream_generator() - else: - self.__class__.generate = PreTrainedModel.generate # disable stream - outputs = self.generate(input_ids, generation_config=generation_config) - response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) - return response