Skip to content

Commit

Permalink
fix data transform bug of interpolate op (#44401)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Jul 18, 2022
1 parent b2224e6 commit c6bf881
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 20 deletions.
25 changes: 20 additions & 5 deletions paddle/phi/kernels/cpu/interpolate_grad_kernel.cc
Expand Up @@ -1041,28 +1041,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
25 changes: 20 additions & 5 deletions paddle/phi/kernels/cpu/interpolate_kernel.cc
Expand Up @@ -1193,7 +1193,10 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
CPU,
ALL_LAYOUT,
Expand All @@ -1202,24 +1205,36 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
double,
int,
int64_t,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2,
CPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
uint8_t) {}
uint8_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2,
CPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
25 changes: 20 additions & 5 deletions paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
Expand Up @@ -1574,28 +1574,43 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
double) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
25 changes: 20 additions & 5 deletions paddle/phi/kernels/gpu/interpolate_kernel.cu
Expand Up @@ -1446,33 +1446,48 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
phi::BilinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
GPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(trilinear_interp_v2,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(linear_interp_v2,
GPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(bicubic_interp_v2,
GPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double,
int) {}
int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}

0 comments on commit c6bf881

Please sign in to comment.