FIND ME ON

GitHub

LinkedIn

Log-Sum-Exp Trick

🌱

source # Why? For the POMDP code, we are required to implement our transition kernel T\mathcal{T} in the form of T_cartesian which computes it for all state pairs xt+1,xtx_{t+1},x_{t}. This approach is advantageous because: - We can vectorize the computation of all probabilities - We can avoid a lot of overhead in computing the cdf of the Gaussian and instead just compute a hasty representation of the pdf and normalize

The issue with this basic approach is that with small values for the variance the issue of underflow can arise and cause our rows to be so small that summing them gives us a value of 00, which then leads to normalized rows of nan.

This is where the Log-Sum-Exp trick comes in. This was first suggested to me through Cursor which kinda just did a “well duh this works” approach to sneaking it without naming the concept it was wielding. I sought to record my understanding of the concept as the person who’s blog I learned this concept from aptly put this quote on their website homepage: > I learned very early the difference between knowing the name of something and knowing something. - Richard Feynman

So let’s jump into it.

The Log-Sum-Exp Trick

In general, in ML or stats we will often work in Log Scale. This allows us to handle very small values e.g. x,yx,y such that we can avoid issues like underflow: log(xy)=logx+logy\log(xy)=\log x+\log y. Where the LHS could be subject to underflow (due to precomputing xyxy), the log\log operation sufficiently (in most cases) blows up the respective individual values s.t. the RHS is not victim to the same issue.

Overall, working in log scale is much more computationally stable that even scipy’s multivariate_normal uses them to compute pdfs.

Sometimes though, we need to convert our log probabilities (denoted as xix_{i}) to real probabilities (denoted as pip_{i}). One may naively compute the following: pi=exp(xi)jexp(xj),n=1Npn=1(1)\tag{1}p_{i}=\frac{\exp(x_{i})}{\sum_{j}\exp(x_{j})},\quad \sum_{n=1}^{N}p_{n}=1. These xix_{i}’s individually may be very large where specifically in our application: - They may be “large” due to a really small σw2\sigma_{w}^{2} in (xiμ)22σw2\frac{-(x_{i}-\mu)^{2}}{2\sigma_{w}^{2}} - But tinyness doesn’t seem to be too much an issue since if we’re going over all Xn\mathbb{X}_{n} then there will be an xXnx\in \mathbb{X}_{n} s.t. xixx_{i}-x is not too small. Since the term is negative we get really tiny exponents for everything pretty much unless we have a sufficiently large distance in the numerator to combat this issue. Enter theLog-Sum-Exp operation >[!def] LogSumExp >Given a list of values {x1,xN}\{ x_{1},\dots x_{N} \}, the LogSumExp of these is defined as LSE(x1,,xN)=logn=1Nexp(xn)\text{LSE}(x_{1},\dots,x_{N})=\log \sum_{n=1}^{N}\exp(x_{n})

Now let us rewrite the naive (1)(1): exp(xi)=pin=1Nexp(xn)xi=logpi+logn=1Nexp(xn)logpi=xilogn=1Nexp(xn)pi=exp(xilogn=1Nexp(xn))pi=exp(xiLSE(x1,,xN))\begin{align*} \exp(x_{i})&= p_{i}\sum_{n=1}^{N}\exp(x_{n})\\ x_{i}&= \log p_{i}+\log \sum_{n=1}^{N}\exp(x_{n})\\ \log p_{i}&= x_{i}-\log \sum_{n=1}^{N}\exp(x_{n})\\ p_{i}&= \exp\left( x_{i}-\log \sum_{n=1}^{N}\exp(x_{n}) \right)\\ p_{i}&= \exp(x_{i}-\text{LSE}(x_{1},\dots,x_{N})) \end{align*} Still, you may observe that we’re still dealing with individual exp(xn)\exp(x_{n}) values which brings us back to our original problem. Well, consider the following: y=logn=1Nexp(xn)ey=n=1Nexp(xn)ey=ecn=1Nexp(xnc)y=c+logn=1Nexp(xnc).\begin{align*} y&= \log \sum_{n=1}^{N}\exp(x_{n})\\ e^{y}&= \sum_{n=1}^{N}\exp(x_{n})\\ e^{y}&= e^{c}\sum_{n=1}^{N}\exp(x_{n}-c)\\ y&= c+\log \sum_{n=1}^{N}\exp(x_{n}-c). \end{align*}Let c=max{x1,,xN}c=\max\{ x_{1},\dots,x_{N} \}, then max1nNexp(xnc)=e0=1min1nNexp(xnc)0\max_{1\le n\le N}\exp(x_{n}-c)=e^{0}=1\quad\min_{1\le n \le N}\exp(x_{n}-c)\ge 0Now, using the following inequality from Wikipedia: max{x1,,xN}LSE(x1,,xN)max{x1,,xN}+logN\max\{ x_{1},\dots,x_{N} \}\le \text{LSE}(x_{1},\dots,x_{N})\le \max \{ x_{1},\dots,x_{N} \}+\log Nwe get logNxiLSE(x1,,xN)0-\log N\le x_{i}-\text{LSE}(x_{1},\dots,x_{N})\le 0hence, 1Nexp(xiLSE(x1,,xN))1\frac{1}{N}\le \exp(x_{i}-\text{LSE}(x_{1},\dots,x_{N}))\le 1which is nice! This means each pip_{i} does not have to deal with underflow.