PyTorch View - how to use the PyTorch View (.view(...)) operation to reshape a PyTorch tensor
PyTorch View - how to use the PyTorch View (.view(...)) operation to reshape a PyTorch tensor
First, we import PyTorch.
import torch
Then we print the PyTorch version we are using.
print(torch.__version__)
We are using PyTorch 0.3.1.post2.
Let’s now create a PyTorch tensor using torch.Tensor.
pt_initial_tensor_ex = torch.Tensor(
[
[
[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24]
]
,
[
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36],
[37, 38, 39, 40, 41, 42],
[43, 44, 45, 46, 47, 48]
]
])
We pass in our data structure, so it’s going to be 2x4x6, and we assign that tensor to the Python variable pt_initial_tensor_ex.
Let’s print the pt_initial_tensor_ex Python variable to see what we have.
print(pt_initial_tensor_ex)
We see that it’s a torch.FloatTensor of size 2x4x6, and it has all the numbers we passed in.
So we can see that all the values go from starting at 1 all the way to the number 48.
So we have 48 elements.
For the first example, we’re going to reshape our PyTorch tensor, but we’re going to retain the same rank.
pt_reshaped_6_by_4_by_2_tensor_ex = pt_initial_tensor_ex.view(6, 4, 2)
So it will still be a rank 3, only this time it’s going to be a 6x4x2 tensor rather than a 2x4x6.
So we pass in our initial tensor, and then we say dot view, and we assign this result to the Python variable pt_reshaped_6_by_4_by_2_tensor_ex.
Now we print the pt_reshaped_6_by_4_by_2_tensor_ex Python variable to see what we have.
print(pt_reshaped_6_by_4_by_2_tensor_ex)
We see that it’s still a torch.FloatTensor.
We see that the size has now been changed to 6x4x2.
We see that the last number in the final matrix is 48, and we see the initial number is 1.
We see that we have one, two, three, four, five, six, so six.
Then we have one, two, three, four rows and two columns.
Just to double check, we print our original tensor to see that the view operation returns a new tensor and doesn’t do in-place reshaping of the original tensor.
print(pt_initial_tensor_ex)
So we still have our original tensor that is 2x4x6.
For the second reshaped tensor example, we’re going to decrease the tensor rank from one, two, three to one, two.
pt_reshaped_4_by_12_tensor_ex = pt_initial_tensor_ex.view(4, 12)
So we’re going to pass in our initial tensor, then we’re going to say dot view, and we want the reshape to be 4 by 12.
4 times 12 is 48, so 48.
So we’re not changing the number of elements.
We’re just shaping them differently.
So rather than being a 2x4x6 tensor, it’s going to be a 4x12 tensor.
The result of this operation is going to be assigned to the Python variable pt_reshaped_4_by_12_tensor_ex.
Let’s print this tensor and we see that it’s a torch.FloatTensor of size 4x12.
print(pt_reshaped_4_by_12_tensor_ex)
So the rank is 2 versus the rank before was 3.
We see that we have all our numbers from 1 to 48.
We have four rows and 12 columns.
For the third and last PyTorch reshape example, let’s increase the rank of the tensor so that we go from a 2x4x6 tensor to a 3x2x2x4 tensor.
pt_reshaped_3_by_2_by_2_by_4_tensor_ex = pt_initial_tensor_ex.view(3, 2, 2, 4)
So our rank will increase from 3 to 4.
We assign that result to pt_reshaped_3_by_2_by_2_by_4_tensor_ex.
So we can then print the pt_reshaped_3_by_2_by_2_by_4_tensor_ex to see what we get.
print(pt_reshaped_3_by_2_by_2_by_4_tensor_ex)
We see that we have 0, 0, 1, 1, 2, 2.
So that’s where our initial 3 is.
Inside of that, we have 0 and 1, 0 and 1, 0 and 1.
That’s where the second 2 is.
Then each internal matrix, we see that it has two rows and one, two, three, four columns.
So that’s where the 2 and the 4 goes.
So 3 times 2 is 6 times 2 is 12.
12 times 4 is 48.
So we have all our 48 numbers there.
Perfect! We were able to reshape a PyTorch tensor by using the PyTorch view operation.
Receive the Data Science Weekly Newsletter every Thursday
Easy to unsubscribe at any time. Your e-mail address is safe.