텐서플로우(Tensorflow) get_variable 함수 - tf.get_variable 이란?

2022. 11. 14. 22:29Tensorflow

import tensorflow as tf

self.user_emb_matrix = tf.get_variable(
            shape=[n_user, self.dim], initializer=KGCN.get_initializer(), name='user_emb_matrix')

 

get_variable() 는 새로운 객체를 생성한다.

이때 기존에 같은 변수 이름을 사용하는 객체가 존재한다면, 이어서 그 값을 받아올 수 도 있다!

 

tf.get_variable('var2', None, initializer=value)

 var2 라는 이름을 가진 객체를 생성하기 전에, 해당 이름을 사용하는 객체가 존재하는지 탐색한다.

이때 기존에 같은 변수 이름을 사용하는 객체가 존재하고, 그 값을 받아오고 싶다면 reuse=tf.AUTO_REUSE 파라미터를 추가한다.

 

tf.variable_scope('var_scope', reuse=tf.AUTO_REUSE)

이렇게하면 기존 매개변수 값을 그대로 받아와서 사용할 수 있다~