forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ActivationGeluKernel.cu
88 lines (76 loc) · 3.77 KB
/
ActivationGeluKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#define TORCH_ASSERT_NO_OPERATORS
#define _USE_MATH_DEFINES
#include <ATen/native/Activation.h>
#include <cmath>
#include <thrust/tuple.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/core/TensorBase.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/Loops.cuh>
namespace at::native {
void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
constexpr opmath_t kKappa = 0.044715;
auto x_cube = static_cast<opmath_t>(x) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
return opmath_t(0.5) * static_cast<opmath_t>(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner));
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
constexpr opmath_t kAlpha = M_SQRT1_2;
return static_cast<opmath_t>(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
});
});
}
}
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
constexpr opmath_t kKappa = 0.044715;
auto x_sq = static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
auto x_cube = x_sq * static_cast<opmath_t>(x);
auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
auto tanh_inner = c10::cuda::compat::tanh(inner);
auto left = opmath_t(0.5) * static_cast<opmath_t>(x);
auto right = opmath_t(1) + tanh_inner;
auto left_derivative = opmath_t(0.5) * right;
auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner;
auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;
return static_cast<opmath_t>(dy) * (left_derivative + right_derivative);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5);
constexpr opmath_t kAlpha = M_SQRT1_2;
const opmath_t cdf =
opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
const opmath_t pdf =
c10::cuda::compat::exp(
opmath_t(-0.5) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x)) *
kBeta;
return static_cast<opmath_t>(dy) * (cdf + static_cast<opmath_t>(x) * pdf);
});
});
}
}
} // namespace at::native