Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609293597
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent a3751b8 commit 9e40776
Show file tree
Hide file tree
Showing 16 changed files with 39 additions and 32 deletions.
6 changes: 3 additions & 3 deletions tensorflow/core/data/service/common.cc
Expand Up @@ -64,7 +64,7 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) {
return absl::OkStatus();
}

StatusOr<AutoShardPolicy> ToAutoShardPolicy(
absl::StatusOr<AutoShardPolicy> ToAutoShardPolicy(
const ProcessingModeDef::ShardingPolicy sharding_policy) {
switch (sharding_policy) {
case ProcessingModeDef::FILE:
Expand All @@ -87,7 +87,7 @@ StatusOr<AutoShardPolicy> ToAutoShardPolicy(
}
}

StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
absl::StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
std::string str_upper = absl::AsciiStrToUpper(s);
if (str_upper.empty() || str_upper == kAuto) {
return TARGET_WORKERS_AUTO;
Expand Down Expand Up @@ -115,7 +115,7 @@ std::string TargetWorkersToString(TargetWorkers target_workers) {
}
}

StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s) {
absl::StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s) {
std::string str_upper = absl::AsciiStrToUpper(s);
if (str_upper == kColocated) {
return DEPLOYMENT_MODE_COLOCATED;
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/data/service/common.h
Expand Up @@ -71,19 +71,19 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode);

// Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an
// internal error if `sharding_policy` is not supported.
StatusOr<AutoShardPolicy> ToAutoShardPolicy(
absl::StatusOr<AutoShardPolicy> ToAutoShardPolicy(
ProcessingModeDef::ShardingPolicy sharding_policy);

// Parses a string representing a `TargetWorkers` (case-insensitive).
// Returns InvalidArgument if the string is not recognized.
StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s);
absl::StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s);

// Converts a `TargetWorkers` enum to string.
std::string TargetWorkersToString(TargetWorkers target_workers);

// Parses a string representing a `DeploymentMode` (case-insensitive).
// Returns InvalidArgument if the string is not recognized.
StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s);
absl::StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s);

// Returns true if `status` is a retriable error that indicates preemption.
bool IsPreemptedError(const Status& status);
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/data/service/cross_trainer_cache_test.cc
Expand Up @@ -51,7 +51,7 @@ using ::testing::UnorderedElementsAreArray;

