2

I have been trying to implement a wavenet. From the papers and designs I have looked at on github I have come up with the following...

for i, (last, d) in enumerate(is_last([1, 2, 4, 8, 16, 32, 64, 128, 256] * 4)):
        h = layers.Conv1D(64, 2, dilation_rate = d, padding = 'causal', activation = 'tanh', name = 'h_%d' % i)(r)
        t = layers.Conv1D(64, 2, dilation_rate = d, padding = 'causal', activation = 'sigmoid', name = 't_%d' % i)(r)
        x = h * t
        s = s + layers.Conv1D(256, 1, name = 's_%d' % i)(x)
        if not last:
            r = r + layers.Conv1D(64, 1, name = 'r_%d' % i)(x)

In this code block h and t are the dilated/gated convolutions. The s variable is my skip connection which will eventually have a relu applied to it before the post processing layers. The r variable is my residual connection which is fed into the next layer. What I don't understand is why the convolution that is added to r does not have an activation function. I know having two linear layers in a row can just be simplified to a single linear layer. Am I missing something here? What is the point of having a linear convolution?

chasep255
  • 695
  • 2
  • 7
  • 15

1 Answers1

1

To me it appears to that the residual 1x1 convolution for $r$ is computing 64 channel-wise linear combinations of $x$. During training that convolution will 'learn' which linear combinations of channels from $x$ are useful for each channel of $r$ when being applied as a residual connection.

Something common for residual connections is directly adding $x$ to $r$ with perhaps some learnable scalar parameter $\alpha$ such that $r = r + x \cdot \alpha$. This linear 1x1 convolution can learn to do the exact same thing, but it can also learn a more complex linear relationship.

The lack of a nonlinear activation is likely for simplicity; the purpose of residual connections is to carry information across layers which may have been 'lost', and non-linearity in the residual connection itself would introduce the possibility of information loss (think saturation of sigmoids or dying of ReLUs) and may introduce additional points of instability during training.

I don't have much experience with wavenets compared to other network architectures, so take my response with a grain of salt.

Avelina
  • 809
  • 1
  • 12