How does the axis parameter from NumPy work?

PythonArraysNumpyMultidimensional ArrayNumpy Ndarray

Python Problem Overview


Can someone explain exactly what the axis parameter in NumPy does?

I am terribly confused.

I'm trying to use the function myArray.sum(axis=num)

At first I thought if the array is itself 3 dimensions, axis=0 will return three elements, consisting of the sum of all nested items in that same position. If each dimension contained five dimensions, I expected axis=1 to return a result of five items, and so on.

However this is not the case, and the documentation does not do a good job helping me out (they use a 3x3x3 array so it's hard to tell what's happening)

Here's what I did:

>>> e
array([[[1, 0],
        [0, 0]],

       [[1, 1],
        [1, 0]],

       [[1, 0],
        [0, 1]]])
>>> e.sum(axis = 0)
array([[3, 1],
       [1, 1]])
>>> e.sum(axis=1)
array([[1, 0],
       [2, 1],
       [1, 1]])
>>> e.sum(axis=2)
array([[1, 0],
       [2, 1],
       [1, 1]])
>>>

Clearly the result is not intuitive.

Python Solutions


Solution 1 - Python

Clearly,

e.shape == (3, 2, 2)

Sum over an axis is a reduction operation so the specified axis disappears. Hence,

e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)

Intuitively, we are "squashing" the array along the chosen axis, and summing the numbers that get squashed together.

Solution 2 - Python

To understand the axis intuitively, refer the picture below (source: Physics Dept, Cornell Uni)

enter image description here

The shape of the (boolean) array in the above figure is shape=(8, 3). ndarray.shape will return a tuple where the entries correspond to the length of the particular dimension. In our example, 8 corresponds to length of axis 0 whereas 3 corresponds to length of axis 1.

Solution 3 - Python

If someone need this visual description:

numpy axis

Solution 4 - Python

There are good answers for visualization however it might help to think purely from analytical perspective.

You can create array of arbitrary dimension with numpy. For example, here's a 5-dimension array:

>>> a = np.random.rand(2, 3, 4, 5, 6)
>>> a.shape
(2, 3, 4, 5, 6)

You can access any element of this array by specifying indices. For example, here's the first element of this array:

>>> a[0, 0, 0, 0, 0]
0.0038908603263844155

Now if you take out one of the dimensions, you get number of elements in that dimension:

>>> a[0, 0, :, 0, 0]
array([0.00389086, 0.27394775, 0.26565889, 0.62125279])

When you apply a function like sum with axis parameter, that dimension gets eliminated and array of dimension less than original gets created. For each cell in new array, the operator will get list of elements and apply the reduction function to get a scaler.

>>> np.sum(a, axis=2).shape
(2, 3, 5, 6)

Now you can check that the first element of this array is sum of above elements:

>>> np.sum(a, axis=2)[0, 0, 0, 0]
1.1647502999560164

>>> a[0, 0, :, 0, 0].sum()
1.1647502999560164

The axis=None has special meaning to flatten out the array and apply function on all numbers.

Now you can think about more complex cases where axis is not just number but a tuple:

>>> np.sum(a, axis=(2,3)).shape
(2, 3, 6)

Note that we use same technique to figure out how this reduction was done:

>>> np.sum(a, axis=(2,3))[0,0,0]
7.889432081931909

>>> a[0, 0, :, :, 0].sum()
7.88943208193191

You can also use same reasoning for adding dimension in array instead of reducing dimension:

>>> x = np.random.rand(3, 4)
>>> y = np.random.rand(3, 4)

# New dimension is created on specified axis
>>> np.stack([x, y], axis=2).shape
(3, 4, 2)
>>> np.stack([x, y], axis=0).shape
(2, 3, 4)

# To retrieve item i in stack set i in that axis 

Hope this gives you generic and full understanding of this important parameter.

Solution 5 - Python

Some answers are too specific or do not address the main source of confusion. This answer attempts to provide a more general but simple explanation of the concept, with a simple example.

The main source of confusion is related to expressions such as "Axis along which the means are computed", which is the documentation of the argument axis of the numpy.mean function. What the heck does "along which" even mean here? "Along which" essentially means that you will sum the rows (and divide by the number of rows, given that we are computing the mean), if the axis is 0, and the columns, if the axis is 1. In the case of axis is 0 (or 1), the rows can be scalars or vectors or even other multi-dimensional arrays.

In [1]: import numpy as np

In [2]: a=np.array([[1, 2], [3, 4]])

In [3]: a
Out[3]: 
array([[1, 2],
       [3, 4]])

In [4]: np.mean(a, axis=0)
Out[4]: array([2., 3.])

In [5]: np.mean(a, axis=1)
Out[5]: array([1.5, 3.5])

So, in the example above, np.mean(a, axis=0) returns array([2., 3.]) because (1 + 3)/2 = 2 and (2 + 4)/2 = 3. It returns an array of two numbers because it returns the mean of the rows for each column (and there are two columns).

Solution 6 - Python

Both 1st and 2nd reply is great for understanding ndarray concept in numpy. I am giving a simple example.

And according to this image by @debaonline4u

https://i.stack.imgur.com/O5hBF.jpg

Suppose , you have an 2D array - [1, 2, 3] [4, 5, 6]

In, numpy format it will be -

c = np.array([[1, 2, 3], 
              [4, 5, 6]])  

Now,

c.ndim = 2 (rows/axis=0)
c.shape = (2,3) (axis0, axis1)
c.sum(axis=0) = [1+4, 2+5, 3+6] = [5, 7, 9] (sum of the 1st elements of each rows, so along axis0)
c.sum(axis=1) = [1+2+3, 4+5+6] = [6, 15]    (sum of the elements in a row, so along axis1)

So for your 3D array, 3d Numpy array sum

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionCodyBugsteinView Question on Stackoverflow
Solution 1 - PythonMartinView Answer on Stackoverflow
Solution 2 - Pythonkmario23View Answer on Stackoverflow
Solution 3 - Pythondebaonline4uView Answer on Stackoverflow
Solution 4 - PythonShital ShahView Answer on Stackoverflow
Solution 5 - PythonnbroView Answer on Stackoverflow
Solution 6 - Python33Anika33View Answer on Stackoverflow