Skip to content

Commit

Permalink
Remove redundant code paths: reference BGemm, BGemm functor. (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamHillier committed Sep 18, 2020
1 parent 7254930 commit 35e0670
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 403 deletions.
41 changes: 2 additions & 39 deletions larq_compute_engine/core/BUILD
Expand Up @@ -9,16 +9,6 @@ cc_library(
],
)

cc_library(
name = "bgemm_functor",
hdrs = [
"bgemm_functor.h",
],
deps = [
":types",
],
)

cc_library(
name = "bitpack",
hdrs = ["bitpack.h"] + select({
Expand Down Expand Up @@ -51,19 +41,6 @@ cc_library(
hdrs = ["padding_functor.h"],
)

cc_library(
name = "bgemm_ref",
hdrs = [
"bgemm_impl_ref.h",
],
deps = [
"//larq_compute_engine/core:bgemm_functor",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bconv2d_output_transform",
hdrs = [
Expand Down Expand Up @@ -108,15 +85,14 @@ cc_library(
deps = [
":bgemm_kernels_arm",
":bitpack",
"//larq_compute_engine/core:bgemm_functor",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bgemm_ruy",
name = "bgemm_impl",
hdrs = [
"bgemm_impl_ruy.h",
"bgemm_impl.h",
"bgemm_trmul_params.h",
"ruy_pack.h",
],
Expand All @@ -128,26 +104,13 @@ cc_library(
],
)

cc_library(
name = "bgemm_impl",
hdrs = [
"bgemm_impl.h",
],
deps = [
":bgemm_ref",
":bgemm_ruy",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bconv2d_impl_ref",
hdrs = [
"bconv2d_impl_ref.h",
],
deps = [
":bconv2d_output_transform",
":bgemm_functor",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
],
)
Expand Down
1 change: 0 additions & 1 deletion larq_compute_engine/core/bconv2d_impl_ref.h
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#define COMPUTE_ENGINE_CORE_BCONV2D_IMPL_REF_H_

#include "larq_compute_engine/core/bconv2d_output_transform.h"
#include "larq_compute_engine/core/bgemm_functor.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

Expand Down
71 changes: 0 additions & 71 deletions larq_compute_engine/core/bgemm_functor.h

This file was deleted.

90 changes: 61 additions & 29 deletions larq_compute_engine/core/bgemm_impl.h
Expand Up @@ -2,53 +2,85 @@
#define COMPUTE_ENGINE_CORE_BGEMM_IMPL_H_

#include "bgemm_kernels_common.h"
#include "bgemm_trmul_params.h"
#include "ruy/context.h"
#include "ruy/context_get_ctx.h"
#include "ruy/matrix.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"

// TODO: currently only ref. impl. is supported
#ifndef TFLITE_WITH_RUY
#include "bgemm_impl_ref.h"
#else
#include "bgemm_impl_ruy.h"
#endif
#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"

using namespace tflite;
using namespace tflite::cpu_backend_gemm;

namespace compute_engine {
namespace tflite {

#ifndef TFLITE_WITH_RUY
template <typename AccumScalar, typename DstScalar>
struct BGemmImpl : BGemmImplRef<LhsScalar, RhsScalar, AccumScalar, DstScalar> {
};
#else
template <typename AccumScalar, typename DstScalar>
struct BGemmImpl : BGemmImplUsingRuy<AccumScalar, DstScalar> {};
#endif
using compute_engine::core::TBitpacked;

template <typename AccumScalar, typename DstScalar>
void BGemm(const MatrixParams<TBitpacked>& lhs_params,
const TBitpacked* lhs_data,
const MatrixParams<TBitpacked>& rhs_params,
const TBitpacked* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const OutputTransform<DstScalar>& params,
const OutputTransform<DstScalar>& output_transform,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("BGemm");
// TODO: special fast bgemm impl. for matrix-vector multiplication
// if (dst_params.cols == 1) {
// // GEMV case: try a custom fast GEMV path.
// if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
// dst_params, dst_data, params, context)) {
// return;
// }
// }
ruy::profiler::ScopeLabel label2("BGemm/GeneralBGEMM");
BGemmImpl<AccumScalar, DstScalar>::Run(lhs_params, lhs_data, rhs_params,
rhs_data, dst_params, dst_data, params,
context);
ruy::profiler::ScopeLabel label("BGemm (Ruy)");

static_assert(std::is_signed<DstScalar>::value,
"Output of BGEMM should be of a signed type.");

// Get ruy context
auto ruy_ctx = get_ctx(context->ruy_context());

// Set up the matrix layouts and mul_params.
ruy::Matrix<TBitpacked> lhs;
ruy::Matrix<TBitpacked> rhs;
ruy::Matrix<DstScalar> dst;
// We allow these matrices to be cached. Note that this doesn't force them
// to be cached; it means that the `cache_policy` of the MatrixParams will
// be respected.
cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &lhs,
/*use_caching=*/true);
cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &rhs,
/*use_caching=*/true);
cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &dst);

// We have to make this a `const` matrix because otherwise gcc will try to
// use the non-const versions of `matrix.data()`
ruy::Mat<TBitpacked> internal_lhs =
ruy::ToInternal((const ruy::Matrix<TBitpacked>)lhs);
ruy::Mat<TBitpacked> internal_rhs =
ruy::ToInternal((const ruy::Matrix<TBitpacked>)rhs);
ruy::Mat<DstScalar> internal_dst = ruy::ToInternal(dst);

BinaryMulParams<AccumScalar, DstScalar> mul_params;
mul_params.output_transform = output_transform;

#if RUY_PLATFORM_NEON
constexpr bool HasOptimizedNeonKernel =
std::is_same<AccumScalar, std::int16_t>::value ||
std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value;
constexpr auto SelectedPath =
HasOptimizedNeonKernel ? ruy::Path::kNeon : ruy::Path::kStandardCpp;
#else
constexpr auto SelectedPath = ruy::Path::kStandardCpp;
#endif

ruy::Mat<TBitpacked> transposed_lhs(internal_lhs);
Transpose(&transposed_lhs);

ruy::TrMulParams bgemm_trmul_params;
PopulateBGemmTrMulParams<SelectedPath>(transposed_lhs, internal_rhs,
internal_dst, mul_params,
&bgemm_trmul_params);

HandlePrepackedCaching(&bgemm_trmul_params, ruy_ctx);
ruy::TrMul(&bgemm_trmul_params, ruy_ctx);
}

} // namespace tflite
Expand Down
64 changes: 0 additions & 64 deletions larq_compute_engine/core/bgemm_impl_ref.h

This file was deleted.

0 comments on commit 35e0670

Please sign in to comment.