16 #ifndef EIGEN_BFLOAT16_H
17 #define EIGEN_BFLOAT16_H
19 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
21 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
22 PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
23 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
30 namespace bfloat16_impl {
35 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
__bfloat16_raw(
unsigned short raw) : value(raw) {}
39 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
__bfloat16_raw raw_uint16_to_bfloat16(
unsigned short value);
40 template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
41 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
__bfloat16_raw float_to_bfloat16_rtne(
float ff);
45 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
__bfloat16_raw float_to_bfloat16_rtne<false>(
float ff);
47 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
__bfloat16_raw float_to_bfloat16_rtne<true>(
float ff);
48 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(
__bfloat16_raw h);
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16() {}
66 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(
bool b)
70 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(
const T& val)
73 explicit EIGEN_DEVICE_FUNC
bfloat16(
float f)
78 template<
typename RealScalar>
79 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
bfloat16(
const std::complex<RealScalar>& val)
82 EIGEN_DEVICE_FUNC
operator float()
const {
83 return bfloat16_impl::bfloat16_to_float(*
this);
90 struct numeric_limits<
Eigen::bfloat16> {
91 static const bool is_specialized =
true;
92 static const bool is_signed =
true;
93 static const bool is_integer =
false;
94 static const bool is_exact =
false;
95 static const bool has_infinity =
true;
96 static const bool has_quiet_NaN =
true;
97 static const bool has_signaling_NaN =
true;
98 static const float_denorm_style has_denorm = std::denorm_absent;
99 static const bool has_denorm_loss =
false;
100 static const std::float_round_style round_style = numeric_limits<float>::round_style;
101 static const bool is_iec559 =
false;
102 static const bool is_bounded =
true;
103 static const bool is_modulo =
false;
104 static const int digits = 8;
105 static const int digits10 = 2;
106 static const int max_digits10 = 4;
107 static const int radix = 2;
108 static const int min_exponent = numeric_limits<float>::min_exponent;
109 static const int min_exponent10 = numeric_limits<float>::min_exponent10;
110 static const int max_exponent = numeric_limits<float>::max_exponent;
111 static const int max_exponent10 = numeric_limits<float>::max_exponent10;
112 static const bool traps = numeric_limits<float>::traps;
113 static const bool tinyness_before = numeric_limits<float>::tinyness_before;
115 static Eigen::bfloat16 (min)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
116 static Eigen::bfloat16 lowest() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
117 static Eigen::bfloat16 (max)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
118 static Eigen::bfloat16 epsilon() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
120 static Eigen::bfloat16 infinity() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
121 static Eigen::bfloat16 quiet_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
122 static Eigen::bfloat16 signaling_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
123 static Eigen::bfloat16 denorm_min() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
131 struct numeric_limits<const
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
133 struct numeric_limits<volatile
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
135 struct numeric_limits<const volatile
Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
140 namespace bfloat16_impl {
145 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
147 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
149 #pragma push_macro("EIGEN_DEVICE_FUNC")
150 #undef EIGEN_DEVICE_FUNC
151 #if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
152 #define EIGEN_DEVICE_FUNC __host__
153 #else // both host and device need emulated ops.
154 #define EIGEN_DEVICE_FUNC __host__ __device__
161 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const bfloat16& a,
const bfloat16& b) {
162 return bfloat16(
float(a) +
float(b));
164 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const bfloat16& a,
const int& b) {
165 return bfloat16(
float(a) +
static_cast<float>(b));
167 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (
const int& a,
const bfloat16& b) {
168 return bfloat16(
static_cast<float>(a) +
float(b));
170 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
operator * (
const bfloat16& a,
const bfloat16& b) {
171 return bfloat16(
float(a) *
float(b));
173 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (
const bfloat16& a,
const bfloat16& b) {
174 return bfloat16(
float(a) -
float(b));
176 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (
const bfloat16& a,
const bfloat16& b) {
177 return bfloat16(
float(a) /
float(b));
179 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (
const bfloat16& a) {
181 result.value = a.value ^ 0x8000;
184 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a,
const bfloat16& b) {
185 a = bfloat16(
float(a) +
float(b));
188 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a,
const bfloat16& b) {
189 a = bfloat16(
float(a) *
float(b));
192 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a,
const bfloat16& b) {
193 a = bfloat16(
float(a) -
float(b));
196 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a,
const bfloat16& b) {
197 a = bfloat16(
float(a) /
float(b));
200 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
204 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
208 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a,
int) {
209 bfloat16 original_value = a;
211 return original_value;
213 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a,
int) {
214 bfloat16 original_value = a;
216 return original_value;
218 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator == (
const bfloat16& a,
const bfloat16& b) {
219 return numext::equal_strict(
float(a),
float(b));
221 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator != (
const bfloat16& a,
const bfloat16& b) {
222 return numext::not_equal_strict(
float(a),
float(b));
224 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator < (
const bfloat16& a,
const bfloat16& b) {
225 return float(a) < float(b);
227 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator <= (
const bfloat16& a,
const bfloat16& b) {
228 return float(a) <= float(b);
230 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator > (
const bfloat16& a,
const bfloat16& b) {
231 return float(a) > float(b);
233 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator >= (
const bfloat16& a,
const bfloat16& b) {
234 return float(a) >= float(b);
237 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
238 #pragma pop_macro("EIGEN_DEVICE_FUNC")
240 #endif // Emulate support for bfloat16 floats
244 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (
const bfloat16& a,
Index b) {
245 return bfloat16(
static_cast<float>(a) /
static_cast<float>(b));
248 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(
const float v) {
249 __bfloat16_raw output;
250 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
251 output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
253 }
else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
255 output.value = std::signbit(v) ? 0x8000 : 0;
258 const uint16_t* p =
reinterpret_cast<const uint16_t*
>(&v);
259 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
267 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
268 return __bfloat16_raw(value);
271 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
const __bfloat16_raw& bf) {
278 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff) {
279 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
282 __bfloat16_raw output;
284 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
290 output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
291 }
else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
293 output.value = std::signbit(ff) ? 0x8000 : 0;
444 output = float_to_bfloat16_rtne<true>(ff);
455 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff) {
456 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
459 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
460 __bfloat16_raw output;
463 numext::uint32_t lsb = (input >> 16) & 1;
464 numext::uint32_t rounding_bias = 0x7fff + lsb;
465 input += rounding_bias;
466 output.value =
static_cast<numext::uint16_t
>(input >> 16);
471 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h) {
473 unsigned short* q =
reinterpret_cast<unsigned short*
>(&result);
474 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
483 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(
const bfloat16& a) {
484 EIGEN_USING_STD(isinf);
485 return (isinf)(float(a));
487 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(
const bfloat16& a) {
488 EIGEN_USING_STD(isnan);
489 return (isnan)(float(a));
491 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(
const bfloat16& a) {
492 return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
495 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(
const bfloat16& a) {
497 result.value = a.value & 0x7FFF;
500 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(
const bfloat16& a) {
501 return bfloat16(::expf(
float(a)));
503 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(
const bfloat16& a) {
504 return bfloat16(numext::expm1(
float(a)));
506 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(
const bfloat16& a) {
507 return bfloat16(::logf(
float(a)));
509 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(
const bfloat16& a) {
510 return bfloat16(numext::log1p(
float(a)));
512 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(
const bfloat16& a) {
513 return bfloat16(::log10f(
float(a)));
515 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(
const bfloat16& a) {
516 return bfloat16(
static_cast<float>(EIGEN_LOG2E) * ::logf(
float(a)));
518 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(
const bfloat16& a) {
519 return bfloat16(::sqrtf(
float(a)));
521 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(
const bfloat16& a,
const bfloat16& b) {
522 return bfloat16(::powf(
float(a),
float(b)));
524 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(
const bfloat16& a) {
525 return bfloat16(::sinf(
float(a)));
527 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(
const bfloat16& a) {
528 return bfloat16(::cosf(
float(a)));
530 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(
const bfloat16& a) {
531 return bfloat16(::tanf(
float(a)));
533 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(
const bfloat16& a) {
534 return bfloat16(::asinf(
float(a)));
536 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(
const bfloat16& a) {
537 return bfloat16(::acosf(
float(a)));
539 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(
const bfloat16& a) {
540 return bfloat16(::atanf(
float(a)));
542 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(
const bfloat16& a) {
543 return bfloat16(::sinhf(
float(a)));
545 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(
const bfloat16& a) {
546 return bfloat16(::coshf(
float(a)));
548 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(
const bfloat16& a) {
549 return bfloat16(::tanhf(
float(a)));
551 #if EIGEN_HAS_CXX11_MATH
552 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(
const bfloat16& a) {
553 return bfloat16(::asinhf(
float(a)));
555 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(
const bfloat16& a) {
556 return bfloat16(::acoshf(
float(a)));
558 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(
const bfloat16& a) {
559 return bfloat16(::atanhf(
float(a)));
562 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(
const bfloat16& a) {
563 return bfloat16(::floorf(
float(a)));
565 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(
const bfloat16& a) {
566 return bfloat16(::rintf(
float(a)));
568 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(
const bfloat16& a) {
569 return bfloat16(::ceilf(
float(a)));
571 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(
const bfloat16& a,
const bfloat16& b) {
572 return bfloat16(::fmodf(
float(a),
float(b)));
575 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(
const bfloat16& a,
const bfloat16& b) {
576 const float f1 =
static_cast<float>(a);
577 const float f2 =
static_cast<float>(b);
578 return f2 < f1 ? b : a;
580 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(
const bfloat16& a,
const bfloat16& b) {
581 const float f1 =
static_cast<float>(a);
582 const float f2 =
static_cast<float>(b);
583 return f1 < f2 ? b : a;
586 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(
const bfloat16& a,
const bfloat16& b) {
587 const float f1 =
static_cast<float>(a);
588 const float f2 =
static_cast<float>(b);
589 return bfloat16(::fminf(f1, f2));
591 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(
const bfloat16& a,
const bfloat16& b) {
592 const float f1 =
static_cast<float>(a);
593 const float f2 =
static_cast<float>(b);
594 return bfloat16(::fmaxf(f1, f2));
598 EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os,
const bfloat16& v) {
599 os << static_cast<float>(v);
613 return x + (y-x) *
bfloat16(
float(std::rand()) / float(RAND_MAX));
632 RequireInitialization =
false
635 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 epsilon() {
636 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
638 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 dummy_precision() {
639 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
642 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 highest() {
643 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
645 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 lowest() {
646 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
648 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 infinity() {
649 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
651 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE
Eigen::bfloat16 quiet_NaN() {
652 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
660 #if __cplusplus > 199711L
662 struct hash<
Eigen::bfloat16> {
663 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(
const Eigen::bfloat16& a)
const {
664 return hash<float>()(
static_cast<float>(a));
676 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
678 return (bfloat16_impl::isnan)(h);
682 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
684 return (bfloat16_impl::isinf)(h);
688 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
690 return (bfloat16_impl::isfinite)(h);
694 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src) {
695 return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
699 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src) {
700 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
706 #endif // EIGEN_BFLOAT16_H