16 #ifndef dealii_differentiation_sd_symengine_number_visitor_internal_h 17 #define dealii_differentiation_sd_symengine_number_visitor_internal_h 21 #ifdef DEAL_II_WITH_SYMENGINE 25 # include <symengine/basic.h> 26 # include <symengine/dict.h> 27 # include <symengine/symengine_exception.h> 28 # include <symengine/symengine_rcp.h> 31 # include <symengine/visitor.h> 41 # include <boost/serialization/split_member.hpp> 62 template <
typename ReturnType,
typename ExpressionType>
66 std::vector<std::pair<SD::Expression, SD::Expression>>;
132 call(ReturnType * output_values,
134 const ReturnType * substitution_values);
140 template <
class Archive>
142 save(Archive &archive,
const unsigned int version)
const;
148 template <
class Archive>
150 load(Archive &archive,
const unsigned int version);
157 template <
class Archive>
159 serialize(Archive &archive,
const unsigned int version);
163 BOOST_SERIALIZATION_SPLIT_MEMBER()
171 template <
typename StreamType>
173 print(StreamType &stream)
const;
207 init(
const SymEngine::vec_basic &dependent_functions);
234 call(ReturnType * output_values,
235 const SymEngine::vec_basic &independent_symbols,
236 const ReturnType * substitution_values);
269 template <
typename ReturnType,
typename ExpressionType>
271 :
public SymEngine::BaseVisitor<
272 DictionarySubstitutionVisitor<ReturnType, ExpressionType>>
311 const bool use_cse =
false);
337 init(
const SymEngine::vec_basic &independent_symbols,
338 const SymEngine::Basic & dependent_function,
339 const bool use_cse =
false);
364 const bool use_cse =
false);
390 init(
const SymEngine::vec_basic &independent_symbols,
391 const SymEngine::vec_basic &dependent_functions,
392 const bool use_cse =
false);
427 call(ReturnType *output_values,
const ReturnType *substitution_values);
447 call(
const std::vector<ReturnType> &substitution_values);
453 template <
class Archive>
455 save(Archive &archive,
const unsigned int version)
const;
461 template <
class Archive>
463 load(Archive &archive,
const unsigned int version);
470 template <
class Archive>
472 serialize(Archive &archive,
const unsigned int version);
476 BOOST_SERIALIZATION_SPLIT_MEMBER()
492 template <
typename StreamType>
494 print(StreamType &stream,
495 const bool print_independent_symbols =
false,
496 const bool print_dependent_functions =
false,
497 const bool print_cse_reductions =
false)
const;
505 # define IMPLEMENT_DSV_BVISIT(Argument) \ 506 void bvisit(const Argument &) \ 508 AssertThrow(false, ExcNotImplemented()); \ 511 IMPLEMENT_DSV_BVISIT(SymEngine::Basic)
512 IMPLEMENT_DSV_BVISIT(SymEngine::Symbol)
513 IMPLEMENT_DSV_BVISIT(SymEngine::Constant)
514 IMPLEMENT_DSV_BVISIT(SymEngine::Integer)
515 IMPLEMENT_DSV_BVISIT(SymEngine::Rational)
516 IMPLEMENT_DSV_BVISIT(SymEngine::RealDouble)
517 IMPLEMENT_DSV_BVISIT(SymEngine::ComplexDouble)
518 IMPLEMENT_DSV_BVISIT(SymEngine::Add)
519 IMPLEMENT_DSV_BVISIT(SymEngine::Mul)
520 IMPLEMENT_DSV_BVISIT(SymEngine::Pow)
521 IMPLEMENT_DSV_BVISIT(SymEngine::Log)
522 IMPLEMENT_DSV_BVISIT(SymEngine::Sin)
523 IMPLEMENT_DSV_BVISIT(SymEngine::Cos)
524 IMPLEMENT_DSV_BVISIT(SymEngine::Tan)
525 IMPLEMENT_DSV_BVISIT(SymEngine::Csc)
526 IMPLEMENT_DSV_BVISIT(SymEngine::Sec)
527 IMPLEMENT_DSV_BVISIT(SymEngine::Cot)
528 IMPLEMENT_DSV_BVISIT(SymEngine::ASin)
529 IMPLEMENT_DSV_BVISIT(SymEngine::ACos)
530 IMPLEMENT_DSV_BVISIT(SymEngine::ATan)
531 IMPLEMENT_DSV_BVISIT(SymEngine::ATan2)
532 IMPLEMENT_DSV_BVISIT(SymEngine::ACsc)
533 IMPLEMENT_DSV_BVISIT(SymEngine::ASec)
534 IMPLEMENT_DSV_BVISIT(SymEngine::ACot)
535 IMPLEMENT_DSV_BVISIT(SymEngine::Sinh)
536 IMPLEMENT_DSV_BVISIT(SymEngine::Cosh)
537 IMPLEMENT_DSV_BVISIT(SymEngine::Tanh)
538 IMPLEMENT_DSV_BVISIT(SymEngine::Csch)
539 IMPLEMENT_DSV_BVISIT(SymEngine::Sech)
540 IMPLEMENT_DSV_BVISIT(SymEngine::Coth)
541 IMPLEMENT_DSV_BVISIT(SymEngine::ASinh)
542 IMPLEMENT_DSV_BVISIT(SymEngine::ACosh)
543 IMPLEMENT_DSV_BVISIT(SymEngine::ATanh)
544 IMPLEMENT_DSV_BVISIT(SymEngine::ACsch)
545 IMPLEMENT_DSV_BVISIT(SymEngine::ACoth)
546 IMPLEMENT_DSV_BVISIT(SymEngine::ASech)
547 IMPLEMENT_DSV_BVISIT(SymEngine::Abs)
548 IMPLEMENT_DSV_BVISIT(SymEngine::Gamma)
549 IMPLEMENT_DSV_BVISIT(SymEngine::LogGamma)
550 IMPLEMENT_DSV_BVISIT(SymEngine::Erf)
551 IMPLEMENT_DSV_BVISIT(SymEngine::Erfc)
552 IMPLEMENT_DSV_BVISIT(SymEngine::Max)
553 IMPLEMENT_DSV_BVISIT(SymEngine::Min)
555 # undef IMPLEMENT_DSV_BVISIT 593 template <
typename ReturnType,
typename ExpressionType>
599 dependent_functions));
604 template <
typename ReturnType,
typename ExpressionType>
607 const SymEngine::vec_basic &dependent_functions)
625 SymEngine::vec_pair se_replacements;
626 SymEngine::vec_basic se_reduced_exprs;
627 SymEngine::cse(se_replacements, se_reduced_exprs, dependent_functions);
638 template <
typename ReturnType,
typename ExpressionType>
641 ReturnType * output_values,
643 const ReturnType * substitution_values)
647 independent_symbols),
648 substitution_values);
653 template <
typename ReturnType,
typename ExpressionType>
656 ReturnType * output_values,
657 const SymEngine::vec_basic &independent_symbols,
658 const ReturnType * substitution_values)
663 SymEngine::map_basic_basic substitution_value_map;
664 for (
unsigned i = 0; i < independent_symbols.size(); ++i)
665 substitution_value_map[independent_symbols[i]] =
666 static_cast<const SymEngine::RCP<const SymEngine::Basic> &
>(
667 ExpressionType(substitution_values[i]));
673 const SymEngine::RCP<const SymEngine::Basic> &cse_symbol =
675 const SymEngine::RCP<const SymEngine::Basic> &cse_expr =
677 Assert(substitution_value_map.find(cse_symbol) ==
678 substitution_value_map.end(),
680 "Reduced symbol already appears in substitution map. " 681 "Is there a clash between the reduced symbol name and " 682 "the symbol used for an independent variable?"));
683 substitution_value_map[cse_symbol] =
684 static_cast<const SymEngine::RCP<const SymEngine::Basic> &
>(
685 ExpressionType(ExpressionType(cse_expr)
686 .
template substitute_and_evaluate<ReturnType>(
687 substitution_value_map)));
693 .template substitute_and_evaluate<ReturnType>(
694 substitution_value_map);
699 template <
typename ReturnType,
typename ExpressionType>
700 template <
class Archive>
704 const unsigned int )
const 714 template <
typename ReturnType,
typename ExpressionType>
715 template <
class Archive>
732 template <
typename ReturnType,
typename ExpressionType>
733 template <
typename StreamType>
736 StreamType &stream)
const 738 stream <<
"Common subexpression elimination: \n";
739 stream <<
" Intermediate reduced expressions: \n";
742 const SymEngine::RCP<const SymEngine::Basic> &cse_symbol =
744 const SymEngine::RCP<const SymEngine::Basic> &cse_expr =
746 stream <<
" " << i <<
": " << cse_symbol <<
" = " << cse_expr
750 stream <<
" Final reduced expressions for dependent variables: \n";
754 stream << std::flush;
759 template <
typename ReturnType,
typename ExpressionType>
773 template <
typename ReturnType,
typename ExpressionType>
783 template <
typename ReturnType,
typename ExpressionType>
796 template <
typename ReturnType,
typename ExpressionType>
808 template <
typename ReturnType,
typename ExpressionType>
811 const SymEngine::vec_basic &inputs,
812 const SymEngine::Basic & output,
822 template <
typename ReturnType,
typename ExpressionType>
825 const SymEngine::vec_basic &inputs,
826 const SymEngine::vec_basic &outputs,
836 template <
typename ReturnType,
typename ExpressionType>
843 independent_symbols.clear();
844 dependent_functions.clear();
846 independent_symbols = inputs;
853 if (use_cse ==
false)
854 dependent_functions = outputs;
863 template <
typename ReturnType,
typename ExpressionType>
866 const std::vector<ReturnType> &substitution_values)
869 dependent_functions.size() == 1,
871 "Cannot use this call function when more than one symbolic expression is to be evaluated."));
873 substitution_values.size() == independent_symbols.size(),
875 "Input substitution vector does not match size of symbol vector."));
878 call(&out, substitution_values.data());
884 template <
typename ReturnType,
typename ExpressionType>
887 ReturnType * output_values,
888 const ReturnType *substitution_values)
893 cse.call(output_values, independent_symbols, substitution_values);
898 SymEngine::map_basic_basic substitution_value_map;
899 for (
unsigned i = 0; i < independent_symbols.size(); ++i)
900 substitution_value_map[independent_symbols[i]] =
901 static_cast<const SymEngine::RCP<const SymEngine::Basic> &
>(
902 ExpressionType(substitution_values[i]));
909 for (
unsigned i = 0; i < dependent_functions.size(); ++i)
911 ExpressionType(dependent_functions[i])
912 .template substitute_and_evaluate<ReturnType>(
913 substitution_value_map);
919 template <
typename ReturnType,
typename ExpressionType>
920 template <
class Archive>
924 const unsigned int version)
const 935 ar &independent_symbols;
936 cse.save(ar, version);
937 ar &dependent_functions;
942 template <
typename ReturnType,
typename ExpressionType>
943 template <
class Archive>
947 const unsigned int version)
957 ar &independent_symbols;
958 cse.load(ar, version);
959 ar &dependent_functions;
964 template <
typename ReturnType,
typename ExpressionType>
965 template <
typename StreamType>
969 const bool print_independent_symbols,
970 const bool print_dependent_functions,
971 const bool print_cse_reductions)
const 973 if (print_independent_symbols)
975 stream <<
"Independent variables: \n";
976 for (
unsigned i = 0; i < independent_symbols.size(); ++i)
977 stream <<
" " << i <<
": " << independent_symbols[i] <<
"\n";
979 stream << std::flush;
983 if (print_cse_reductions && cse.executed())
991 if (print_dependent_functions)
993 stream <<
"Dependent variables: \n";
994 for (
unsigned i = 0; i < dependent_functions.size(); ++i)
995 stream <<
" " << i << dependent_functions[i] <<
"\n";
997 stream << std::flush;
1011 #endif // DEAL_II_WITH_SYMENGINE 1013 #endif // dealii_differentiation_sd_symengine_number_visitor_internal_h std::vector< std::pair< SD::Expression, SD::Expression > > symbol_vector_pair
virtual ~CSEDictionaryVisitor()=default
void init(const types::symbol_vector &dependent_functions)
static constexpr const T & value(const T &t)
symbol_vector_pair intermediate_symbols_exprs
void init(const types::symbol_vector &independent_symbols, const Expression &dependent_function, const bool use_cse=false)
SD::types::symbol_vector convert_basic_vector_to_expression_vector(const SymEngine::vec_basic &symbol_vector)
void load(Archive &archive, const unsigned int version)
unsigned int n_intermediate_expressions() const
void save(Archive &archive, const unsigned int version) const
std::vector< std::pair< Expression, Expression > > convert_basic_pair_vector_to_expression_pair_vector(const SymEngine::vec_pair &symbol_value_vector)
#define DEAL_II_DISABLE_EXTRA_DIAGNOSTICS
void print(StreamType &stream) const
unsigned int n_reduced_expressions() const
static ::ExceptionBase & ExcMessage(std::string arg1)
types::symbol_vector reduced_exprs
#define Assert(cond, exc)
SD::types::symbol_vector independent_symbols
#define DEAL_II_NAMESPACE_CLOSE
void call(ReturnType *output_values, const types::symbol_vector &independent_symbols, const ReturnType *substitution_values)
SymEngine::vec_basic convert_expression_vector_to_basic_vector(const SD::types::symbol_vector &symbol_vector)
void serialize(Archive &archive, const unsigned int version)
std::vector< SD::Expression > symbol_vector
CSEDictionaryVisitor< ReturnType, ExpressionType > cse
void load(Archive &archive, const unsigned int version)
void call(ReturnType *output_values, const ReturnType *substitution_values)
#define DEAL_II_ENABLE_EXTRA_DIAGNOSTICS
#define DEAL_II_NAMESPACE_OPEN
CSEDictionaryVisitor()=default
void print(StreamType &stream, const bool print_independent_symbols=false, const bool print_dependent_functions=false, const bool print_cse_reductions=false) const
SD::types::symbol_vector dependent_functions
static ::ExceptionBase & ExcInternalError()
void save(Archive &archive, const unsigned int version) const