2

I need to alternate Pytorch Tensors (similar to numpy arrays) with rows and columns of zeros. Like this:

Input => [[ 1,2,3],
           [ 4,5,6],
           [ 7,8,9]]

output => [[ 1,0,2,0,3],
           [ 0,0,0,0,0],
           [ 4,0,5,0,6],
           [ 0,0,0,0,0],
           [ 7,0,8,0,9]] 

I am using the accepted answer in this question that proposes the following

def insert_zeros(a, N=1):
    # a : Input array
    # N : number of zeros to be inserted between consecutive rows and cols 
    out = np.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
    out[::N+1,::N+1] = a
    return out

The answers works perfectly, except that I need to perform this many times on many arrays and the time it takes has become the bottleneck. It is the step-sized slicing that takes most of the time.

For what it's worth, the matrices I am using it for are 4D, an example size of a matrix is 32x18x16x16 and I am inserting the alternate rows/cols only in the last two dimensions.

So my question is, is there another implementation with the same functionality but with reduced time?

3
  • Can you guarantee that your matrix will have exactly this shape? Will the number of zero-rows/zero-columns being inserted always be the same? Commented May 19, 2022 at 19:17
  • There are a few differently shaped matrices for which I need to do this. But these few shapes repeat all the time and the zero rows and cols are inserted always the same way Commented May 19, 2022 at 19:41
  • I tried setting up the indexes for "advanced indexing" in advance, but that doesn't seem to save any time. I haven't found anything faster than the answer you found Commented May 19, 2022 at 19:43

4 Answers 4

1

I am not familiar to Pytorch, but to accelerate the code that you provided, I think JAX library will help a lot. So, if:

import numpy as np
import jax
import jax.numpy as jnp
from functools import partial

a = np.arange(10000).reshape(100, 100)
b = jnp.array(a)

@partial(jax.jit, static_argnums=1)
def new(a, N):
    out = jnp.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
    out = out.at[::N+1,::N+1].set(a)
    return out

will improve the runtime about 10 times on GPU. It depends to array size and N (The increase in the sizes, the better performances). You can see Benchmarks on my Colab link based on the 4 answer proposed so far (JAX beats the others).
I believe that jax can be one of the best libraries for your case if you could adjust it on your problem (It is possible).

Sign up to request clarification or add additional context in comments.

1 Comment

Of course JAX beats the others, it is amazing. That being said, any numpy-based answer can be JAXed.
1

I found a few methods to achieve this result, and the indexing method seems to be consistently the fastest.

There might be some improvement to be made on other methods though, because I tried to generalized them from 1D to 2D and arbitrary number of leading dimensions, and might not have do it in the best way posisble.

Edit: Yet another method using numpy, not faster.

Performance test (CPU):

In [4]: N, C, H, W = 11, 5, 128, 128
   ...: x = torch.rand(N, C, H, W)
   ...: k = 3
   ...:
   ...: x1 = interleave_index(x, k)
   ...: x2 = interleave_view(x, k)
   ...: x3 = interleave_einops(x, k)
   ...: x4 = interleave_convtranspose(x, k)
   ...: x4 = interleave_numpy(x, k)
   ...:
   ...: assert torch.all(x1 == x2)
   ...: assert torch.all(x2 == x3)
   ...: assert torch.all(x3 == x4)
   ...: assert torch.all(x4 == x5)
   ...:
   ...: %timeit interleave_index(x, k)
   ...: %timeit interleave_view(x, k)
   ...: %timeit interleave_einops(x, k)
   ...: %timeit interleave_convtranspose(x, k)
   ...: %timeit interleave_numpy(x, k)

9.51 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.3 ms ± 4.19 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.5 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 809 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Performance test (GPU):

(numpy metod not tested)

...: ...
...: x = torch.rand(N, C, H, W, device="cuda")
...: ...
260 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
861 µs ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
912 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
429 µs ± 5.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Implementations:

import torch
import torch.nn.functional as F
import einops


