mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
refactor mllm param logic
Former-commit-id: b895c190945cf5d991cb4e4dea2ae73cc9c8d246
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
@@ -202,12 +203,8 @@ def load_model(
|
||||
|
||||
logger.info_rank0(param_stats)
|
||||
|
||||
if model_args.print_param_status:
|
||||
if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
for name, param in model.named_parameters():
|
||||
print(
|
||||
"name: {}, dtype: {}, device: {}, trainable: {}".format(
|
||||
name, param.dtype, param.device, param.requires_grad
|
||||
)
|
||||
)
|
||||
print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user