class InfiniteRange : public CachableSequence<int64_t> {
public:
StatusOr<int64_t> GetNext() override { return next_++; }
absl::StatusOr<int64_t> GetNext() override { return next_++; }
size_t GetElementSizeBytes(const int64_t& element) const override {
return sizeof(element);
}
Expand All @@ -63,7 +63,7 @@ class InfiniteRange : public CachableSequence<int64_t> {

class TensorDataset : public CachableSequence<Tensor> {
public:
StatusOr<Tensor> GetNext() override { return Tensor("Test Tensor"); }
absl::StatusOr<Tensor> GetNext() override { return Tensor("Test Tensor"); }
size_t GetElementSizeBytes(const Tensor& element) const override {
return element.TotalBytes();
}
Expand All @@ -73,7 +73,7 @@ class SlowDataset : public CachableSequence<Tensor> {
public:
explicit SlowDataset(absl::Duration delay) : delay_(delay) {}

StatusOr<Tensor> GetNext() override {
absl::StatusOr<Tensor> GetNext() override {
Env::Default()->SleepForMicroseconds(absl::ToInt64Microseconds(delay_));
return Tensor("Test Tensor");
}
Expand Down Expand Up @@ -369,7 +369,7 @@ TEST(CrossTrainerCacheTest, Cancel) {
/*thread_options=*/{}, /*name=*/absl::StrCat("Trainer_", i),
[&cache, &status, &mu]() {
for (int j = 0; true; ++j) {
StatusOr<std::shared_ptr<const Tensor>> tensor =
absl::StatusOr<std::shared_ptr<const Tensor>> tensor =
cache.Get(absl::StrCat("Trainer_", j % 1000));
{
mutex_lock l(mu);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/data/service/data_service_test.cc
Expand Up @@ -235,7 +235,7 @@ TEST(DataServiceTest, GcMissingClientsWithSmallTimeout) {
TF_ASSERT_OK(dataset_client.GetTasks(iteration_client_id).status());
// Iteration should be garbage collected within 10 seconds.
absl::Time wait_start = absl::Now();
TF_ASSERT_OK(WaitWhile([&]() -> StatusOr<bool> {
TF_ASSERT_OK(WaitWhile([&]() -> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(size_t num_iterations, cluster.NumActiveIterations());
return num_iterations > 0;
}));
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/data/service/data_transfer.h
Expand Up @@ -91,7 +91,7 @@ class DataTransferClient {

// Returns a string describing properties of the client relevant for checking
// compatibility with a server for a given protocol.
virtual StatusOr<std::string> GetCompatibilityInfo() const {
virtual absl::StatusOr<std::string> GetCompatibilityInfo() const {
return std::string();
}

Expand Down Expand Up @@ -130,7 +130,7 @@ class DataTransferServer {

// Returns a string describing properties of the server relevant for checking
// compatibility with a client for a given protocol.
virtual StatusOr<std::string> GetCompatibilityInfo() const {
virtual absl::StatusOr<std::string> GetCompatibilityInfo() const {
return std::string();
}
};
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.cc
Expand Up @@ -84,7 +84,8 @@ Status DataServiceDispatcherClient::Initialize() {
return absl::OkStatus();
}

StatusOr<WorkerHeartbeatResponse> DataServiceDispatcherClient::WorkerHeartbeat(
absl::StatusOr<WorkerHeartbeatResponse>
DataServiceDispatcherClient::WorkerHeartbeat(
const WorkerHeartbeatRequest& request) {
WorkerHeartbeatResponse response;
grpc::ClientContext client_ctx;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/data/service/dispatcher_client.h
Expand Up @@ -50,7 +50,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
// registered with the dispatcher, this will register the worker. The
// dispatcher will report which new tasks the worker should run, and which
// tasks it should delete.
StatusOr<WorkerHeartbeatResponse> WorkerHeartbeat(
absl::StatusOr<WorkerHeartbeatResponse> WorkerHeartbeat(
const WorkerHeartbeatRequest& request);

// Updates the dispatcher with information about the worker's state.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/data/service/dispatcher_client_test.cc
Expand Up @@ -89,7 +89,7 @@ class DispatcherClientTest : public ::testing::Test {
}

// Creates a dataset and returns the dataset ID.
StatusOr<std::string> RegisterDataset(
absl::StatusOr<std::string> RegisterDataset(
const DatasetDef& dataset, const DataServiceMetadata& metadata,
const std::optional<std::string>& requested_dataset_id = std::nullopt) {
std::string dataset_id;
Expand All @@ -99,7 +99,7 @@ class DispatcherClientTest : public ::testing::Test {
}

// Starts snapshots and returns the directories.
StatusOr<absl::flat_hash_set<std::string>> StartDummySnapshots(
absl::StatusOr<absl::flat_hash_set<std::string>> StartDummySnapshots(
int64_t num_snapshots) {
DistributedSnapshotMetadata metadata =
CreateDummyDistributedSnapshotMetadata();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_impl.cc
Expand Up @@ -622,7 +622,8 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
return absl::OkStatus();
}

StatusOr<std::optional<std::string>> DataServiceDispatcherImpl::FindDataset(
absl::StatusOr<std::optional<std::string>>
DataServiceDispatcherImpl::FindDataset(
const GetOrRegisterDatasetRequest& request)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::shared_ptr<const Dataset> existing_dataset;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/data/service/dispatcher_impl.h
Expand Up @@ -217,7 +217,7 @@ class DataServiceDispatcherImpl {
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Finds the dataset ID with the requested dataset ID.
// Returns nullptr if no such dataset exists.
StatusOr<std::optional<std::string>> FindDataset(
absl::StatusOr<std::optional<std::string>> FindDataset(
const GetOrRegisterDatasetRequest& request);
// Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
// stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/data/service/dispatcher_state.cc
Expand Up @@ -467,7 +467,7 @@ Status DispatcherState::ValidateWorker(absl::string_view worker_address) const {
return worker_index_resolver_.ValidateWorker(worker_address);
}

StatusOr<int64_t> DispatcherState::GetWorkerIndex(
absl::StatusOr<int64_t> DispatcherState::GetWorkerIndex(
absl::string_view worker_address) const {
return worker_index_resolver_.GetWorkerIndex(worker_address);
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_state.h
Expand Up @@ -288,7 +288,8 @@ class DispatcherState {
// If the dispatcher config specifies worker addresses, `GetWorkerIndex`
// returns the worker index according to the list. This is useful for
// deterministically sharding a dataset among a fixed set of workers.
StatusOr<int64_t> GetWorkerIndex(absl::string_view worker_address) const;
absl::StatusOr<int64_t> GetWorkerIndex(
absl::string_view worker_address) const;

// Returns the paths of all snapshots initiated during the lifetime of this
// journal.
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/data/service/graph_rewriters.cc
Expand Up @@ -93,7 +93,7 @@ bool ShouldReplaceDynamicPort(absl::string_view config_address,
}
} // namespace

StatusOr<GraphDef>
absl::StatusOr<GraphDef>
RemoveCompressionMapRewriter::ApplyRemoveCompressionMapRewrite(
const GraphDef& graph_def) {
grappler::RemoveCompressionMap remove_compression_map;
Expand Down Expand Up @@ -122,15 +122,16 @@ RemoveCompressionMapRewriter::GetRewriteConfig() const {
return config;
}

StatusOr<AutoShardRewriter> AutoShardRewriter::Create(const TaskDef& task_def) {
absl::StatusOr<AutoShardRewriter> AutoShardRewriter::Create(
const TaskDef& task_def) {
TF_ASSIGN_OR_RETURN(
AutoShardPolicy auto_shard_policy,
ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy()));
return AutoShardRewriter(auto_shard_policy, task_def.num_workers(),
task_def.worker_index());
}

StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
absl::StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
const GraphDef& graph_def) {
if (auto_shard_policy_ == AutoShardPolicy::OFF) {
return graph_def;
Expand Down Expand Up @@ -214,7 +215,7 @@ void WorkerIndexResolver::AddWorker(absl::string_view worker_address) {
}
}

StatusOr<int64_t> WorkerIndexResolver::GetWorkerIndex(
absl::StatusOr<int64_t> WorkerIndexResolver::GetWorkerIndex(
absl::string_view worker_address) const {
const auto it = absl::c_find(worker_addresses_, worker_address);
if (it == worker_addresses_.cend()) {
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/data/service/graph_rewriters.h
Expand Up @@ -37,7 +37,7 @@ namespace data {
class RemoveCompressionMapRewriter {
public:
// Returns `graph_def` with the compression map removed.
StatusOr<GraphDef> ApplyRemoveCompressionMapRewrite(
absl::StatusOr<GraphDef> ApplyRemoveCompressionMapRewrite(
const GraphDef& graph_def);

private:
Expand All @@ -49,11 +49,11 @@ class AutoShardRewriter {
public:
// Creates an `AutoShardRewriter` according to `task_def`. Returns an error if
// the sharding policy is not a valid auto-shard policy.
static StatusOr<AutoShardRewriter> Create(const TaskDef& task_def);
static absl::StatusOr<AutoShardRewriter> Create(const TaskDef& task_def);

// Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns
// the same graph as `graph_def`. Otherwise, returns the re-written graph.
StatusOr<GraphDef> ApplyAutoShardRewrite(const GraphDef& graph_def);
absl::StatusOr<GraphDef> ApplyAutoShardRewrite(const GraphDef& graph_def);

private:
AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64_t num_workers,
Expand Down Expand Up @@ -97,7 +97,8 @@ class WorkerIndexResolver {

// Returns the worker index for the worker at `worker_address`. Returns a
// NotFound error if the worker is not registered.
StatusOr<int64_t> GetWorkerIndex(absl::string_view worker_address) const;
absl::StatusOr<int64_t> GetWorkerIndex(
absl::string_view worker_address) const;

private:
std::vector<std::string> worker_addresses_;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/data/service/graph_rewriters_test.cc
Expand Up @@ -49,7 +49,8 @@ using ::tensorflow::testing::StatusIs;
using ::testing::HasSubstr;
using ::testing::SizeIs;

StatusOr<NodeDef> GetNode(const GraphDef& graph_def, absl::string_view name) {
absl::StatusOr<NodeDef> GetNode(const GraphDef& graph_def,
absl::string_view name) {
for (const NodeDef& node : graph_def.node()) {
if (node.name() == name) {
return node;
Expand All @@ -59,7 +60,8 @@ StatusOr<NodeDef> GetNode(const GraphDef& graph_def, absl::string_view name) {
name, graph_def.ShortDebugString()));
}

StatusOr<int64_t> GetValue(const GraphDef& graph_def, absl::string_view name) {
absl::StatusOr<int64_t> GetValue(const GraphDef& graph_def,
absl::string_view name) {
for (const NodeDef& node : graph_def.node()) {
if (node.name() == name) {
return node.attr().at("value").tensor().int64_val()[0];
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/data/service/server_lib.cc
Expand Up @@ -212,7 +212,7 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer(
str_util::StringReplace(config_.data_transfer_address(), kPortPlaceholder,
absl::StrCat(transfer_server_->Port()),
/*replace_all=*/false));
StatusOr<std::string> compatibility_info =
absl::StatusOr<std::string> compatibility_info =
transfer_server_->GetCompatibilityInfo();
if (!compatibility_info.ok()) {
LOG(ERROR)
Expand Down

0 comments on commit 9e40776

Please sign in to comment.