NumPy Tutorial NumPy Statistics NumPy Resources
Python Java C++ C C# PHP R SQL DS Algo InterviewQ

NumPy - expand_dims() function

The NumPy expand_dims() function expands the shape of an array. It inserts a new axis that will appear at the axis position in the expanded array shape.


numpy.expand_dims(a, axis)


a Required. Specify the input array (array_like).
axis Required. Specify the position in the expanded axes where the new axis (or axes) is placed. It can be int or tuple of ints.

Return Value

Returns view of a with the number of dimensions increased.


In the example below, an array is expanded on a given axis.

import numpy as np

x = np.array([1, 2, 3])

#expanding the dimension of x on axis=0
x1 = np.expand_dims(x, axis=0)

#expanding the dimension of x on axis=1
x2 = np.expand_dims(x, axis=1)

#expanding the dimension of x on axis=(0,1)
x3 = np.expand_dims(x, axis=(0,1))

#displaying results
print("shape of x:", x.shape)
print("x contains:")
print("\nshape of x1:", x1.shape)
print("x1 contains:")
print("\nshape of x2:", x2.shape)
print("x2 contains:")
print("\nshape of x3:", x3.shape)
print("x3 contains:")

The output of the above code will be:

shape of x: (3,)
x contains:
[1 2 3]

shape of x1: (1, 3)
x1 contains:
[[1 2 3]]

shape of x2: (3, 1)
x2 contains:

shape of x3: (1, 1, 3)
x3 contains:
[[[1 2 3]]]

❮ NumPy - Functions