diff --git a/larq_compute_engine/core/BUILD b/larq_compute_engine/core/BUILD index 48060387e..0bbdcdfad 100644 --- a/larq_compute_engine/core/BUILD +++ b/larq_compute_engine/core/BUILD @@ -9,16 +9,6 @@ cc_library( ], ) -cc_library( - name = "bgemm_functor", - hdrs = [ - "bgemm_functor.h", - ], - deps = [ - ":types", - ], -) - cc_library( name = "bitpack", hdrs = ["bitpack.h"] + select({ @@ -51,19 +41,6 @@ cc_library( hdrs = ["padding_functor.h"], ) -cc_library( - name = "bgemm_ref", - hdrs = [ - "bgemm_impl_ref.h", - ], - deps = [ - "//larq_compute_engine/core:bgemm_functor", - "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context", - "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm", - "@ruy//ruy/profiler:instrumentation", - ], -) - cc_library( name = "bconv2d_output_transform", hdrs = [ @@ -108,15 +85,14 @@ cc_library( deps = [ ":bgemm_kernels_arm", ":bitpack", - "//larq_compute_engine/core:bgemm_functor", "@ruy//ruy/profiler:instrumentation", ], ) cc_library( - name = "bgemm_ruy", + name = "bgemm_impl", hdrs = [ - "bgemm_impl_ruy.h", + "bgemm_impl.h", "bgemm_trmul_params.h", "ruy_pack.h", ], @@ -128,18 +104,6 @@ cc_library( ], ) -cc_library( - name = "bgemm_impl", - hdrs = [ - "bgemm_impl.h", - ], - deps = [ - ":bgemm_ref", - ":bgemm_ruy", - "@ruy//ruy/profiler:instrumentation", - ], -) - cc_library( name = "bconv2d_impl_ref", hdrs = [ @@ -147,7 +111,6 @@ cc_library( ], deps = [ ":bconv2d_output_transform", - ":bgemm_functor", "@org_tensorflow//tensorflow/lite/kernels/internal:types", ], ) diff --git a/larq_compute_engine/core/bconv2d_impl_ref.h b/larq_compute_engine/core/bconv2d_impl_ref.h index e24397f4c..2996e8eec 100644 --- a/larq_compute_engine/core/bconv2d_impl_ref.h +++ b/larq_compute_engine/core/bconv2d_impl_ref.h @@ -17,7 +17,6 @@ limitations under the License. #define COMPUTE_ENGINE_CORE_BCONV2D_IMPL_REF_H_ #include "larq_compute_engine/core/bconv2d_output_transform.h" -#include "larq_compute_engine/core/bgemm_functor.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/larq_compute_engine/core/bgemm_functor.h b/larq_compute_engine/core/bgemm_functor.h deleted file mode 100644 index d974f914f..000000000 --- a/larq_compute_engine/core/bgemm_functor.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef COMPUTE_ENGINE_KERNELS_BGEMM_FUNCTORS_H_ -#define COMPUTE_ENGINE_KERNELS_BGEMM_FUNCTORS_H_ - -#include -#include -#include - -#include "larq_compute_engine/core/types.h" - -namespace compute_engine { -namespace core { - -enum class Layout { RowMajor, ColMajor }; - -using compute_engine::core::bitpacking_bitwidth; -using compute_engine::core::TBitpacked; - -inline std::int32_t compute_binary_inner_prod(const TBitpacked& a, - const TBitpacked& b) { - // TODO: __builtin_popcount works only with GCC compiler -> implement a - // generalized version. - return bitpacking_bitwidth - - 2 * static_cast(__builtin_popcount(a ^ b)); -} - -inline std::int32_t xor_popcount(const TBitpacked& a, const TBitpacked& b) { - return __builtin_popcount(a ^ b); -} - -// A naive implementation of binary matrix multiplication, useful for -// debugging and understanding the algorithm. -template -class ReferenceBGemmFunctor { - public: - void operator()(const std::size_t m, const std::size_t n, const std::size_t k, - const TBitpacked* a, const std::size_t lda, - const TBitpacked* b, const std::size_t ldb, TOut* c, - const std::size_t ldc, const int bitpaddding = 0) { - static_assert(std::is_signed::value, - "Output of BGEMM should be of a signed type."); - - const std::size_t a_i_stride = (LLhs == Layout::RowMajor ? lda : 1); - const std::size_t a_l_stride = (LLhs == Layout::RowMajor ? 1 : lda); - const std::size_t b_j_stride = (LRhs == Layout::RowMajor ? 1 : ldb); - const std::size_t b_l_stride = (LRhs == Layout::RowMajor ? ldb : 1); - const std::size_t c_i_stride = (LOut == Layout::RowMajor ? ldc : 1); - const std::size_t c_j_stride = (LOut == Layout::RowMajor ? 1 : ldc); - - std::size_t i, j, l; - // The j-loop should be the inner loop for weight-stationary computations - for (i = 0; i < m; ++i) { - for (j = 0; j < n; ++j) { - TAccum total(0); - for (l = 0; l < k; ++l) { - const std::size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); - const std::size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); - total += compute_binary_inner_prod(a[a_index], b[b_index]); - } - const std::size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); - c[c_index] = static_cast(total - bitpaddding); - } // end of j loop - } // end of i loop - } -}; - -} // namespace core -} // namespace compute_engine - -#endif // COMPUTE_ENGINE_KERNELS_BGEMM_FUNCTORS_H_ diff --git a/larq_compute_engine/core/bgemm_impl.h b/larq_compute_engine/core/bgemm_impl.h index 9f1245cf2..0c9db42bb 100644 --- a/larq_compute_engine/core/bgemm_impl.h +++ b/larq_compute_engine/core/bgemm_impl.h @@ -2,16 +2,15 @@ #define COMPUTE_ENGINE_CORE_BGEMM_IMPL_H_ #include "bgemm_kernels_common.h" +#include "bgemm_trmul_params.h" +#include "ruy/context.h" +#include "ruy/context_get_ctx.h" +#include "ruy/matrix.h" +#include "ruy/platform.h" #include "ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" - -// TODO: currently only ref. impl. is supported -#ifndef TFLITE_WITH_RUY -#include "bgemm_impl_ref.h" -#else -#include "bgemm_impl_ruy.h" -#endif +#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" using namespace tflite; using namespace tflite::cpu_backend_gemm; @@ -19,14 +18,7 @@ using namespace tflite::cpu_backend_gemm; namespace compute_engine { namespace tflite { -#ifndef TFLITE_WITH_RUY -template -struct BGemmImpl : BGemmImplRef { -}; -#else -template -struct BGemmImpl : BGemmImplUsingRuy {}; -#endif +using compute_engine::core::TBitpacked; template void BGemm(const MatrixParams& lhs_params, @@ -34,21 +26,61 @@ void BGemm(const MatrixParams& lhs_params, const MatrixParams& rhs_params, const TBitpacked* rhs_data, const MatrixParams& dst_params, DstScalar* dst_data, - const OutputTransform& params, + const OutputTransform& output_transform, CpuBackendContext* context) { - ruy::profiler::ScopeLabel label("BGemm"); - // TODO: special fast bgemm impl. for matrix-vector multiplication - // if (dst_params.cols == 1) { - // // GEMV case: try a custom fast GEMV path. - // if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, - // dst_params, dst_data, params, context)) { - // return; - // } - // } - ruy::profiler::ScopeLabel label2("BGemm/GeneralBGEMM"); - BGemmImpl::Run(lhs_params, lhs_data, rhs_params, - rhs_data, dst_params, dst_data, params, - context); + ruy::profiler::ScopeLabel label("BGemm (Ruy)"); + + static_assert(std::is_signed::value, + "Output of BGEMM should be of a signed type."); + + // Get ruy context + auto ruy_ctx = get_ctx(context->ruy_context()); + + // Set up the matrix layouts and mul_params. + ruy::Matrix lhs; + ruy::Matrix rhs; + ruy::Matrix dst; + // We allow these matrices to be cached. Note that this doesn't force them + // to be cached; it means that the `cache_policy` of the MatrixParams will + // be respected. + cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &lhs, + /*use_caching=*/true); + cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &rhs, + /*use_caching=*/true); + cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &dst); + + // We have to make this a `const` matrix because otherwise gcc will try to + // use the non-const versions of `matrix.data()` + ruy::Mat internal_lhs = + ruy::ToInternal((const ruy::Matrix)lhs); + ruy::Mat internal_rhs = + ruy::ToInternal((const ruy::Matrix)rhs); + ruy::Mat internal_dst = ruy::ToInternal(dst); + + BinaryMulParams mul_params; + mul_params.output_transform = output_transform; + +#if RUY_PLATFORM_NEON + constexpr bool HasOptimizedNeonKernel = + std::is_same::value || + std::is_same::value || + std::is_same::value; + constexpr auto SelectedPath = + HasOptimizedNeonKernel ? ruy::Path::kNeon : ruy::Path::kStandardCpp; +#else + constexpr auto SelectedPath = ruy::Path::kStandardCpp; +#endif + + ruy::Mat transposed_lhs(internal_lhs); + Transpose(&transposed_lhs); + + ruy::TrMulParams bgemm_trmul_params; + PopulateBGemmTrMulParams(transposed_lhs, internal_rhs, + internal_dst, mul_params, + &bgemm_trmul_params); + + HandlePrepackedCaching(&bgemm_trmul_params, ruy_ctx); + ruy::TrMul(&bgemm_trmul_params, ruy_ctx); } } // namespace tflite diff --git a/larq_compute_engine/core/bgemm_impl_ref.h b/larq_compute_engine/core/bgemm_impl_ref.h deleted file mode 100644 index 17d14df91..000000000 --- a/larq_compute_engine/core/bgemm_impl_ref.h +++ /dev/null @@ -1,64 +0,0 @@ -#ifndef COMPUTE_ENGINE_CORE_BGEMM_IMPL_REF_H_ -#define COMPUTE_ENGINE_CORE_BGEMM_IMPL_REF_H_ - -#include "larq_compute_engine/core/bgemm_functor.h" -#include "ruy/profiler/instrumentation.h" -#include "tensorflow/lite/kernels/cpu_backend_context.h" -#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" - -using namespace tflite; -using namespace tflite::cpu_backend_gemm; - -namespace compute_engine { - -namespace ce = compute_engine; - -namespace tflite { - -using ce::core::TBitpacked; - -template -struct BGemmImplRef { - static void Run( - const MatrixParams& lhs_params, const TBitpacked* lhs_data, - const MatrixParams& rhs_params, const TBitpacked* rhs_data, - const MatrixParams& dst_params, DstScalar* dst_data, - const GemmParams& params, - CpuBackendContext* context) { - ruy::profiler::ScopeLabel label("BGemmRef"); - - static_assert(std::is_signed::value, - "Output of BGEMM should be of a signed type."); - - // This code assumes specific memory layout - // assert(rhs_params.order == cpu_backend_gemm::Order::kColMajor); - using TBGemmFunctor = - ce::core::ReferenceBGemmFunctor; - - // LHS (n, k) -> RowMajor -> (n, k) - // RHS (m, k) -> ColMajor -> (k, m) - // DST (n, m) -> ColMajor -> (m, n) - const auto n = lhs_params.rows; - const auto k = lhs_params.cols; - const auto m = rhs_params.cols; - const auto lda = lhs_params.cols; - // use number of rows for col-major layout - const auto ldb = rhs_params.rows; - const auto ldc = dst_params.rows; - TBGemmFunctor bgemm_functor; - // TODO: Currently GemmParmas is not used the same way as - // as its used in the TF Lite codebase. Here, we abuse the - // 'multiplier_exponent' which is used only for non-floating-point - // cases to pass the bitpadding correction value (int) to BGemm - bgemm_functor(n, m, k, lhs_data, lda, rhs_data, ldb, dst_data, ldc, - params.multiplier_exponent); - } -}; - -} // namespace tflite -} // namespace compute_engine - -#endif // COMPUTE_ENGINE_CORE_BGEMM_IMPL_REF_H_ diff --git a/larq_compute_engine/core/bgemm_impl_ruy.h b/larq_compute_engine/core/bgemm_impl_ruy.h deleted file mode 100644 index 50e4d398c..000000000 --- a/larq_compute_engine/core/bgemm_impl_ruy.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef COMPUTE_ENGINE_CORE_BGEMM_IMPL_RUY_H_ -#define COMPUTE_ENGINE_CORE_BGEMM_IMPL_RUY_H_ - -#include "bgemm_kernels_common.h" -#include "bgemm_trmul_params.h" -#include "ruy/context.h" -#include "ruy/context_get_ctx.h" -#include "ruy/matrix.h" -#include "ruy/platform.h" -#include "ruy/profiler/instrumentation.h" -#include "tensorflow/lite/kernels/cpu_backend_context.h" -#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" -#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" - -using namespace tflite; -using namespace tflite::cpu_backend_gemm; - -namespace compute_engine { - -namespace tflite { - -using compute_engine::core::TBitpacked; - -template -struct BGemmImplUsingRuy { - static void Run(const MatrixParams& lhs_params, - const TBitpacked* lhs_data, - const MatrixParams& rhs_params, - const TBitpacked* rhs_data, - const MatrixParams& dst_params, - DstScalar* dst_data, - const OutputTransform& output_transform, - CpuBackendContext* context) { - ruy::profiler::ScopeLabel label("BGemmRuy"); - - static_assert(std::is_signed::value, - "Output of BGEMM should be of a signed type."); - - // Get ruy context - auto ruy_ctx = get_ctx(context->ruy_context()); - - // Set up the matrix layouts and mul_params. - ruy::Matrix lhs; - ruy::Matrix rhs; - ruy::Matrix dst; - // We allow these matrices to be cached. Note that this doesn't force them - // to be cached; it means that the `cache_policy` of the MatrixParams will - // be respected. - cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &lhs, - /*use_caching=*/true); - cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &rhs, - /*use_caching=*/true); - cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &dst); - - // We have to make this a `const` matrix because otherwise gcc will try to - // use the non-const versions of `matrix.data()` - ruy::Mat internal_lhs = - ruy::ToInternal((const ruy::Matrix)lhs); - ruy::Mat internal_rhs = - ruy::ToInternal((const ruy::Matrix)rhs); - ruy::Mat internal_dst = ruy::ToInternal(dst); - - BinaryMulParams mul_params; - mul_params.output_transform = output_transform; - -#if RUY_PLATFORM_NEON - constexpr bool HasOptimizedNeonKernel = - std::is_same::value || - std::is_same::value || - std::is_same::value; - constexpr auto SelectedPath = - HasOptimizedNeonKernel ? ruy::Path::kNeon : ruy::Path::kStandardCpp; -#else - constexpr auto SelectedPath = ruy::Path::kStandardCpp; -#endif - - ruy::Mat transposed_lhs(internal_lhs); - Transpose(&transposed_lhs); - - ruy::TrMulParams bgemm_trmul_params; - PopulateBgemmTrMulParams(transposed_lhs, internal_rhs, - internal_dst, mul_params, - &bgemm_trmul_params); - - HandlePrepackedCaching(&bgemm_trmul_params, ruy_ctx); - ruy::TrMul(&bgemm_trmul_params, ruy_ctx); - } -}; - -} // namespace tflite -} // namespace compute_engine - -#endif // COMPUTE_ENGINE_CORE_BGEMM_IMPL_RUY_H_ diff --git a/larq_compute_engine/core/bgemm_kernels_arm.h b/larq_compute_engine/core/bgemm_kernels_arm.h index 76a79d733..f9809533f 100644 --- a/larq_compute_engine/core/bgemm_kernels_arm.h +++ b/larq_compute_engine/core/bgemm_kernels_arm.h @@ -25,12 +25,12 @@ using compute_engine::core::TBitpacked; // Optimised Arm32 kernel. Supports float or int8 output. template -struct BgemmKernel> { Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout; using RhsLayout = FixedKernelLayout; - explicit BgemmKernel(Tuning tuning_) : tuning(tuning_) {} + explicit BGemmKernel(Tuning tuning_) : tuning(tuning_) {} void Run(const ruy::PMat& lhs, const ruy::PMat& rhs, const BinaryMulParams& mul_params, int start_row, int start_col, int end_row, int end_col, @@ -54,12 +54,12 @@ struct BgemmKernel -struct BgemmKernel> { Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout; using RhsLayout = FixedKernelLayout; - explicit BgemmKernel(Tuning tuning_) : tuning(tuning_) {} + explicit BGemmKernel(Tuning tuning_) : tuning(tuning_) {} void Run(const ruy::PMat& lhs, const ruy::PMat& rhs, const BinaryMulParams& mul_params, int start_row, int start_col, int end_row, int end_col, @@ -78,12 +78,12 @@ struct BgemmKernel -struct BgemmKernel> { Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout; using RhsLayout = FixedKernelLayout; - explicit BgemmKernel(Tuning tuning_) : tuning(tuning_) {} + explicit BGemmKernel(Tuning tuning_) : tuning(tuning_) {} void Run(const ruy::PMat& lhs, const ruy::PMat& rhs, const BinaryMulParams& mul_params, int start_row, int start_col, int end_row, int end_col, diff --git a/larq_compute_engine/core/bgemm_kernels_ruy.h b/larq_compute_engine/core/bgemm_kernels_ruy.h index e84905366..87815ed53 100644 --- a/larq_compute_engine/core/bgemm_kernels_ruy.h +++ b/larq_compute_engine/core/bgemm_kernels_ruy.h @@ -1,7 +1,6 @@ #ifndef COMPUTE_ENGINE_CORE_BGEMM_KERNELS_RUY_H_ #define COMPUTE_ENGINE_CORE_BGEMM_KERNELS_RUY_H_ -#include "larq_compute_engine/core/bgemm_functor.h" #include "larq_compute_engine/core/bitpack_utils.h" #include "ruy/platform.h" #include "ruy/profiler/instrumentation.h" @@ -16,7 +15,7 @@ using ce::core::bitpacking_bitwidth; using ce::core::TBitpacked; template -struct BgemmKernel {}; +struct BGemmKernel {}; // TODO: this is hacky #if RUY_PLATFORM_NEON @@ -24,11 +23,11 @@ struct BgemmKernel {}; #endif template -struct BgemmKernel { +struct BGemmKernel { using AccumScalar = typename Spec::AccumScalar; using LhsLayout = typename Spec::StandardCppKernelLhsLayout; using RhsLayout = typename Spec::StandardCppKernelRhsLayout; - explicit BgemmKernel(ruy::Tuning) {} + explicit BGemmKernel(ruy::Tuning) {} void Run(const ruy::PMat& lhs, const ruy::PMat& rhs, const Spec& spec, int start_row, int start_col, int end_row, int end_col, ruy::Mat* dst) const { @@ -67,11 +66,11 @@ struct BgemmKernel { // A template specialisation for writing bitpacked output. template -struct BgemmKernel { +struct BGemmKernel { using AccumScalar = typename Spec::AccumScalar; using LhsLayout = typename Spec::StandardCppKernelLhsLayout; using RhsLayout = typename Spec::StandardCppKernelRhsLayout; - explicit BgemmKernel(ruy::Tuning) {} + explicit BGemmKernel(ruy::Tuning) {} void Run(const ruy::PMat& lhs, const ruy::PMat& rhs, const Spec& spec, int start_row, int start_col, int end_row, int end_col, ruy::Mat* dst) const { @@ -137,11 +136,11 @@ struct BgemmKernel { }; template -void RunBgemmKernelTyped(ruy::Tuning tuning, const ruy::PMat& lhs, +void RunBGemmKernelTyped(ruy::Tuning tuning, const ruy::PMat& lhs, const ruy::PMat& rhs, const Spec& spec, int start_row, int start_col, int end_row, int end_col, ruy::Mat* dst) { - using BKernel = BgemmKernel; + using BKernel = BGemmKernel; BKernel kernel(tuning); using LhsLayout = typename BKernel::LhsLayout; using RhsLayout = typename BKernel::RhsLayout; @@ -173,11 +172,11 @@ void RunBgemmKernelTyped(ruy::Tuning tuning, const ruy::PMat& lhs, } template -void RunBgemmKernel(ruy::Tuning tuning, const ruy::SidePair& src, +void RunBGemmKernel(ruy::Tuning tuning, const ruy::SidePair& src, void* spec, const ruy::SidePair& start, const ruy::SidePair& end, ruy::EMat* dst) { ruy::Mat mdst = ruy::UneraseType(*dst); - RunBgemmKernelTyped( + RunBGemmKernelTyped( tuning, ruy::UneraseType(src[ruy::Side::kLhs]), ruy::UneraseType(src[ruy::Side::kRhs]), *static_cast(spec), start[ruy::Side::kLhs], diff --git a/larq_compute_engine/core/bgemm_trmul_params.h b/larq_compute_engine/core/bgemm_trmul_params.h index 69f79fdbd..4f06c10a3 100644 --- a/larq_compute_engine/core/bgemm_trmul_params.h +++ b/larq_compute_engine/core/bgemm_trmul_params.h @@ -12,7 +12,7 @@ namespace compute_engine { namespace tflite { template -void PopulateBgemmTrMulParams(const Mat& lhs, +void PopulateBGemmTrMulParams(const Mat& lhs, const Mat& rhs, Mat& dst, const MulParamsType& mul_params, ruy::TrMulParams* params) { @@ -23,12 +23,12 @@ void PopulateBgemmTrMulParams(const Mat& lhs, // Optimised code paths only support all matrices being column-major if (!ruy::IsColMajorTrMul(*params) && ThePath != ruy::Path::kStandardCpp) { - PopulateBgemmTrMulParams(lhs, rhs, dst, mul_params, + PopulateBGemmTrMulParams(lhs, rhs, dst, mul_params, params); return; }; - using Kernel = BgemmKernel; + using Kernel = BGemmKernel; using LhsKernelLayout = typename Kernel::LhsLayout; using RhsKernelLayout = typename Kernel::RhsLayout; @@ -42,7 +42,7 @@ void PopulateBgemmTrMulParams(const Mat& lhs, &compute_engine::tflite::RunPack; params->run_pack[Side::kRhs] = &compute_engine::tflite::RunPack; - params->run_kernel = &RunBgemmKernel; + params->run_kernel = &RunBGemmKernel; } } // namespace tflite diff --git a/larq_compute_engine/core/padding_functor.h b/larq_compute_engine/core/padding_functor.h index f165e211f..efc1cca58 100644 --- a/larq_compute_engine/core/padding_functor.h +++ b/larq_compute_engine/core/padding_functor.h @@ -121,7 +121,7 @@ class PaddingFunctor { // filter = +1.0 --> bit = 0 ; correction += +1.0 // filter = -1.0 --> bit = 1 ; correction += -1.0 - popcount += __builtin_popcount(filter_data[filter_idx]); + popcount += xor_popcount(filter_data[filter_idx], 0); } float cur_correction = input_channels - 2 * popcount; diff --git a/larq_compute_engine/core/tests/BUILD b/larq_compute_engine/core/tests/BUILD index ee36b6e2a..6e9033062 100644 --- a/larq_compute_engine/core/tests/BUILD +++ b/larq_compute_engine/core/tests/BUILD @@ -2,16 +2,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -cc_test( - name = "bgemm_tests", - size = "small", - srcs = ["bgemm_tests.cc"], - deps = [ - "//larq_compute_engine/core:bgemm_functor", - "@com_google_googletest//:gtest_main", - ], -) - cc_test( name = "bitpack_tests", size = "small", @@ -38,7 +28,6 @@ cc_test( test_suite( name = "cc_tests", tests = [ - "bgemm_tests", "bitpack_tests", ], ) diff --git a/larq_compute_engine/core/tests/bgemm_tests.cc b/larq_compute_engine/core/tests/bgemm_tests.cc deleted file mode 100644 index 927aac63e..000000000 --- a/larq_compute_engine/core/tests/bgemm_tests.cc +++ /dev/null @@ -1,74 +0,0 @@ -#include - -#include - -#include "larq_compute_engine/core/bgemm_functor.h" - -namespace compute_engine { -namespace testing { - -namespace ce = compute_engine; -using ce::core::bitpacking_bitwidth; -using ce::core::Layout; -using ce::core::TBitpacked; - -TEST(BGemmTests, BinaryInnerProd) { - const auto a = static_cast(0b01101110000111111101011001101000); - const auto b = static_cast(0b01100110000110111001011011101001); - // a and b are off by five bits so POP_CNT(a XOR b) = 5 - const auto expected = static_cast(bitpacking_bitwidth - 2 * 5); - auto c = ce::core::compute_binary_inner_prod(a, b); - EXPECT_EQ(c, expected); -} - -template -void test_bgemm() { - const int lda = k; - const int ldb = n; - const int ldc = n; - - const int a_size = m * k; - const int b_size = k * n; - const int c_size = m * n; - - std::array a; - a.fill(1); - - std::array b; - b.fill(1); - - // each row of matrix "a" and column of "b" contains k same values so - // a[i, k] XOR b[k, j] = 0 and therefore - // c[i, j] = k * (bitpacking_bitwidth - 2 * POP_CNT(0)) - std::int32_t expected_value = k * bitpacking_bitwidth; - std::array expected; - expected.fill(expected_value); - - std::array c; - TBgemmFunctor bgemm_functor; - bgemm_functor(m, n, k, a.data(), lda, b.data(), ldb, c.data(), ldc); - EXPECT_THAT(c, ::testing::ElementsAreArray(expected)); -} - -TEST(BGemmTests, BGemmTestRowMajor) { - using BGemmFunctor = - ce::core::ReferenceBGemmFunctor; - const int m = 20; - const int k = 200; - const int n = 30; - test_bgemm(); -} - -TEST(BGemmTests, BGemmTestColMajor) { - using BGemmFunctor = - ce::core::ReferenceBGemmFunctor; - const int m = 20; - const int k = 200; - const int n = 30; - test_bgemm(); -} - -} // end namespace testing -} // end namespace compute_engine diff --git a/larq_compute_engine/core/types.h b/larq_compute_engine/core/types.h index e03c7ef79..87c93274e 100644 --- a/larq_compute_engine/core/types.h +++ b/larq_compute_engine/core/types.h @@ -1,6 +1,7 @@ #ifndef COMPUTE_ENGINE_CORE_TYPES_H_ #define COMPUTE_ENGINE_CORE_TYPES_H_ +#include #include #include #include @@ -18,6 +19,10 @@ using TBitpacked = std::int32_t; constexpr std::size_t bitpacking_bitwidth = std::numeric_limits::type>::digits; +inline int xor_popcount(const TBitpacked& a, const TBitpacked& b) { + return std::bitset(a ^ b).count(); +} + } // namespace core } // namespace compute_engine