Path Tracer
BFloat16.h
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef EIGEN_BFLOAT16_H
17 #define EIGEN_BFLOAT16_H
18 
19 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
20  template <> \
21  EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
22  PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
23  return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
24  }
25 
26 namespace Eigen {
27 
28 struct bfloat16;
29 
30 namespace bfloat16_impl {
31 
32 // Make our own __bfloat16_raw definition.
34  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
35  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
36  unsigned short value;
37 };
38 
39 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
40 template <bool AssumeArgumentIsNormalOrInfinityOrZero>
41 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
42 // Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
43 // > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
44 template <>
45 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
46 template <>
47 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
48 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
49 
50 struct bfloat16_base : public __bfloat16_raw {
51  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
52  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
53 };
54 
55 } // namespace bfloat16_impl
56 
57 // Class definition.
59 
61 
62  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
63 
64  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
65 
66  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
67  : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
68 
69  template<class T>
70  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const T& val)
71  : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
72 
73  explicit EIGEN_DEVICE_FUNC bfloat16(float f)
74  : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
75 
76  // Following the convention of numpy, converting between complex and
77  // float will lead to loss of imag value.
78  template<typename RealScalar>
79  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
80  : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
81 
82  EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
83  return bfloat16_impl::bfloat16_to_float(*this);
84  }
85 };
86 } // namespace Eigen
87 
88 namespace std {
89 template<>
90 struct numeric_limits<Eigen::bfloat16> {
91  static const bool is_specialized = true;
92  static const bool is_signed = true;
93  static const bool is_integer = false;
94  static const bool is_exact = false;
95  static const bool has_infinity = true;
96  static const bool has_quiet_NaN = true;
97  static const bool has_signaling_NaN = true;
98  static const float_denorm_style has_denorm = std::denorm_absent;
99  static const bool has_denorm_loss = false;
100  static const std::float_round_style round_style = numeric_limits<float>::round_style;
101  static const bool is_iec559 = false;
102  static const bool is_bounded = true;
103  static const bool is_modulo = false;
104  static const int digits = 8;
105  static const int digits10 = 2;
106  static const int max_digits10 = 4;
107  static const int radix = 2;
108  static const int min_exponent = numeric_limits<float>::min_exponent;
109  static const int min_exponent10 = numeric_limits<float>::min_exponent10;
110  static const int max_exponent = numeric_limits<float>::max_exponent;
111  static const int max_exponent10 = numeric_limits<float>::max_exponent10;
112  static const bool traps = numeric_limits<float>::traps;
113  static const bool tinyness_before = numeric_limits<float>::tinyness_before;
114 
115  static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
116  static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
117  static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
118  static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
119  static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
120  static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
121  static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
122  static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
123  static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
124 };
125 
126 // If std::numeric_limits<T> is specialized, should also specialize
127 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
128 // std::numeric_limits<const volatile T>
129 // https://stackoverflow.com/a/16519653/
130 template<>
131 struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
132 template<>
133 struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
134 template<>
135 struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
136 } // namespace std
137 
138 namespace Eigen {
139 
140 namespace bfloat16_impl {
141 
142 // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
143 // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
144 // of the functions, while the latter can only deal with one of them.
145 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
146 
147 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
148 // We need to provide emulated *host-side* BF16 operators for clang.
149 #pragma push_macro("EIGEN_DEVICE_FUNC")
150 #undef EIGEN_DEVICE_FUNC
151 #if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
152 #define EIGEN_DEVICE_FUNC __host__
153 #else // both host and device need emulated ops.
154 #define EIGEN_DEVICE_FUNC __host__ __device__
155 #endif
156 #endif
157 
158 // Definitions for CPUs, mostly working through conversion
159 // to/from fp32.
160 
161 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
162  return bfloat16(float(a) + float(b));
163 }
164 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
165  return bfloat16(float(a) + static_cast<float>(b));
166 }
167 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
168  return bfloat16(static_cast<float>(a) + float(b));
169 }
170 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
171  return bfloat16(float(a) * float(b));
172 }
173 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
174  return bfloat16(float(a) - float(b));
175 }
176 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
177  return bfloat16(float(a) / float(b));
178 }
179 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
180  bfloat16 result;
181  result.value = a.value ^ 0x8000;
182  return result;
183 }
184 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
185  a = bfloat16(float(a) + float(b));
186  return a;
187 }
188 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
189  a = bfloat16(float(a) * float(b));
190  return a;
191 }
192 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
193  a = bfloat16(float(a) - float(b));
194  return a;
195 }
196 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
197  a = bfloat16(float(a) / float(b));
198  return a;
199 }
200 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
201  a += bfloat16(1);
202  return a;
203 }
204 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
205  a -= bfloat16(1);
206  return a;
207 }
208 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
209  bfloat16 original_value = a;
210  ++a;
211  return original_value;
212 }
213 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
214  bfloat16 original_value = a;
215  --a;
216  return original_value;
217 }
218 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
219  return numext::equal_strict(float(a),float(b));
220 }
221 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
222  return numext::not_equal_strict(float(a), float(b));
223 }
224 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
225  return float(a) < float(b);
226 }
227 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
228  return float(a) <= float(b);
229 }
230 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
231  return float(a) > float(b);
232 }
233 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
234  return float(a) >= float(b);
235 }
236 
237 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
238 #pragma pop_macro("EIGEN_DEVICE_FUNC")
239 #endif
240 #endif // Emulate support for bfloat16 floats
241 
242 // Division by an index. Do it in full float precision to avoid accuracy
243 // issues in converting the denominator to bfloat16.
244 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
245  return bfloat16(static_cast<float>(a) / static_cast<float>(b));
246 }
247 
248 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
249  __bfloat16_raw output;
250  if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
251  output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
252  return output;
253  } else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
254  // Flush denormal to +/- 0.
255  output.value = std::signbit(v) ? 0x8000 : 0;
256  return output;
257  }
258  const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
259 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
260  output.value = p[0];
261 #else
262  output.value = p[1];
263 #endif
264  return output;
265 }
266 
267 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
268  return __bfloat16_raw(value);
269 }
270 
271 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
272  return bf.value;
273 }
274 
275 // float_to_bfloat16_rtne template specialization that does not make any
276 // assumption about the value of its function argument (ff).
277 template <>
278 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
279 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
280  // Nothing to do here
281 #else
282  __bfloat16_raw output;
283 
284  if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
285  // If the value is a NaN, squash it to a qNaN with msb of fraction set,
286  // this makes sure after truncation we don't end up with an inf.
287  //
288  // qNaN magic: All exponent bits set + most significant bit of fraction
289  // set.
290  output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
291  } else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
292  // Flush denormal to +/- 0.0
293  output.value = std::signbit(ff) ? 0x8000 : 0;
294  } else {
295  // Fast rounding algorithm that rounds a half value to nearest even. This
296  // reduces expected error when we convert a large number of floats. Here
297  // is how it works:
298  //
299  // Definitions:
300  // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
301  // with the following tags:
302  //
303  // Sign | Exp (8 bits) | Frac (23 bits)
304  // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
305  //
306  // S: Sign bit.
307  // E: Exponent bits.
308  // F: First 6 bits of fraction.
309  // L: Least significant bit of resulting bfloat16 if we truncate away the
310  // rest of the float32. This is also the 7th bit of fraction
311  // R: Rounding bit, 8th bit of fraction.
312  // T: Sticky bits, rest of fraction, 15 bits.
313  //
314  // To round half to nearest even, there are 3 cases where we want to round
315  // down (simply truncate the result of the bits away, which consists of
316  // rounding bit and sticky bits) and two cases where we want to round up
317  // (truncate then add one to the result).
318  //
319  // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
320  // 1s) as the rounding bias, adds the rounding bias to the input, then
321  // truncates the last 16 bits away.
322  //
323  // To understand how it works, we can analyze this algorithm case by case:
324  //
325  // 1. L = 0, R = 0:
326  // Expect: round down, this is less than half value.
327  //
328  // Algorithm:
329  // - Rounding bias: 0x7fff + 0 = 0x7fff
330  // - Adding rounding bias to input may create any carry, depending on
331  // whether there is any value set to 1 in T bits.
332  // - R may be set to 1 if there is a carry.
333  // - L remains 0.
334  // - Note that this case also handles Inf and -Inf, where all fraction
335  // bits, including L, R and Ts are all 0. The output remains Inf after
336  // this algorithm.
337  //
338  // 2. L = 1, R = 0:
339  // Expect: round down, this is less than half value.
340  //
341  // Algorithm:
342  // - Rounding bias: 0x7fff + 1 = 0x8000
343  // - Adding rounding bias to input doesn't change sticky bits but
344  // adds 1 to rounding bit.
345  // - L remains 1.
346  //
347  // 3. L = 0, R = 1, all of T are 0:
348  // Expect: round down, this is exactly at half, the result is already
349  // even (L=0).
350  //
351  // Algorithm:
352  // - Rounding bias: 0x7fff + 0 = 0x7fff
353  // - Adding rounding bias to input sets all sticky bits to 1, but
354  // doesn't create a carry.
355  // - R remains 1.
356  // - L remains 0.
357  //
358  // 4. L = 1, R = 1:
359  // Expect: round up, this is exactly at half, the result needs to be
360  // round to the next even number.
361  //
362  // Algorithm:
363  // - Rounding bias: 0x7fff + 1 = 0x8000
364  // - Adding rounding bias to input doesn't change sticky bits, but
365  // creates a carry from rounding bit.
366  // - The carry sets L to 0, creates another carry bit and propagate
367  // forward to F bits.
368  // - If all the F bits are 1, a carry then propagates to the exponent
369  // bits, which then creates the minimum value with the next exponent
370  // value. Note that we won't have the case where exponents are all 1,
371  // since that's either a NaN (handled in the other if condition) or inf
372  // (handled in case 1).
373  //
374  // 5. L = 0, R = 1, any of T is 1:
375  // Expect: round up, this is greater than half.
376  //
377  // Algorithm:
378  // - Rounding bias: 0x7fff + 0 = 0x7fff
379  // - Adding rounding bias to input creates a carry from sticky bits,
380  // sets rounding bit to 0, then create another carry.
381  // - The second carry sets L to 1.
382  //
383  // Examples:
384  //
385  // Exact half value that is already even:
386  // Input:
387  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
388  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
389  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
390  //
391  // This falls into case 3. We truncate the rest of 16 bits and no
392  // carry is created into F and L:
393  //
394  // Output:
395  // Sign | Exp (8 bit) | Frac (first 7 bit)
396  // S E E E E E E E E F F F F F F L
397  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
398  //
399  // Exact half value, round to next even number:
400  // Input:
401  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
402  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
403  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
404  //
405  // This falls into case 4. We create a carry from R and T,
406  // which then propagates into L and F:
407  //
408  // Output:
409  // Sign | Exp (8 bit) | Frac (first 7 bit)
410  // S E E E E E E E E F F F F F F L
411  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
412  //
413  //
414  // Max denormal value round to min normal value:
415  // Input:
416  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
417  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
418  // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
419  //
420  // This falls into case 4. We create a carry from R and T,
421  // propagate into L and F, which then propagates into exponent
422  // bits:
423  //
424  // Output:
425  // Sign | Exp (8 bit) | Frac (first 7 bit)
426  // S E E E E E E E E F F F F F F L
427  // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
428  //
429  // Max normal value round to Inf:
430  // Input:
431  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
432  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
433  // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
434  //
435  // This falls into case 4. We create a carry from R and T,
436  // propagate into L and F, which then propagates into exponent
437  // bits:
438  //
439  // Sign | Exp (8 bit) | Frac (first 7 bit)
440  // S E E E E E E E E F F F F F F L
441  // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
442 
443  // At this point, ff must be either a normal float, or +/-infinity.
444  output = float_to_bfloat16_rtne<true>(ff);
445  }
446  return output;
447 #endif
448 }
449 
450 // float_to_bfloat16_rtne template specialization that assumes that its function
451 // argument (ff) is either a normal floating point number, or +/-infinity, or
452 // zero. Used to improve the runtime performance of conversion from an integer
453 // type to bfloat16.
454 template <>
455 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
456 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
457  // Nothing to do here
458 #else
459  numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
460  __bfloat16_raw output;
461 
462  // Least significant bit of resulting bfloat.
463  numext::uint32_t lsb = (input >> 16) & 1;
464  numext::uint32_t rounding_bias = 0x7fff + lsb;
465  input += rounding_bias;
466  output.value = static_cast<numext::uint16_t>(input >> 16);
467  return output;
468 #endif
469 }
470 
471 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
472  float result = 0;
473  unsigned short* q = reinterpret_cast<unsigned short*>(&result);
474 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
475  q[0] = h.value;
476 #else
477  q[1] = h.value;
478 #endif
479  return result;
480 }
481 // --- standard functions ---
482 
483 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
484  EIGEN_USING_STD(isinf);
485  return (isinf)(float(a));
486 }
487 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
488  EIGEN_USING_STD(isnan);
489  return (isnan)(float(a));
490 }
491 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
492  return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
493 }
494 
495 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
496  bfloat16 result;
497  result.value = a.value & 0x7FFF;
498  return result;
499 }
500 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
501  return bfloat16(::expf(float(a)));
502 }
503 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
504  return bfloat16(numext::expm1(float(a)));
505 }
506 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
507  return bfloat16(::logf(float(a)));
508 }
509 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
510  return bfloat16(numext::log1p(float(a)));
511 }
512 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
513  return bfloat16(::log10f(float(a)));
514 }
515 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
516  return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
517 }
518 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
519  return bfloat16(::sqrtf(float(a)));
520 }
521 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
522  return bfloat16(::powf(float(a), float(b)));
523 }
524 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
525  return bfloat16(::sinf(float(a)));
526 }
527 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
528  return bfloat16(::cosf(float(a)));
529 }
530 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
531  return bfloat16(::tanf(float(a)));
532 }
533 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
534  return bfloat16(::asinf(float(a)));
535 }
536 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
537  return bfloat16(::acosf(float(a)));
538 }
539 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
540  return bfloat16(::atanf(float(a)));
541 }
542 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
543  return bfloat16(::sinhf(float(a)));
544 }
545 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
546  return bfloat16(::coshf(float(a)));
547 }
548 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
549  return bfloat16(::tanhf(float(a)));
550 }
551 #if EIGEN_HAS_CXX11_MATH
552 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
553  return bfloat16(::asinhf(float(a)));
554 }
555 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
556  return bfloat16(::acoshf(float(a)));
557 }
558 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
559  return bfloat16(::atanhf(float(a)));
560 }
561 #endif
562 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
563  return bfloat16(::floorf(float(a)));
564 }
565 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
566  return bfloat16(::rintf(float(a)));
567 }
568 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
569  return bfloat16(::ceilf(float(a)));
570 }
571 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
572  return bfloat16(::fmodf(float(a), float(b)));
573 }
574 
575 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
576  const float f1 = static_cast<float>(a);
577  const float f2 = static_cast<float>(b);
578  return f2 < f1 ? b : a;
579 }
580 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
581  const float f1 = static_cast<float>(a);
582  const float f2 = static_cast<float>(b);
583  return f1 < f2 ? b : a;
584 }
585 
586 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
587  const float f1 = static_cast<float>(a);
588  const float f2 = static_cast<float>(b);
589  return bfloat16(::fminf(f1, f2));
590 }
591 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
592  const float f1 = static_cast<float>(a);
593  const float f2 = static_cast<float>(b);
594  return bfloat16(::fmaxf(f1, f2));
595 }
596 
597 #ifndef EIGEN_NO_IO
598 EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
599  os << static_cast<float>(v);
600  return os;
601 }
602 #endif
603 
604 } // namespace bfloat16_impl
605 
606 namespace internal {
607 
608 template<>
609 struct random_default_impl<bfloat16, false, false>
610 {
611  static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
612  {
613  return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
614  }
615  static inline bfloat16 run()
616  {
617  return run(bfloat16(-1.f), bfloat16(1.f));
618  }
619 };
620 
621 template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
622 
623 } // namespace internal
624 
625 template<> struct NumTraits<Eigen::bfloat16>
626  : GenericNumTraits<Eigen::bfloat16>
627 {
628  enum {
629  IsSigned = true,
630  IsInteger = false,
631  IsComplex = false,
632  RequireInitialization = false
633  };
634 
635  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
636  return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
637  }
638  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
639  return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
640 
641  }
642  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
643  return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
644  }
645  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
646  return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
647  }
648  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
649  return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
650  }
651  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
652  return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
653  }
654 };
655 
656 } // namespace Eigen
657 
658 namespace std {
659 
660 #if __cplusplus > 199711L
661 template <>
662 struct hash<Eigen::bfloat16> {
663  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
664  return hash<float>()(static_cast<float>(a));
665  }
666 };
667 #endif
668 
669 } // namespace std
670 
671 
672 namespace Eigen {
673 namespace numext {
674 
675 template<>
676 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
677 bool (isnan)(const Eigen::bfloat16& h) {
678  return (bfloat16_impl::isnan)(h);
679 }
680 
681 template<>
682 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
683 bool (isinf)(const Eigen::bfloat16& h) {
684  return (bfloat16_impl::isinf)(h);
685 }
686 
687 template<>
688 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
689 bool (isfinite)(const Eigen::bfloat16& h) {
690  return (bfloat16_impl::isfinite)(h);
691 }
692 
693 template <>
694 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
695  return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
696 }
697 
698 template <>
699 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
700  return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
701 }
702 
703 } // namespace numext
704 } // namespace Eigen
705 
706 #endif // EIGEN_BFLOAT16_H
Eigen
Namespace containing all symbols from the Eigen library.
Definition: LDLT.h:16
Eigen::bfloat16_impl::__bfloat16_raw
Definition: BFloat16.h:33
Eigen::operator*
EIGEN_DEVICE_FUNC const Product< MatrixDerived, PermutationDerived, AliasFreeProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:515
Eigen::bfloat16
Definition: BFloat16.h:58
Eigen::GenericNumTraits
Definition: NumTraits.h:144
Eigen::internal::random_default_impl
Definition: thirdparty/Eigen/src/Core/MathFunctions.h:718
Eigen::internal::is_integral
Definition: Meta.h:126
Eigen::bfloat16_impl::bfloat16_base
Definition: BFloat16.h:50
Eigen::NumTraits
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition: NumTraits.h:213
Eigen::internal::is_arithmetic
Definition: Meta.h:100
Eigen::Index
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:42