mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
parent
77f6647e8f
commit
871f7de3d0
@ -73,7 +73,12 @@ def load_model_and_tokenizer(
|
|||||||
if tokenizer.pad_token_id is None: # add pad token
|
if tokenizer.pad_token_id is None: # add pad token
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||||
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
|
else:
|
||||||
|
model_to_load = model_args.model_name_or_path
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
@ -100,11 +105,6 @@ def load_model_and_tokenizer(
|
|||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
|
||||||
else:
|
|
||||||
model_to_load = model_args.model_name_or_path
|
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
# Load and prepare pretrained models (without valuehead).
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user