PyTorch Stack - Use the PyTorch Stack operation (torch.stack) to turn a list of PyTorch Tensors into one tensor
PyTorch Stack - Use the PyTorch Stack operation (torch.stack) to turn a list of PyTorch Tensors into one tensor
This video will show you how to use the PyTorch stack operation to turn a list of PyTorch tensors into one tensor.
First, we import PyTorch.
import torch
Then we print the PyTorch version we are using.
print(torch.__version__)
We are using PyTorch 0.4.0.
Let’s now create three tensors manually that we’ll later combine into a Python list.
We create our first PyTorch tensor using torch.tensor.
tensor_one = torch.tensor([[1,2,3],[4,5,6]])
Here, we can see the data structure.
We assign it to the Python variable tensor_one.
Let’s print the tensor_one Python variable to see what we have.
print(tensor_one)
We see that we have our PyTorch tensor, and we see that our data is in there.
Next, we create our second PyTorch tensor, again using the torch.tensor operation.
tensor_two = torch.tensor([[7,8,9],[10,11,12]])
Then we create our third tensor and assign it to the Python variable tensor_tre.
tensor_tre = torch.tensor([[13,14,15],[16,17,18]])
Again, we use torch.tensor and pass in more data.
So now that we have our three example tensors initialized, let’s put them in a Python list.
So we’re going to use the square bracket construction.
tensor_list = [tensor_one, tensor_two, tensor_tre]
We put tensor_one, tensor_two, tensor_tre, and we assign this list to the Python variable tensor_list.
We can then print this tensor list Python variable to see what we have.
print(tensor_list)
We see that we have a tensor here, then a comma, then a tensor here, then a comma, and then a tensor there.
So we have a list of three tensors.
Let’s now turn this list of tensors into one tensor by using the PyTorch stack operation.
stacked_tensor = torch.stack(tensor_list)
So we see torch.stack, and then we pass in our Python list that contains three tensors.
Then the result of this will be assigned to the Python variable stacked_tensor.
Note that the default setting in PyTorch stack is to insert a new dimension as the first dimension.
Our initial three tensors were all of shape 2x3.
We can see this by looking at our tensor_one example that we constructed up here and saying dot shape.
tensor_one.shape
When we do that, we see that the torch size is 2x3.
So two rows, three columns.
So the default of torch.stack is that it’s going to insert a new dimension in front of the 2 here, so we’re going to end up with a 3x2x3 tensor.
The reason it’s 3 is because we have three tensors in this list we are converting to one tensor.
We can check that by using our stacked_tensor Python variable and checking the shape of it.
stacked_tensor.shape
We see that we get 3x2x3 because there are now three tensors of size 2x3 stacked up on top of each other.
We can then print the stacked_tensor Python variable to see what we have.
print(stacked_tensor)
So print(stacked_tensor) and we see that it is one tensor rather than a list of tensors as before.
So we have one tensor, one tensor, one tensor, so there’s a list of three tensors.
This time, we only have one tensor.
However, it is 3, so one, two, three by 2, one, two, one, two, one, two by 3, one, two, three, one, two, three, one, two, three.
So you can see our three tensors have now been combined into one tensor.
Perfect! We were able to use the PyTorch stack operation to turn a list of PyTorch tensors into one tensor.
Receive the Data Science Weekly Newsletter every Thursday
Easy to unsubscribe at any time. Your e-mail address is safe.