diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 99153101fc326..8bd2b7fe2d127 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -32,6 +32,11 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data, T one = static_cast(1.); T neg_100 = static_cast(-100.); + PADDLE_ENFORCE( + (x >= static_cast(0)) && (x <= one), + "Input is expected to be within the interval [0, 1], but recieved %f.", + x); + T term1 = max(real_log(x), neg_100); T term2 = max(real_log(one - x), neg_100); @@ -64,29 +69,13 @@ class BCELossCUDAKernel : public framework::OpKernel { auto* labels = ctx.Input("Label"); auto* out = ctx.Output("Out"); - auto x_data = x->data(); - auto out_data = out->mutable_data(ctx.GetPlace()); + const auto* x_data = x->data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); auto x_numel = x->numel(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), x_numel); - - Tensor x_cpu; - framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu); - T* x_cpu_data = x_cpu.data(); - - for (int64_t i = 0; i < x_numel; ++i) { - PADDLE_ENFORCE_GE( - x_cpu_data[i], static_cast(0), - platform::errors::InvalidArgument( - "Illegal input, input must be greater than or equal to 0")); - PADDLE_ENFORCE_LE( - x_cpu_data[i], static_cast(1), - platform::errors::InvalidArgument( - "Illegal input, input must be less than or equal to 1")); - } - auto& dev_ctx = ctx.cuda_device_context(); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, x_numel); GPUBCELossForward<<>>(x_data, labels->data(), @@ -102,9 +91,10 @@ class BCELossGradCUDAKernel : public framework::OpKernel { auto* labels = ctx.Input("Label"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - auto dx_data = dx->mutable_data(ctx.GetPlace()); int x_numel = x->numel(); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.cuda_device_context(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(dev_ctx, x_numel);