텐서플로우(Tensorflow) - tf.gather() 란?

2022. 11. 15. 13:17Tensorflow

tf.gather(params, indices, validate_indices=None, name=None, axis=None, batch_dims=0)

전달인자 params(=tensor)에서 axis 축 기준으로 indices 값들을 뽑아온다.

 

예시:

v1 = tf.constant([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])

v2 = tf.constant([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])

 

with tf.Session() as sess:

    print(sess.run(tf.gather(v1, [2, 5, 2, 5], axis=0)))

    print(sess.run(tf.gather(v2, [0, 1], axis=0)))

    print(sess.run(tf.gather(v2, [0, 1], axis=1)))

 

output:

[5 0 5 0]

[[ 1 2 3 4 5 6] [ 7 8 9 10 11 12]]

[[1 2] [7 8]]

 

v1 텐서에서 행을 기준으로 2,5,2,5 index 값들을 뽑아온다.

v2 텐서에서 행을 기준으로 0,1 index 값들을 뽑아온다.

- 이때 다른점은 v1은 1차원 텐서이고, v2는 2차원 텐서이다. 

- 따라서 v2 에서 행 기준 0,1 index는 [0][i] 와 [1][i]를 의미하므로 첫번째행, 두번째행 전체가 출력된다.

 

v2 텐서에서 열을 기준으로 0,1 index 값들을 뽑아온다.

- 이는 [i][0] 와 [i][1]을 의미하므로 첫번째 행의 열 기준 0,1 index, 두번째 행의 열 기준 0,1 index가 출력된다.