7

Let $\vec X$ be a vector. The $\vec V = \mathrm{logsoftmax}(\vec{X})$ function is defined as:

$$v_i = \ln\left(\frac{e^{x_i}}{\sum_i e^{x_i}}\right)$$

This is provided in machine learning numerical packages for numerical stability.

Is there a numerically stable implementation of:

$$\ln\left(1 - \frac{e^{x_i}}{\sum_i e^{x_i}}\right)$$

provided in standard packages (e.g. PyTorch, or scipy, etc.)? What's a good way of computing this?

For example, if $e^{x_i}$ represents the (non-normalized) probaiblity of a label, then this is the log of the probability of an incorrect label.

becko
  • 3,298
  • 1
  • 19
  • 36

3 Answers3

6

Usually these values are not computed alone: the entire collection of $v_i$ and $\log(1 - \exp(v_i))$ is needed. That changes the analysis of computational effort.

To this end, let

$$\bar x = \log\left(\sum_{j} e^{x_j}\right) = x_k + \log\left(1 + \sum_{j\ne k} e^{x_j-x_k}\right)$$

for any index $k.$ The right hand expression shows how $\bar x - x_k$ can be computed in a numerically stable way when $k$ is the index of the largest argument, for then the argument of the logarithm is between $1$ and $n$ (the number of the $x_i$) and the sum can be accurately computed using exp (especially when the $x_j$ are ordered from smallest to largest in the sum).

The relation

$$\eqalign{ \log\left(1 - \frac{e^{x_i}}{\sum_{j} e^{x_j}}\right) &= \log\left(\frac{\sum_{j\ne i} e^{x_j}}{\sum_{j} e^{x_j}}\right) \\ &= \log\left(\sum_{j\ne i} e^{x_j}\right) - x_k + \log\left(\frac{e^{x_k}}{\sum_{j} e^{x_j}}\right) \\ &= \log\left(-e^{x_i} + \sum_{j} e^{x_j}\right) - x_k + v_k \\ &= \log\left(1 - e^{x_i - \bar x}\right) + (\bar x - x_k) + v_k }$$

reduces the problem to finding that last logarithm, which is accomplished by applying a log1mexp function to the difference $x_i - \bar x.$

For $n$ arguments $x_i,$ $i=1,2,\ldots, n,$ the total effort to compute all $2n$ values is

  • $n$ computations of the $v_i$ using logsoftmax.

  • One computation of $\bar x - x_k$ (using $n-1$ exponentials and a logarithm).

  • $n$ invocations of log1mexp.

whuber
  • 281,159
  • 54
  • 637
  • 1,101
  • It is possible to get rid of the last $n$ invocations of `log1mexp`. See my answer below. – becko Jun 01 '20 at 21:53
4

One option is to use the numerically stable log-softmax implementation in combination with a numerically stable $\text{log1m_exp}(x):=\log(1-\exp(x))$ function.

I believe the following is pretty good for $\text{log1m_exp}$:

  • if x > -0.693147 you use $\log(-\text{expm1}(x))$,
  • otherwise $\text{log1m}(\exp(x))$,

where $\text{log1m}(x):= \text{log1p}(-x)$. $\text{log1p}(x):=\log(1+x)$ is usually implemented (e.g. in numpy or base R, similarly the inverse function for $\text{log1p}$, i.e. $\text{expm1}(x) = \text{log1p}^{-1}(x)$).

Björn
  • 21,227
  • 2
  • 26
  • 65
  • Yes I thought of that. The one thing I don't like is that it's less efficient than `logsoftmax`, because it involves the additional `log1mexp` call. – becko Jun 01 '20 at 13:32
  • I guess you could also write your own C-function. You are essentially implementing log_sum_exp(x[-i]) - log_sum_exp(x) - see https://mc-stan.org/math/d8/d23/prim_2mat_2fun_2log__sum__exp_8hpp_source.html for a good implementation of log_sum_exp - where the [-i] is meant to indicate the vector without index i. – Björn Jun 01 '20 at 13:55
  • Another idea, which is probably slower than writing your own C function, but perhaps easier is to use numba. You should be able to just define the above as I indicated and you may wish to see whether that gets you a speed-up, although it is usually pretty hard to beat a combination of numpy functions (which my main solution eseentially is, which is usally pretty good efficiency-wise). I suspect if you want a speed-up versus my main answer you may have to really go into writing a dedicated and optimized function in a good compiled language. – Björn Jun 01 '20 at 14:01
0

Here is a possible way to do this, in Julia code (which I think is quite readable even if one does not know Julia):

function log1msoftmax(x::AbstractArray; dims=1)
  m = maximum(x; dims=dims)
  e = exp.(x .- m)
  s = sum(e; dims=dims)
  return log.((s .- e) ./ s)
end

This has the same complexity as a logsoftmax (it does one exp and one log call per entry).

becko
  • 3,298
  • 1
  • 19
  • 36
  • I guess there's even more to be discovered in [StatsFuns](https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/basicfuns.jl). – phipsgabler Jun 02 '20 at 09:49
  • This will fail when the $x_i$ are dominated by one value: in that case, the expression `s .- e` will involve catastrophic cancellation of precision and can even result in a $\log(0)$ error. Try it with $x=(-37,0)$ in double precision. – whuber Jun 02 '20 at 10:29
  • 1
    @whuber Good catch. Let's see if I can fix it without having additional exp/log calls. – becko Jun 02 '20 at 10:31