Use TensorFlow reshape to change the shape of a TensorFlow Tensor as long as the number of elements stay the same
Use TensorFlow reshape to change the shape of a TensorFlow Tensor as long as the number of elements stay the same
We start by importing TensorFlow as tf.
import tensorflow as tf
Then we print the version of TensorFlow that we are using.
print(tf.__version__)
We are using TensorFlow 1.5.0.
In this video, we're going to use tf.reshape to change the shape of a TensorFlow tensor as long as the number of elements stay the same.
We will do three examples to show how reshape works.
Let's start out with an initial TensorFlow constant tensor shaped 2x3x4 with numerical integer values between 1 and 24, all of whom have the data type of int32.
tf_initial_tensor_constant = tf.constant(
[
[
[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]
]
,
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
]
]
, dtype="int32"
)
So we use tf.constant, we have our 2x3x4 tensor, we have the data type as int32, and we see the numbers are 1, 2, 3, 4, all the way through 24, and we assign it to the Python variable, tf_initial_tensor_constant.
Now that we have it, let's print out the tf_initial_tensor_constant Python variable to see what we have.
print(tf_initial_tensor_constant)
We see that it's a TensorFlow constant, the shape is 2x3x4, the data type is int32.
Because we haven't run it in a TensorFlow session yet, it doesn't seem to have values even though we just defined it as a constant.
The same will apply to the other reshapes we're about to create.
For the first example, let's go from a tensor whose shape is 2x3x4 to a tensor whose shape is 2x12.
tf_ex_one_reshaped_tensor_2_by_12 = tf.reshape(tf_initial_tensor_constant, [2, 12])
So we're going to use tf.reshape, we pass in the tf_initial_tensor_constant, and then we pass in the specifics of what we want the new shape to be.
So it'll be 2, 12 then we assign the whole thing to the Python variable, tf_ex_one_reshaped_tensor_2_by_12.
Note that the number of elements will stay the same as 2 x 3 x 4 is 24 and 2 x 12 is 24.
Let's print out the tf_ex_one_reshaped_tensor_2_by_12 Python variable to see what we have.
print(tf_ex_one_reshaped_tensor_2_by_12)
We see that it's a TensorFlow tensor, we see that the shape is 2x12, and we see that the data type is int32.
It's not showing any values yet because we're still building the TensorFlow graph and we haven't run it in a TensorFlow session.
For the second example, let's change a tensor whose shape is 2x3x4 to a tensor whose shape is 2x3x2x2.
tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2 = tf.reshape(tf_initial_tensor_constant, [2, 3, 2, 2])
So we use tf.reshape, we pass in our initial tensor, and then we specify what the shape is going to be.
So we pass in 2, 3, 2, 2 and we assign it to the Python variable, tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2.
Note that the number of elements will stay the same as 2 x 3 x 4 is 24 and 2 x 3 x 2 x 2 is 24 as well.
Let's print out the tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2 Python variable to see what we have.
print(tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2)
We see that it's a TensorFlow tensor, we see that the shape is 2x3x2x2, which is what we would expect, and the data type is int32.
For the third example, we're going to change a TensorFlow tensor whose shape is 2x3x4 to a vector of 24 elements.
tf_ex_tre_reshaped_tensor_1_by_24 = tf.reshape(tf_initial_tensor_constant, [-1])
The way we do that is we use the tf.reshape operation, we pass in our initial tensor, and here, we're going to use a (-1).
So what that's going to do is it's just going to flatten the tensor, so it's just going to be a list of 24 elements.
We assign it to the Python variable, tf_ex_tre_reshaped_tensor_1_by_24.
Let's print out the tf_ex_tre_reshaped_tensor_1_by_24 Python variable to see what we have.
print(tf_ex_tre_reshaped_tensor_1_by_24)
We see that it's a TensorFlow tensor, we see that the shape is (24,), that means it's going to be a vector, the data type is int32.
Now that we have created our TensorFlow tensors, it's time to run the computational graph.
First, we launch the graph in a session.
sess = tf.Session()
Then we initialize all the global variables in the graph.
sess.run(tf.global_variables_initializer())
In our case, it's going to be all the tensors we have created.
Next, we are going to print out the four tensors to see how tf.reshape works.
Let's print out our initial tensor constant.
print(sess.run(tf_initial_tensor_constant))
So we do a print(sess.run(tf_initial_tensor_constant)).
We see that it's a 2x3x4 tensor, the numbers go from 1 to 24, and none of them have decimal points so we know that they're int32 numbers.
Let's now print our first reshaped tensor.
print(sess.run(tf_ex_one_reshaped_tensor_2_by_12))
So we use the print(sess.run(tf_ex_one_reshaped_tensor_2_by_12)).
We see that it's a tensor that has two matrices inside of it.
The first matrix has one row and 12 columns.
The second matrix has one row and 12 columns.
So all of our elements are there, 1 through 24.
Let's now print our second reshaped tensor, Python variable tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2.
print(sess.run(tf_ex_two_reshaped_tensor_2_by_3_by_2_by_2))
Awesome.
We see that it's a tensor that has two interior tensors, each of which has three matrices that are 2x2.
Perfect.
So two rows, two columns; two rows, two columns; two rows, two columns.
Then two rows, two columns; two rows, two columns; two rows, two columns.
So overall, we can see that the shape is 2x3x2x2 and all our numbers are there.
Finally, let's print our third reshaped TensorFlow example.
print(sess.run(tf_ex_tre_reshaped_tensor_1_by_24))
This is the Python variable, tf_ex_tre_reshaped_tensor_1_by_24.
Awesome.
We see that it's a vector that's 24 elements long.
So we see the number 1 all the way to 24.
So all our numbers are there.
Perfect - we were able to use tf.reshape to change the shape of a TensorFlow tensor as long as the number of elements stay the same.
Receive the Data Science Weekly Newsletter every Thursday
Easy to unsubscribe at any time. Your e-mail address is safe.