--- a/bitshift.m
+++ b/bitshift.m
@@ -57,16 +57,16 @@
   if (nargin == 3)
     n = fix (n);
     if ( is_scalar(n) & (!is_scalar(k)) )
-      if (!is_scalar (A) & (size (A) != size (k)))
+      if (!is_scalar (A) & ((ndims (A) != ndims (k)) || any(size (A) != size (k))))
         error ("size of A and k must match");
       endif 
       n = n .* ones (size (k));
     elseif (!is_scalar (n)) & is_scalar (k)
-	  if (!is_scalar (A) & (size (A) != size (n)))
+      if (!is_scalar (A) & ((ndims (A) != ndims (n)) || any(size (A) != size (n))))
         error ("size of A and n must match");
       endif
       k = fix (k) .* ones (size (n));
-    elseif (size (n) != size (k))
+    elseif ((ndims (n) != ndims (n)) || any(size (n) != size (n)))
       error ("size of n and k must match");
     endif
   else