diff --git a/plotting.py b/plotting.py
index 93683c6d980f3a41fbd998b79e42592d903ebec3..95d63923560ea164f674f41397e13f7e380b4c6a 100644
--- a/plotting.py
+++ b/plotting.py
@@ -73,14 +73,11 @@ def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear=True):
     """
     for r in range(axes.shape[0]):
         for c in range(axes.shape[1]):
-            if ylabel:
-                if clear: axes[r,c].set(ylabel=None)
-                if c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]):
-                    axes[r,c].set(ylabel=ylabel)
-            if xlabel:
-                if clear: axes[r,c].set(xlabel=None)
-                if r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]):
-                    axes[r,c].set(xlabel=xlabel)
+            if clear: axes[r,c].set(xlabel=None, ylabel=None)
+            if c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]):
+                axes[r,c].set(ylabel=ylabel)
+            if r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]):
+                axes[r,c].set(xlabel=xlabel)
 
 def grid_diagonal(ax, **kwargs):
     for k, v in dict(color='lightgray',lw=1,zorder=-100).items():