patternMinor
fast and stable x * tanh(log1pexp(x)) computation
Viewed 0 times
fastcomputationlog1pexpstableandtanh
Problem
$$f(x) = x \tanh(\log(1 + e^x))$$
The function (mish activation) can be easily implemented using a stable log1pexp without any significant loss of precision. Unfortunately, this is computationally heavy.
Is it possible to write a more direct numerically stable implementation which is faster?
Accuracy as good as
The distribution of inputs is from $[-\infty, \infty]$. It should work everywhere.
The function (mish activation) can be easily implemented using a stable log1pexp without any significant loss of precision. Unfortunately, this is computationally heavy.
Is it possible to write a more direct numerically stable implementation which is faster?
Accuracy as good as
x * std::tanh(std::log1p(std::exp(x))) would be nice. There is no strict constraints but it should be reasonably accurate for use in neural networks.The distribution of inputs is from $[-\infty, \infty]$. It should work everywhere.
Solution
With some algebraic manipulation (as pointed out in @orlp's answer), we can deduce the following:
$$f(x) = x \tanh(\log(1+e^x)) \tag{1}$$
$$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}\tag{2}$$
$$ = x - \frac{2x}{(1 + e^x)^2 + 1} \tag{3}$$
Expression $(3)$ works great when $x$ is negative with very little loss of precision. Expression $(2)$ is not suitable for large values of $x$ since the terms are going to blow up both in the numerator and denominator.
The function $(1)$ asymptotically hits zero as $x \to-\infty$. Now as $x$ becomes larger in magnitude, the expression $(3)$ will suffer from catastrophic cancellation: two large terms cancelling each other to give a really small number. The expression $(2)$ is more suitable in this range.
This works fairly well until $-18$ and beyond which you lose multiple significant figures.
Let's take a closer look at the function and try to approximate $f(x)$ as $x \to-\infty$.
$$f(x) = x \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$
The $e^{2x}$ will be orders of magnitude smaller than $e^x$. $e^x$ will be orders of magnitude smaller than $1$. Using these two facts, we can approximate $f(x)$ to:
$f(x) \approx x\frac{e^x}{e^x+1}\approx xe^x$
Result:
$f(x) \approx \begin{cases}
xe^x, & \text{if $x \le -18$} \\
x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{if $-18 \lt x \le -0.6$} \\
x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise}
\end{cases}
$
Fast CUDA implementation:
EDIT:
An even more faster and accurate version:
$f(x) \approx \begin{cases}
x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{$x \le -0.6$} \\
x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise}
\end{cases}
$
Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c
Benchmarks with other implementations: https://github.com/YashasSamaga/ConvolutionBuildingBlocks/tree/master/mish
$$\begin{array}{c|c|c|c|}
& \text{Time (float)} & \text{Time (float4)} & \text{L2 norm of error vector} \\ \hline
\text{mish} & 1.49ms & 1.39ms & 2.4583e-05 \\ \hline
\text{relu} & 1.47ms & 1.39ms & \text{N/A} \\ \hline
\end{array}$$
$$f(x) = x \tanh(\log(1+e^x)) \tag{1}$$
$$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}\tag{2}$$
$$ = x - \frac{2x}{(1 + e^x)^2 + 1} \tag{3}$$
Expression $(3)$ works great when $x$ is negative with very little loss of precision. Expression $(2)$ is not suitable for large values of $x$ since the terms are going to blow up both in the numerator and denominator.
The function $(1)$ asymptotically hits zero as $x \to-\infty$. Now as $x$ becomes larger in magnitude, the expression $(3)$ will suffer from catastrophic cancellation: two large terms cancelling each other to give a really small number. The expression $(2)$ is more suitable in this range.
This works fairly well until $-18$ and beyond which you lose multiple significant figures.
Let's take a closer look at the function and try to approximate $f(x)$ as $x \to-\infty$.
$$f(x) = x \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$
The $e^{2x}$ will be orders of magnitude smaller than $e^x$. $e^x$ will be orders of magnitude smaller than $1$. Using these two facts, we can approximate $f(x)$ to:
$f(x) \approx x\frac{e^x}{e^x+1}\approx xe^x$
Result:
$f(x) \approx \begin{cases}
xe^x, & \text{if $x \le -18$} \\
x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{if $-18 \lt x \le -0.6$} \\
x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise}
\end{cases}
$
Fast CUDA implementation:
__device__ float mish(float x)
{
auto e = __expf(x);
if (x <= -18.0f)
return x * e;
auto n = e * e + 2 * e;
if (x <= -0.6f)
return x * __fdividef(n, n + 2);
return x - 2 * __fdividef(x, n + 2);
}EDIT:
An even more faster and accurate version:
$f(x) \approx \begin{cases}
x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{$x \le -0.6$} \\
x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise}
\end{cases}
$
__device__ float mish(float x)
{
auto e = __expf(value);
auto n = e * e + 2 * e;
if (value <= -0.6f)
return value * __fdividef(n, n + 2);
return value - 2 * __fdividef(value, n + 2);
}Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c
Benchmarks with other implementations: https://github.com/YashasSamaga/ConvolutionBuildingBlocks/tree/master/mish
$$\begin{array}{c|c|c|c|}
& \text{Time (float)} & \text{Time (float4)} & \text{L2 norm of error vector} \\ \hline
\text{mish} & 1.49ms & 1.39ms & 2.4583e-05 \\ \hline
\text{relu} & 1.47ms & 1.39ms & \text{N/A} \\ \hline
\end{array}$$
Code Snippets
__device__ float mish(float x)
{
auto e = __expf(x);
if (x <= -18.0f)
return x * e;
auto n = e * e + 2 * e;
if (x <= -0.6f)
return x * __fdividef(n, n + 2);
return x - 2 * __fdividef(x, n + 2);
}__device__ float mish(float x)
{
auto e = __expf(value);
auto n = e * e + 2 * e;
if (value <= -0.6f)
return value * __fdividef(n, n + 2);
return value - 2 * __fdividef(value, n + 2);
}Context
StackExchange Computer Science Q#125002, answer score: 6
Revisions (0)
No revisions yet.