텐서플로우(Tensorflow) - tf.expand_dims() 란?

2022. 11. 15. 13:03Tensorflow

import tensorflow as tf
seeds = tf.expand_dims(seeds, axis=1)
expand_dims 영어 단어가 의미하듯이 배열의 차원을 늘려주는 것 이다. 
전달인자로는 배열, axis 가 있는데 지정한 axis에 차원을 하나 삽입하는 것 이다.
 
예를 들면 seeds 라는 배열이 (2,3)의 shape을 가지고 있다고 하자.

 

axis0 = tf.expand_dims(seeds, axis=0) -> shape : (1,2,3)
axis1 = tf.expand_dims(seeds, axis=1) -> shape : (2,1,3)
axis2 = tf.expand_dims(seeds, axis=2) -> shape : (2,3,1)
 
다음과 같이 2차원에서 3차원으로 확장되었으며 해당 차원의 원소 개수는 1로 확장된다.