PyTorch Concatenate - Use PyTorch cat to concatenate a list of PyTorch tensors along a given dimension
PyTorch Concatenate - Use PyTorch cat to concatenate a list of PyTorch tensors along a given dimension
We import PyTorch.
import torch
We print out the PyTorch version we are using.
print(torch.__version__)
We see that we are using PyTorch version 0.2.0_4.
The first thing we’re going to do is we’re going to define a PyTorch tensor and we’re going to initialize it using the random functionality which pulls a random number between 0 to 1.
Then we’re going to multiply it by 100 so that we have a number between 0 to 100 and we cast it to an Int PyTorch tensor just so it’s cleaner when we go to look at the numbers.
x = (torch.rand(2, 3, 4) * 100).int()
Then we can see that it is a PyTorch IntTensor of size 2x3x4.
print(x)
We repeat the same process and this time defining a variable y. Cast this one to an Int as well.
y = (torch.rand(2, 3, 4) * 100).int()
When we print it, we can see that we have a PyTorch IntTensor of size 2x3x4.
print(y)
Looking at the y, we have 85, 56, 58.
Looking at the x, we have 58, 85, 74.
So two different PyTorch IntTensors.
In this video, we want to concatenate PyTorch tensors along a given dimension.
So here, we see that this is a three-dimensional PyTorch tensor.
We have 2x3x4.
So we can concatenate it across the first one, or across the second one, or across the third one.
We’ll define a variable z_zero and use the PyTorch concatenation function where we pass in the list of our two PyTorch tensors, so x, y, and we’re going to concatenate it by the 0th dimension, so the first dimension.
z_zero = torch.cat((x, y), 0)
When we print this z_zero variable, we see that it is 4x3x4.
print(z_zero)
Remember that x was 2x3x4 and y was 2x3x4.
So we have concatenated it across the first dimension.
So we see 58, 85, 74 which was this one; then 43, 80, 13 which was the second one; 85, 56, 58 which was the third one; 86, 7, 52 which was the fourth one.
So we have 4x3x4.
That was z_zero.
Next, we create a second variable called z_one.
z_one = torch.cat((x, y), 1)
We again use the PyTorch concatenation function.
We pass in a list of our two PyTorch tensors and we’re going to concatenate it across the second dimension.
Again, Python is a zero-based index, so we use 1 rather than 2.
When we print the z_one variable, we can see that it is of size 2x6x4.
print(z_one)
Remember that x was 2x3x4, y was 2x3x4, so that’s why we have 2x6x4.
We concatenated across this dimension.
So we see 58, 85, 74, 85, 56, 58.
Scrolling back up, we see that x was 58, 85, 74 and y was 85, 56, 58, which is what we see here – 58, 85, 74, 85, 56, 58.
The last concatenation variable that we define is z_two.
z_two = torch.cat((x, y), 2
We use the PyTorch concatenation function and we pass in the list of x and y PyTorch Tensors and we’re going to concatenate across the third dimension.
Remember that Python is zero-based index so we pass in a 2 rather than a 3.
Because x was 2x3x4 and y was 2x3x4, we should expect this PyTorch Tensor to be 2x3x8.
When we print z_two, we see that it is 2x3x8.
print(z_two)
Here we look at the first one.
We see 58, 39, 98, 97, and then it goes to 85.
So if we scroll up and look at our x and y, we want to see 58 and then 85.
So we see 58 and 85.
So the z_two variable was the concatenation across the third dimension.
We used 2, again, because Python is zero-based index.
Receive the Data Science Weekly Newsletter every Thursday
Easy to unsubscribe at any time. Your e-mail address is safe.