Skip to content

Commit

Permalink
Merge pull request #282 from PowerGridModel/feature/remove-enable-if
Browse files Browse the repository at this point in the history
Feature/remove enable if
  • Loading branch information
TonyXiang8787 committed Jun 26, 2023
2 parents 70ed5ce + 8203849 commit e93174d
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 134 deletions.
5 changes: 3 additions & 2 deletions docs/advanced_documentation/build-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ You need a C++ compiler with C++20 support. Below is a list of tested compilers:

**Linux**

* gcc >= 10.0
* Version 10.0 tested using the version in the `manylinux2014` container.
* gcc >= 11.0
* Version 12.x tested using the version in the `manylinux_2_28` container.
* Version 12.x tested using the musllinux build with custom compiler
* Version 11.x tested in CI
* Clang >= 14.0
* Version 14.x tested in CI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DataPointer {

// conversion to const iterator
template <class UX = DataPointer<true>>
operator std::enable_if_t<!is_const, UX>() const {
requires(!is_const) operator UX() const {
return DataPointer<true>{ptr_, indptr_, batch_size_, elements_per_scenario_};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ struct DataAttribute {
};

template <class BaseType, auto member_ptr>
inline std::enable_if_t<std::is_same_v<BaseType, typename trait_pointer_to_member<decltype(member_ptr)>::struct_type>,
size_t>
get_offset() {
requires std::same_as<BaseType, typename trait_pointer_to_member<decltype(member_ptr)>::struct_type>
inline size_t get_offset() {
using struct_type = typename trait_pointer_to_member<decltype(member_ptr)>::struct_type;
struct_type const obj{};
return (size_t)(&(obj.*member_ptr)) - (size_t)&obj;
Expand All @@ -149,9 +148,8 @@ constexpr bool is_little_endian() {
}

template <class BaseType, auto member_ptr>
inline std::enable_if_t<std::is_same_v<BaseType, typename trait_pointer_to_member<decltype(member_ptr)>::struct_type>,
DataAttribute>
get_data_attribute(std::string const& name) {
requires std::same_as<BaseType, typename trait_pointer_to_member<decltype(member_ptr)>::struct_type>
inline DataAttribute get_data_attribute(std::string_view const& name) {
using value_type = typename trait_pointer_to_member<decltype(member_ptr)>::value_type;
using single_data_type = data_type<value_type>;
DataAttribute attr{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ namespace meta_data {
// template function to add meta data
template <class CT>
void add_meta_data(AllPowerGridMetaData& meta) {
// TODO, remove this separate definition for UpdateType after migrating to gcc-11
// this is due to a wired bug in gcc-10
using UpdateType = typename CT::UpdateType;
meta["input"][CT::name] = get_meta<typename CT::InputType>{}();
meta["update"][CT::name] = get_meta<UpdateType>{}();
meta["update"][CT::name] = get_meta<typename CT::UpdateType>{}();
meta["sym_output"][CT::name] = get_meta<typename CT::template OutputType<true>>{}();
meta["asym_output"][CT::name] = get_meta<typename CT::template OutputType<false>>{}();
meta["sc_output"][CT::name] = get_meta<typename CT::ShortCircuitOutputType>{}();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ class Container<RetrievableTypes<GettableTypes...>, StorageableTypes...> {

// get item per type
template <class GettableBaseType, class StorageableSubType>
GettableBaseType& get_raw(Idx pos) {
static_assert(std::is_base_of_v<GettableBaseType, StorageableSubType>);
requires std::derived_from<StorageableSubType, GettableBaseType> GettableBaseType& get_raw(Idx pos) {
return std::get<std::vector<StorageableSubType>>(vectors_)[pos];
}
template <class GettableBaseType, class StorageableSubType>
GettableBaseType const& get_raw(Idx pos) const {
static_assert(std::is_base_of_v<GettableBaseType, StorageableSubType>);
requires std::derived_from<StorageableSubType, GettableBaseType> GettableBaseType const& get_raw(Idx pos) const {
return std::get<std::vector<StorageableSubType>>(vectors_)[pos];
}

Expand All @@ -248,8 +246,8 @@ class Container<RetrievableTypes<GettableTypes...>, StorageableTypes...> {
static constexpr GetItemFuncPtrConst<GettableBaseType> ptr_const = nullptr;
};
template <class GettableBaseType, class StorageableSubType>
struct select_get_item_func_ptr<GettableBaseType, StorageableSubType,
std::enable_if_t<std::is_base_of_v<GettableBaseType, StorageableSubType>>> {
requires std::derived_from<StorageableSubType, GettableBaseType>
struct select_get_item_func_ptr<GettableBaseType, StorageableSubType> {
static constexpr GetItemFuncPtr<GettableBaseType> ptr =
&Container::get_raw<GettableBaseType, StorageableSubType>;
static constexpr GetItemFuncPtrConst<GettableBaseType> ptr_const =
Expand Down Expand Up @@ -309,7 +307,7 @@ class Container<RetrievableTypes<GettableTypes...>, StorageableTypes...> {
}
// conversion to const iterator
template <class ConstGettable = Gettable>
operator std::enable_if_t<!is_const, Iterator<ConstGettable const>>() const {
requires(!is_const) explicit operator Iterator<ConstGettable const>() const {
return Iterator<ConstGettable const>{container_ptr_, idx_};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
// template to construct components
// using forward interators
// different selection based on component type
template <class CompType, class ForwardIterator>
std::enable_if_t<std::is_base_of_v<Base, CompType>> add_component(ForwardIterator begin, ForwardIterator end) {
template <std::derived_from<Base> CompType, std::forward_iterator ForwardIterator>
void add_component(ForwardIterator begin, ForwardIterator end) {
assert(!construction_complete_);
// check forward iterator
static_assert(std::is_base_of_v<std::forward_iterator_tag,
typename std::iterator_traits<ForwardIterator>::iterator_category>);
size_t size = std::distance(begin, end);
components_.template reserve<CompType>(size);
// loop to add component
Expand All @@ -158,7 +155,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
double const u1 = components_.template get_item<Node>(input.from_node).u_rated();
double const u2 = components_.template get_item<Node>(input.to_node).u_rated();
// set system frequency for line
if constexpr (std::is_same_v<CompType, Line>) {
if constexpr (std::same_as<CompType, Line>) {
components_.template emplace<CompType>(id, input, system_frequency_, u1, u2);
}
else {
Expand Down Expand Up @@ -226,12 +223,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
// using forward interators
// different selection based on component type
// if sequence_idx is given, it will be used to load the object instead of using IDs via hash map.
template <class CompType, class CacheType, class ForwardIterator>
template <class CompType, class CacheType, std::forward_iterator ForwardIterator>
void update_component(ForwardIterator begin, ForwardIterator end, std::vector<Idx2D> const& sequence_idx = {}) {
assert(construction_complete_);
// check forward iterator
static_assert(std::is_base_of_v<std::forward_iterator_tag,
typename std::iterator_traits<ForwardIterator>::iterator_category>);
bool const has_sequence_id = !sequence_idx.empty();
Idx seq = 0;
// loop to to update component
Expand Down Expand Up @@ -719,12 +713,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output node
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_same_v<Node, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::same_as<Node> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(components_.template citer<Component>().begin(),
components_.template citer<Component>().end(), comp_coup_->node.cbegin(), res_it,
Expand All @@ -738,12 +728,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output branch
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_base_of_v<Branch, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::derived_from<Branch> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(components_.template citer<Component>().begin(),
components_.template citer<Component>().end(),
Expand All @@ -757,12 +743,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output branch3
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_base_of_v<Branch3, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::derived_from<Branch3> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(components_.template citer<Component>().begin(),
components_.template citer<Component>().end(),
Expand All @@ -779,12 +761,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output source, load_gen, shunt individually
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_same_v<Appliance, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::same_as<Appliance> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
res_it = output_result<sym, Source>(math_output, res_it);
res_it = output_result<sym, GenericLoadGen>(math_output, res_it);
Expand All @@ -793,12 +771,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output source
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_same_v<Source, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::same_as<Source> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(components_.template citer<Component>().begin(),
components_.template citer<Component>().end(), comp_coup_->source.cbegin(), res_it,
Expand All @@ -811,12 +785,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output load gen
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_base_of_v<GenericLoadGen, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::derived_from<GenericLoadGen> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(
components_.template citer<Component>().begin(), components_.template citer<Component>().end(),
Expand All @@ -830,12 +800,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output shunt
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_same_v<Shunt, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::same_as<Shunt> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(components_.template citer<Component>().begin(),
components_.template citer<Component>().end(), comp_coup_->shunt.cbegin(), res_it,
Expand All @@ -848,12 +814,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output voltage sensor
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_base_of_v<GenericVoltageSensor, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::derived_from<GenericVoltageSensor> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(
components_.template citer<Component>().begin(), components_.template citer<Component>().end(),
Expand All @@ -869,12 +831,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

// output power sensor
template <bool sym, class Component, class ResIt>
std::enable_if_t<
std::is_base_of_v<std::forward_iterator_tag, typename std::iterator_traits<ResIt>::iterator_category> &&
std::is_base_of_v<GenericPowerSensor, Component>,
ResIt>
output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
template <bool sym, std::derived_from<GenericPowerSensor> Component, std::forward_iterator ResIt>
ResIt output_result(std::vector<MathOutput<sym>> const& math_output, ResIt res_it) {
assert(construction_complete_);
return std::transform(
components_.template citer<Component>().begin(), components_.template citer<Component>().end(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace power_grid_model {
// hide implementation in inside namespace
namespace math_model_impl {

template <class T, bool sym, bool is_tensor, int n_sub_block, class = std::enable_if_t<check_scalar_v<T>>>
template <scalar_value T, bool sym, bool is_tensor, int n_sub_block>
struct block_trait {
static constexpr int n_row = sym ? n_sub_block : n_sub_block * 3;
static constexpr int n_col = is_tensor ? (sym ? n_sub_block : n_sub_block * 3) : 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,30 @@ template <class Tensor, class RHSVector, class XVector, class = void>
struct sparse_lu_entry_trait;

template <class Tensor, class RHSVector, class XVector>
using enable_scalar_lu_t =
std::enable_if_t<std::is_same_v<Tensor, RHSVector> && std::is_same_v<Tensor, XVector> && check_scalar_v<Tensor>>;
concept scalar_value_lu = scalar_value<Tensor> && std::same_as<Tensor, RHSVector> && std::same_as<Tensor, XVector>;

// TODO(mgovers) improve this concept
template <class Derived>
int check_array_base(Eigen::ArrayBase<Derived> const&) {
return 0;
}
template <class ArrayLike>
concept eigen_array = std::same_as<decltype(check_array_base(ArrayLike{})), int>; // should be an eigen array

template <class LHSArrayLike, class RHSArrayLike>
concept matrix_multiplicable = eigen_array<LHSArrayLike> && eigen_array<RHSArrayLike> &&
(static_cast<Idx>(LHSArrayLike::ColsAtCompileTime) == static_cast<Idx>(RHSArrayLike::RowsAtCompileTime));

template <class Tensor, class RHSVector, class XVector>
using enable_tensor_lu_t = std::enable_if_t<
std::is_same_v<decltype(check_array_base(Tensor{})), int> && // tensor should be an eigen array
std::is_same_v<decltype(check_array_base(RHSVector{})), int> && // rhs vector should be an eigen array
std::is_same_v<decltype(check_array_base(XVector{})), int> && // x vector should be an eigen array
(Idx)Tensor::RowsAtCompileTime == (Idx)Tensor::ColsAtCompileTime && // tensor should be square
RHSVector::ColsAtCompileTime == 1 && // rhs vector should be column vector
(Idx)RHSVector::RowsAtCompileTime == (Idx)Tensor::RowsAtCompileTime && // rhs vector should be column vector
XVector::ColsAtCompileTime == 1 && // x vector should be column vector
(Idx)XVector::RowsAtCompileTime == (Idx)Tensor::RowsAtCompileTime && // x vector should be column vector
std::is_same_v<typename Tensor::Scalar, typename RHSVector::Scalar> && // all entries should have same scalar type
std::is_same_v<typename Tensor::Scalar, typename XVector::Scalar> && // all entries should have same scalar type
check_scalar_v<typename Tensor::Scalar>>; // scalar can only be double or complex double
concept tensor_lu = rk2_tensor<Tensor> && column_vector<RHSVector> && column_vector<XVector> &&
matrix_multiplicable<Tensor, RHSVector> && matrix_multiplicable<Tensor, XVector> &&
std::same_as<typename Tensor::Scalar, typename RHSVector::Scalar> && // all entries should have same scalar type
std::same_as<typename Tensor::Scalar, typename XVector::Scalar> && // all entries should have same scalar type
scalar_value<typename Tensor::Scalar>; // scalar can only be double or complex double

template <class Tensor, class RHSVector, class XVector>
struct sparse_lu_entry_trait<Tensor, RHSVector, XVector, enable_scalar_lu_t<Tensor, RHSVector, XVector>> {
requires scalar_value_lu<Tensor, RHSVector, XVector>
struct sparse_lu_entry_trait<Tensor, RHSVector, XVector> {
static constexpr bool is_block = false;
static constexpr Idx block_size = 1;
using Scalar = Tensor;
Expand All @@ -55,7 +55,8 @@ struct sparse_lu_entry_trait<Tensor, RHSVector, XVector, enable_scalar_lu_t<Tens
};

template <class Tensor, class RHSVector, class XVector>
struct sparse_lu_entry_trait<Tensor, RHSVector, XVector, enable_tensor_lu_t<Tensor, RHSVector, XVector>> {
requires tensor_lu<Tensor, RHSVector, XVector>
struct sparse_lu_entry_trait<Tensor, RHSVector, XVector> {
static constexpr bool is_block = true;
static constexpr Idx block_size = Tensor::RowsAtCompileTime;
using Scalar = typename Tensor::Scalar;
Expand Down

0 comments on commit e93174d

Please sign in to comment.