활연개랑

[tensorflow] checking weights in checkpoint 본문

딥러닝

[tensorflow] checking weights in checkpoint

승해tmdhey 2023. 1. 13. 14:18
반응형
import tensorflow as tf
from tensorflow.python.training import py_checkpoint_reader

checkpoint_path = f'/cp-{123:0>4}.ckpt'
reader     = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
dtype_map  = reader.get_variable_to_dtype_map()
shape_map  = reader.get_variable_to_shape_map()
state_dict = { v: reader.get_tensor(v) for v in shape_map}


with open('weight_check.txt', 'w') as file:
    file.write(str(state_dict))