Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant code paths: reference BGemm, BGemm functor. #510

Merged
merged 1 commit into from Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 2 additions & 39 deletions larq_compute_engine/core/BUILD
Expand Up @@ -9,16 +9,6 @@ cc_library(
],
)

cc_library(
name = "bgemm_functor",
hdrs = [
"bgemm_functor.h",
],
deps = [
":types",
],
)

cc_library(
name = "bitpack",
hdrs = ["bitpack.h"] + select({
Expand Down Expand Up @@ -51,19 +41,6 @@ cc_library(
hdrs = ["padding_functor.h"],
)

cc_library(
name = "bgemm_ref",
hdrs = [
"bgemm_impl_ref.h",
],
deps = [
"//larq_compute_engine/core:bgemm_functor",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bconv2d_output_transform",
hdrs = [
Expand Down Expand Up @@ -108,15 +85,14 @@ cc_library(
deps = [
":bgemm_kernels_arm",
":bitpack",
"//larq_compute_engine/core:bgemm_functor",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bgemm_ruy",
name = "bgemm_impl",
hdrs = [
"bgemm_impl_ruy.h",
"bgemm_impl.h",
"bgemm_trmul_params.h",
"ruy_pack.h",
],
Expand All @@ -128,26 +104,13 @@ cc_library(
],
)

cc_library(
name = "bgemm_impl",
hdrs = [
"bgemm_impl.h",
],
deps = [
":bgemm_ref",
":bgemm_ruy",
"@ruy//ruy/profiler:instrumentation",
],
)

cc_library(
name = "bconv2d_impl_ref",
hdrs = [
"bconv2d_impl_ref.h",
],
deps = [
":bconv2d_output_transform",
":bgemm_functor",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
],
)
Expand Down
1 change: 0 additions & 1 deletion larq_compute_engine/core/bconv2d_impl_ref.h
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#define COMPUTE_ENGINE_CORE_BCONV2D_IMPL_REF_H_

#include "larq_compute_engine/core/bconv2d_output_transform.h"
#include "larq_compute_engine/core/bgemm_functor.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

Expand Down
71 changes: 0 additions & 71 deletions larq_compute_engine/core/bgemm_functor.h

This file was deleted.

90 changes: 61 additions & 29 deletions larq_compute_engine/core/bgemm_impl.h
Expand Up @@ -2,53 +2,85 @@
#define COMPUTE_ENGINE_CORE_BGEMM_IMPL_H_

#include "bgemm_kernels_common.h"
#include "bgemm_trmul_params.h"
#include "ruy/context.h"
#include "ruy/context_get_ctx.h"
#include "ruy/matrix.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"

// TODO: currently only ref. impl. is supported
#ifndef TFLITE_WITH_RUY
#include "bgemm_impl_ref.h"
#else
#include "bgemm_impl_ruy.h"
#endif
#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"

using namespace tflite;
using namespace tflite::cpu_backend_gemm;

namespace compute_engine {
namespace tflite {

#ifndef TFLITE_WITH_RUY
template <typename AccumScalar, typename DstScalar>
struct BGemmImpl : BGemmImplRef<LhsScalar, RhsScalar, AccumScalar, DstScalar> {
};
#else
template <typename AccumScalar, typename DstScalar>
struct BGemmImpl : BGemmImplUsingRuy<AccumScalar, DstScalar> {};
#endif
using compute_engine::core::TBitpacked;

template <typename AccumScalar, typename DstScalar>
void BGemm(const MatrixParams<TBitpacked>& lhs_params,
const TBitpacked* lhs_data,
const MatrixParams<TBitpacked>& rhs_params,
const TBitpacked* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const OutputTransform<DstScalar>& params,
const OutputTransform<DstScalar>& output_transform,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("BGemm");
// TODO: special fast bgemm impl. for matrix-vector multiplication
// if (dst_params.cols == 1) {
// // GEMV case: try a custom fast GEMV path.
// if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
// dst_params, dst_data, params, context)) {
// return;
// }
// }
ruy::profiler::ScopeLabel label2("BGemm/GeneralBGEMM");
BGemmImpl<AccumScalar, DstScalar>::Run(lhs_params, lhs_data, rhs_params,
rhs_data, dst_params, dst_data, params,
context);
ruy::profiler::ScopeLabel label("BGemm (Ruy)");

static_assert(std::is_signed<DstScalar>::value,
"Output of BGEMM should be of a signed type.");

// Get ruy context
auto ruy_ctx = get_ctx(context->ruy_context());

// Set up the matrix layouts and mul_params.
ruy::Matrix<TBitpacked> lhs;
ruy::Matrix<TBitpacked> rhs;
ruy::Matrix<DstScalar> dst;
// We allow these matrices to be cached. Note that this doesn't force them
// to be cached; it means that the `cache_policy` of the MatrixParams will
// be respected.
cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &lhs,
/*use_caching=*/true);
cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &rhs,
/*use_caching=*/true);
cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &dst);

// We have to make this a `const` matrix because otherwise gcc will try to
// use the non-const versions of `matrix.data()`
ruy::Mat<TBitpacked> internal_lhs =
ruy::ToInternal((const ruy::Matrix<TBitpacked>)lhs);
ruy::Mat<TBitpacked> internal_rhs =
ruy::ToInternal((const ruy::Matrix<TBitpacked>)rhs);
ruy::Mat<DstScalar> internal_dst = ruy::ToInternal(dst);

BinaryMulParams<AccumScalar, DstScalar> mul_params;
mul_params.output_transform = output_transform;

#if RUY_PLATFORM_NEON
constexpr bool HasOptimizedNeonKernel =
std::is_same<AccumScalar, std::int16_t>::value ||
std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value;
constexpr auto SelectedPath =
HasOptimizedNeonKernel ? ruy::Path::kNeon : ruy::Path::kStandardCpp;
#else
constexpr auto SelectedPath = ruy::Path::kStandardCpp;
#endif

ruy::Mat<TBitpacked> transposed_lhs(internal_lhs);
Transpose(&transposed_lhs);

ruy::TrMulParams bgemm_trmul_params;
PopulateBGemmTrMulParams<SelectedPath>(transposed_lhs, internal_rhs,
internal_dst, mul_params,
&bgemm_trmul_params);

HandlePrepackedCaching(&bgemm_trmul_params, ruy_ctx);
ruy::TrMul(&bgemm_trmul_params, ruy_ctx);
}

} // namespace tflite
Expand Down
64 changes: 0 additions & 64 deletions larq_compute_engine/core/bgemm_impl_ref.h

This file was deleted.