2022. 11. 15. 13:17ㆍTensorflow
tf.gather(params, indices, validate_indices=None, name=None, axis=None, batch_dims=0)
전달인자 params(=tensor)에서 axis 축 기준으로 indices 값들을 뽑아온다.
예시:
v1 = tf.constant([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])
v2 = tf.constant([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
with tf.Session() as sess:
print(sess.run(tf.gather(v1, [2, 5, 2, 5], axis=0)))
print(sess.run(tf.gather(v2, [0, 1], axis=0)))
print(sess.run(tf.gather(v2, [0, 1], axis=1)))
output:
[5 0 5 0]
[[ 1 2 3 4 5 6] [ 7 8 9 10 11 12]]
[[1 2] [7 8]]
v1 텐서에서 행을 기준으로 2,5,2,5 index 값들을 뽑아온다.
v2 텐서에서 행을 기준으로 0,1 index 값들을 뽑아온다.
- 이때 다른점은 v1은 1차원 텐서이고, v2는 2차원 텐서이다.
- 따라서 v2 에서 행 기준 0,1 index는 [0][i] 와 [1][i]를 의미하므로 첫번째행, 두번째행 전체가 출력된다.
v2 텐서에서 열을 기준으로 0,1 index 값들을 뽑아온다.
- 이는 [i][0] 와 [i][1]을 의미하므로 첫번째 행의 열 기준 0,1 index, 두번째 행의 열 기준 0,1 index가 출력된다.
'Tensorflow' 카테고리의 다른 글
텐서플로우(Tensorflow) - tensorflow==1.x 버전 for MacOS - M1 (0) | 2022.11.16 |
---|---|
텐서플로우(Tensorflow) - tf.reduce_sum() 이란? (0) | 2022.11.15 |
텐서플로우(Tensorflow) - tf.expand_dims() 란? (0) | 2022.11.15 |
텐서플로우(Tensorflow) tf.nn.embedding_lookup() 이란? (0) | 2022.11.14 |
텐서플로우(Tensorflow) get_variable 함수 - tf.get_variable 이란? (0) | 2022.11.14 |