Skip to content

Commit

Permalink
[XLA:TPU] Refactor HandleInputStreaming to process tuples using ForEa…
Browse files Browse the repository at this point in the history
…chSubshape

Previously, tuples in parameters were separately processed from non tuple parameters. This cleans up the code to handle both tuple and non tuple parameters with one block.

PiperOrigin-RevId: 632648325
  • Loading branch information
jvstokes authored and tensorflower-gardener committed May 11, 2024
1 parent 02fa87f commit ad631f7
Showing 1 changed file with 12 additions and 32 deletions.
44 changes: 12 additions & 32 deletions third_party/xla/xla/service/host_offloader.cc
Expand Up @@ -692,38 +692,18 @@ Status HostOffloader::HandleInputStreaming(HloComputation* computation) {
LOG(WARNING) << "Token parameters are not supported for streaming.";
continue;
}
if (entry_computation_layout.parameter_shape(i).IsTuple()) {
// Handle tuple parameters, which may contain streamed elements. Nested
// tuples are not supported.
const Shape& tuple_shape = entry_computation_layout.parameter_shape(i);
for (int j = 0; j < tuple_shape.tuple_shapes_size(); ++j) {
const Shape& tuple_element_shape = tuple_shape.tuple_shapes(j);
// TODO(b/335498881): Support nested tuples.
if (tuple_element_shape.IsTuple()) {
LOG(WARNING)
<< "Nested tuple parameters are not supported for streaming.";
continue;
}
TF_RET_CHECK(tuple_element_shape.has_layout());
if (tuple_element_shape.layout().memory_space() ==
kHostMemorySpaceColor) {
VLOG(4) << "Handling streamed element in tuple parameter: "
<< tuple_element_shape.ToString(/*print_layout=*/true);
const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(
computation->parameter_instruction(i), {j});
TF_RETURN_IF_ERROR(HandleStreamedBuffer(unique_buffer));
}
}
} else if (entry_computation_layout.parameter_layout(i)
.layout()
.memory_space() == kHostMemorySpaceColor) {
HloInstruction* streamed_input = computation->parameter_instruction(i);
VLOG(4) << "Handling streamed input: " << streamed_input->ToString();
const HloBuffer& unique_buffer =
alias_analysis_->GetUniqueBufferAt(streamed_input);

TF_RETURN_IF_ERROR(HandleStreamedBuffer(unique_buffer));
}
ShapeUtil::ForEachSubshape(
entry_computation_layout.parameter_shape(i),
[&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.has_layout() &&
subshape.layout().memory_space() == kHostMemorySpaceColor) {
VLOG(4) << "Handling streamed element in input with shape: "
<< subshape.ToString(true);
const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(
computation->parameter_instruction(i), {index});
TF_CHECK_OK(HandleStreamedBuffer(unique_buffer));
}
});
}
return OkStatus();
}
Expand Down

0 comments on commit ad631f7

Please sign in to comment.