Both batch norm and layer norm are common normalization techniques for neural network training.
I am wondering why transformers primarily use layer norm.
Both batch norm and layer norm are common normalization techniques for neural network training.
I am wondering why transformers primarily use layer norm.
It seems that it has been the standard to use batchnorm in CV tasks, and layernorm in NLP tasks. The original Attention is All you Need paper tested only NLP tasks, and thus used layernorm. It does seem that even with the rise of transformers in CV applications, layernorm is still the most standardly used, so I'm not completely certain as to the pros and cons of each. But I do have some personal intuitions -- which I'll admit aren't grounded in theory, but which I'll nevertheless try to elaborate on in the following.
Recall that in batchnorm, the mean and variance statistics used for normalization are calculated across all elements of all instances in a batch, for each feature independently. By "element" and "instance," I mean "word" and "sentence" respectively for an NLP task, and "pixel" and "image" for a CV task. On the other hand, for layernorm, the statistics are calculated across the feature dimension, for each element and instance independently (source). In transformers, it is calculated across all features and all elements, for each instance independently. This illustration from this recent article conveys the difference between batchnorm and layernorm:
(in the case of transformers, where the normalization stats are calculated across all features and all elements for each instance independently, in the image that would correspond to the left face of the cube being colored blue.)
Now onto the reasons why batchnorm is less suitable for NLP tasks. In NLP tasks, the sentence length often varies -- thus, if using batchnorm, it would be uncertain what would be the appropriate normalization constant (the total number of elements to divide by during normalization) to use. Different batches would have different normalization constants which leads to instability during the course of training. According to the paper that provided the image linked above, "statistics of NLP data across the batch dimension exhibit large fluctuations throughout training. This results in instability, if BN is naively implemented." (The paper is concerned with an improvement upon batchnorm for use in transformers that they call PowerNorm, which improves performance on NLP tasks as compared to either batchnorm or layernorm.)
Another intuition is that in the past (before Transformers), RNN architectures were the norm. Within recurrent layers, it is again unclear how to compute the normalization statistics. (Should you consider previous words which passed through a recurrent layer?) Thus it's much more straightforward to normalize each word independently of others in the same sentence. Of course this reason does not apply to transformers, since computing on words in transformers has no time-dependency on previous words, and thus you can normalize across the sentence dimension too (in the picture above that would correspond to the entire left face of the cube being colored blue).
It may also be worth checking out instance normalization and group normalization, I'm no expert on either but apparently each has its merits.
A less known issue of Batch Norm is that how hard it is to parallellize batch-normalized models. Since there is dependence between elements, there is additional need for synchronization across devices. While this is not an issue for most vision models, which tends to be used on a small set of devices, Transformers really suffer from this problem, as they rely on large-scale setups to counter their quadratic complexity. In this regard, layer norm provides some degree of normalization while incurring no batch-wise dependence.
If you want to choose a sample box of data which contains all the feature but smaller in length of single dataframe row wise and small number in group of single dataframe sent as batch to dispatch -> layer norm
For transformer such normalization is efficient as it will be able to create relevance matrix in one go on all the entity.
And the first answers explains this very well in both modality [text and image]