16 #ifndef dealii_differentiation_ad_ad_number_traits_h 17 #define dealii_differentiation_ad_ad_number_traits_h 26 #include <boost/type_traits.hpp> 29 #include <type_traits> 50 template <
typename ScalarType,
71 template <
typename ADNumberType,
typename T =
void>
113 template <
typename ScalarType,
150 template <
typename ADNumberType,
typename T =
void>
185 template <
typename ADNumberType,
typename T =
void>
203 template <
typename ADNumberTrait,
typename T =
void>
223 template <
typename T>
233 template <
typename Number>
247 template <
typename NumberType>
259 template <
typename NumberType,
typename =
void>
271 template <
typename NumberType,
typename =
void>
283 template <
typename NumberType,
typename =
void>
295 template <
typename NumberType,
typename =
void>
314 template <
typename ADNumberTrait,
typename>
315 struct HasRequiredADInfo : std::false_type
330 template <
typename ADNumberTrait>
331 struct HasRequiredADInfo<
333 decltype((void)ADNumberTrait::type_code,
334 (void)ADNumberTrait::is_taped,
335 (void)std::declval<typename ADNumberTrait::real_type>(),
336 (void)std::declval<typename ADNumberTrait::derivative_type>(),
339 std::is_floating_point<typename ADNumberTrait::real_type>::value,
341 std::true_type>::type
349 template <
typename ScalarType>
350 struct ADNumberInfoFromEnum<
353 typename std::enable_if<
354 std::is_floating_point<ScalarType>::value>::type>
356 static const bool is_taped =
false;
357 using real_type = ScalarType;
358 using derivative_type = ScalarType;
359 static const unsigned int n_supported_derivative_levels = 0;
368 template <
typename ScalarType>
369 struct Marking<ScalarType,
370 typename
std::enable_if<
371 std::is_floating_point<ScalarType>::value>::type>
380 template <
typename ADNumberType>
382 independent_variable(
const ScalarType &in,
393 template <
typename ADNumberType>
395 dependent_variable(ADNumberType &,
const ScalarType &)
400 "Floating point numbers cannot be marked as dependent variables."));
408 template <
typename ADNumberType>
411 typename
std::enable_if<boost::is_complex<ADNumberType>::value>::type>
416 template <
typename ScalarType>
418 independent_variable(
const ScalarType &in,
426 "Marking for complex numbers has not yet been implemented."));
433 template <
typename ScalarType>
435 dependent_variable(ADNumberType &,
const ScalarType &)
440 "Marking for complex numbers has not yet been implemented."));
447 template <
typename NumberType,
typename>
452 template <
typename NumberType,
typename>
457 template <
typename NumberType,
typename>
462 template <
typename NumberType,
typename>
472 template <
typename NumberType>
475 ADNumberTraits<typename std::decay<NumberType>::type>>
483 template <
typename NumberType>
486 typename
std::enable_if<
487 ADNumberTraits<typename std::decay<NumberType>::type>::is_taped>::type>
496 template <
typename NumberType>
499 typename std::enable_if<ADNumberTraits<
500 typename std::decay<NumberType>::type>::is_tapeless>::type>
510 template <
typename NumberType>
513 typename std::enable_if<ADNumberTraits<
514 typename std::decay<NumberType>::type>::is_real_valued>::type>
524 template <
typename NumberType>
527 typename std::enable_if<ADNumberTraits<
528 typename std::decay<NumberType>::type>::is_complex_valued>::type>
539 template <
typename Number>
540 struct RemoveComplexWrapper
551 template <
typename Number>
552 struct RemoveComplexWrapper<std::complex<Number>>
554 using type =
typename RemoveComplexWrapper<Number>::type;
563 template <
typename NumberType>
564 struct ExtractData<NumberType,
565 typename std::enable_if<
566 std::is_floating_point<NumberType>::value>::type>
571 static const NumberType &
572 value(
const NumberType &x)
582 n_directional_derivatives(
const NumberType &)
592 directional_derivative(
const NumberType &,
const unsigned int)
604 template <
typename ADNumberType>
605 struct ExtractData<std::complex<ADNumberType>>
608 "Expected an auto-differentiable number.");
614 static std::complex<typename ADNumberTraits<ADNumberType>::scalar_type>
615 value(
const std::complex<ADNumberType> &x)
628 n_directional_derivatives(
const std::complex<ADNumberType> &x)
630 return ExtractData<ADNumberType>::n_directional_derivatives(x.real());
639 directional_derivative(
const std::complex<ADNumberType> &x,
640 const unsigned int direction)
644 ExtractData<ADNumberType>::directional_derivative(x.real(),
646 ExtractData<ADNumberType>::directional_derivative(x.imag(),
652 template <
typename T>
658 template <
typename F>
675 template <
typename F>
694 template <
typename F>
704 template <
typename T>
705 struct NumberType<std::complex<T>>
710 template <
typename F>
719 return ::internal::NumberType<std::complex<T>>
::value(f);
727 template <
typename F>
728 static std::complex<T>
737 return std::complex<T>(
741 template <
typename F>
742 static std::complex<T>
743 value(
const std::complex<F> &f)
776 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
780 typename std::enable_if<
781 std::is_floating_point<ScalarType>::value ||
782 (boost::is_complex<ScalarType>::value &&
783 std::is_floating_point<typename internal::RemoveComplexWrapper<
784 ScalarType>::type>::value)>::type>
789 static constexpr
enum NumberTypes type_code = ADNumberTypeCode;
803 static const bool is_taped;
810 static const bool is_tapeless;
817 static const bool is_real_valued;
824 static const bool is_complex_valued;
831 static const unsigned int n_supported_derivative_levels;
841 ADNumberTypeCode>::is_taped;
848 static constexpr
bool is_tapeless =
856 static constexpr
bool is_real_valued =
864 static constexpr
bool is_complex_valued =
872 static constexpr
unsigned int n_supported_derivative_levels =
874 typename internal::RemoveComplexWrapper<ScalarType>::type,
875 ADNumberTypeCode>::n_supported_derivative_levels;
884 using scalar_type = ScalarType;
891 typename internal::RemoveComplexWrapper<ScalarType>::type,
892 ADNumberTypeCode>::real_type;
898 using complex_type = std::complex<real_type>;
904 using ad_type =
typename std::
905 conditional<is_real_valued, real_type, complex_type>::type;
910 using derivative_type =
typename std::conditional<
913 typename internal::RemoveComplexWrapper<ScalarType>::type,
914 ADNumberTypeCode>::derivative_type,
916 typename internal::RemoveComplexWrapper<ScalarType>::type,
917 ADNumberTypeCode>::derivative_type>>::type;
923 static scalar_type get_scalar_value(
const ad_type &x)
939 static derivative_type get_directional_derivative(
940 const ad_type &x,
const unsigned int direction)
951 static unsigned int n_directional_derivatives(
const ad_type &x)
957 static_assert((is_real_valued ==
true ?
960 "Incorrect template type selected for ad_type");
962 static_assert((is_complex_valued ==
true ?
965 "Expected a complex float_type");
967 static_assert((is_complex_valued ==
true ?
970 "Expected a complex ad_type");
975 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
979 typename std::enable_if<
983 ScalarType>::type>
::value)>::type>::is_taped =
985 typename internal::RemoveComplexWrapper<ScalarType>::type,
986 ADNumberTypeCode>::is_taped;
989 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
993 typename std::enable_if<
997 ScalarType>::type>
::value)>::type>::is_tapeless =
1001 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1005 typename std::enable_if<
1009 ScalarType>::type>
::value)>::type>::is_real_valued =
1013 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1017 typename std::enable_if<
1021 ScalarType>::type>
::value)>::type>::is_complex_valued =
1025 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1029 typename std::enable_if<
1033 ScalarType>::type>
::value)>::type>::n_supported_derivative_levels =
1035 typename internal::RemoveComplexWrapper<ScalarType>::type,
1036 ADNumberTypeCode>::n_supported_derivative_levels;
1040 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1044 typename std::enable_if<
1048 ScalarType>::type>
::value)>::type>::is_taped;
1051 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1055 typename std::enable_if<
1059 ScalarType>::type>
::value)>::type>::is_tapeless;
1062 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1066 typename std::enable_if<
1070 ScalarType>::type>
::value)>::type>::is_real_valued;
1073 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1077 typename std::enable_if<
1081 ScalarType>::type>
::value)>::type>::is_complex_valued;
1084 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1088 typename std::enable_if<
1092 ScalarType>::type>
::value)>::type>::n_supported_derivative_levels;
1107 template <
typename ScalarType>
1111 typename std::enable_if<
1112 std::is_floating_point<ScalarType>::value ||
1113 (boost::is_complex<ScalarType>::value &&
1114 std::is_floating_point<typename internal::RemoveComplexWrapper<
1115 ScalarType>::type>::value)>::type>
1134 static const bool is_taped;
1141 static const bool is_tapeless;
1148 static const bool is_real_valued;
1155 static const bool is_complex_valued;
1162 static const unsigned int n_supported_derivative_levels;
1170 static constexpr
bool is_taped =
false;
1177 static constexpr
bool is_tapeless =
false;
1184 static constexpr
bool is_real_valued =
1192 static constexpr
bool is_complex_valued = !is_real_valued;
1199 static constexpr
unsigned int n_supported_derivative_levels = 0;
1208 using scalar_type = ScalarType;
1215 typename ::numbers::NumberTraits<scalar_type>::real_type;
1221 using complex_type = std::complex<real_type>;
1227 using ad_type = ScalarType;
1232 using derivative_type = ScalarType;
1239 get_scalar_value(
const ad_type &x)
1248 static derivative_type
1249 get_directional_derivative(
const ad_type &,
const unsigned int)
1254 "Floating point/arithmetic numbers have no directional derivatives."));
1255 return derivative_type();
1264 n_directional_derivatives(
const ad_type &)
1269 "Floating point/arithmetic numbers have no directional derivatives."));
1276 template <
typename ScalarType>
1280 typename std::enable_if<
1284 ScalarType>::type>
::value)>::type>::is_taped =
false;
1287 template <
typename ScalarType>
1291 typename std::enable_if<
1295 ScalarType>::type>
::value)>::type>::is_tapeless =
false;
1298 template <
typename ScalarType>
1302 typename std::enable_if<
1306 ScalarType>::type>
::value)>::type>::is_real_valued =
1310 template <
typename ScalarType>
1314 typename std::enable_if<
1318 ScalarType>::type>
::value)>::type>::is_complex_valued =
1322 template <
typename ScalarType>
1326 typename std::enable_if<
1330 ScalarType>::type>
::value)>::type>::n_supported_derivative_levels =
1335 template <
typename ScalarType>
1339 typename std::enable_if<
1343 ScalarType>::type>
::value)>::type>::is_taped;
1346 template <
typename ScalarType>
1350 typename std::enable_if<
1354 ScalarType>::type>
::value)>::type>::is_tapeless;
1357 template <
typename ScalarType>
1361 typename std::enable_if<
1365 ScalarType>::type>
::value)>::type>::is_real_valued;
1368 template <
typename ScalarType>
1372 typename std::enable_if<
1376 ScalarType>::type>
::value)>::type>::is_complex_valued;
1379 template <
typename ScalarType>
1383 typename std::enable_if<
1387 ScalarType>::type>
::value)>::type>::n_supported_derivative_levels;
1410 template <
typename ScalarType>
1413 typename std::enable_if<std::is_floating_point<ScalarType>::value>::type>
1420 template <
typename ComplexScalarType>
1423 typename std::enable_if<
1424 boost::is_complex<ComplexScalarType>::value &&
1425 std::is_floating_point<typename ComplexScalarType::value_type>::value>::
1426 type> :
NumberTraits<ComplexScalarType, NumberTypes::none>
static constexpr const T & value(const T &t)
#define AssertThrow(cond, exc)
Tensor< 2, dim, Number > F(const Tensor< 2, dim, Number > &Grad_u)
static ::ExceptionBase & ExcMessage(std::string arg1)
#define Assert(cond, exc)
#define DEAL_II_NAMESPACE_CLOSE
#define DEAL_II_NAMESPACE_OPEN