Skip to content

Commit

Permalink
Use the input's shapes directly instead of getting them from the op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617953833
  • Loading branch information
tensorflower-gardener committed Apr 19, 2024
1 parent b33007d commit 992db01
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/tpu/kernels/BUILD
Expand Up @@ -986,6 +986,7 @@ cc_library(
name = "tpu_embedding_ops",
srcs = ["tpu_embedding_ops.cc"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context",
"//tensorflow/compiler/tf2xla:xla_helpers",
Expand Down
11 changes: 6 additions & 5 deletions tensorflow/core/tpu/kernels/tpu_embedding_ops.cc
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "absl/cleanup/cleanup.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -191,11 +192,11 @@ void CompileSendTPUEmbeddingGradients(
std::vector<xla::Shape> gradient_shapes;
auto builder = ctx->builder();
gradient_shapes.reserve(gradients.size());
for (xla::XlaOp op : gradients) {
// Gradient layout information is added by XLA, so we can just create
// default layout information.
xla::Shape gradient_shape = builder->GetShape(op).value();
xla::LayoutUtil::SetToDefaultLayout(&gradient_shape);
for (int i = 0; i < gradients.size(); ++i) {
DataType dtype = ctx->input_type(i);
xla::Shape gradient_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tf_gradient_shapes[i],
&gradient_shape));
gradient_shapes.push_back(gradient_shape);
}

Expand Down

0 comments on commit 992db01

Please sign in to comment.