def interleave_index(x, k):
    *cdims, Hin, Win = x.shape
    Hout = (k + 1) * (Hin - 1) + 1
    Wout = (k + 1) * (Win - 1) + 1
    out = x.new_zeros(*cdims, Hout, Wout)
    out[..., :: k + 1, :: k + 1] = x
    return out


def interleave_view(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/4
    """
    *cdims, Hin, Win = x.shape
    Hout = (k + 1) * (Hin - 1) + 1
    Wout = (k + 1) * (Win - 1) + 1
    zeros = [torch.zeros_like(x)] * k
    out = torch.stack([x, *zeros], dim=-1).view(*cdims, Hin, Wout + k)[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = torch.stack([out, *zeros], dim=-2).view(*cdims, Hout + k, Wout)[..., :-k, :]
    return out


def interleave_einops(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/6
    """
    zeros = [torch.zeros_like(x)] * k
    out = einops.rearrange([x, *zeros], "t ... h w -> ... h (w t)")[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = einops.rearrange([out, *zeros], "t ... h w -> ... (h t) w")[..., :-k, :]
    return out


def interleave_convtranspose(x, k):
    """
    From
    https://github.com/pytorch/pytorch/issues/7911#issuecomment-515493009
    """
    C = x.shape[-3]
    weight=x.new_ones(C, 1, 1, 1)
    return F.conv_transpose2d(x, weight=weight, stride=k+1, groups=C)


def interleave_numpy(x, k):
    """
    From https://stackoverflow.com/a/53179919
    """
    pos = np.repeat(np.arange(1, x.shape[-1]), k)
    out = np.insert(x, pos, 0, axis=-1)
    pos = np.repeat(np.arange(1, x.shape[-2]), k)
    out = np.insert(out, pos, 0, axis=-2)
    return out

Comments

1

Since you know the size of the array in advance, first step to optimize is to create the out array outside the function. Then, try numba to jit-compile the function and work in-place on the out array. This achieves 5X speedup over the numpy version you posted.

import numpy as np
from numba import njit

@njit
def insert_zeros_n(a, out, N=1):
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            out[2*i,2*j] = a[i,j]

and call it with the specified N and a:

N = 1
a = np.arange(16*16).reshape(16, 16)
out = np.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
insert_zeros_n(a,out)

2 Comments

IMO, numba can be useful on smaller arrays and with smaller N, and will get near the same performances as the OP answer (of course for your written code; perhaps this code could be written using numba in a faster scheme).
Sure if you don't count the array creation into the function it will speed up things, still that numba method achieves a 15% speedup over OP ones, with OP dims 32x18x16x16, but not 5x faster
1

Encapsulated for any N, what about using numpy.kron with 4D inputs,

a = np.arange(1, 19).reshape((1, 2, 3, 3))
print(a)
# array([[[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [ 7,  8,  9]],
# 
#         [[10, 11, 12],
#          [13, 14, 15],
#          [16, 17, 18]]]])


def interleave_kron(a, N=1):
    n = N + 1
    return np.kron(
        a, np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n))
    )[..., :-N, :-N]

where np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n)) could be externalized/defaulted once for all for the sake of performance.

and then

>>> interleave_kron(a, N=2)
array([[[[ 1.,  0.,  0.,  2.,  0.,  0.,  3.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 4.,  0.,  0.,  5.,  0.,  0.,  6.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 7.,  0.,  0.,  8.,  0.,  0.,  9.]],

        [[10.,  0.,  0., 11.,  0.,  0., 12.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [13.,  0.,  0., 14.,  0.,  0., 15.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [16.,  0.,  0., 17.,  0.,  0., 18.]]]])

?

4 Comments

I don't think that it can be beat the OP performance if N > 1 on large arrays. I think c can be np.array([1, *(N**2-1)*[0]]).reshape((N, N)) to be more comprehensive.
@Ali_Sh I timed it and it is faster to instantiate c with np.zeros
What do you mean @paime ? Like this ? See edit.
@keepAlive I meant the straightforward b=np.zeros(n,n); b[0, 0]=1

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.