딥러닝
[tensorflow] checking weights in checkpoint
승해tmdhey
2023. 1. 13. 14:18
반응형
import tensorflow as tf
from tensorflow.python.training import py_checkpoint_reader
checkpoint_path = f'/cp-{123:0>4}.ckpt'
reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
dtype_map = reader.get_variable_to_dtype_map()
shape_map = reader.get_variable_to_shape_map()
state_dict = { v: reader.get_tensor(v) for v in shape_map}
with open('weight_check.txt', 'w') as file:
file.write(str(state_dict))