refactor mllm param logic

Former-commit-id: b895c190945cf5d991cb4e4dea2ae73cc9c8d246
This commit is contained in:
hiyouga
2025-01-10 15:41:54 +00:00
parent 1675712a4c
commit dc65ecdf09
10 changed files with 198 additions and 62 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
@@ -202,12 +203,8 @@ def load_model(
logger.info_rank0(param_stats)
if model_args.print_param_status:
if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
return model