From 4d134bad0403ebb5722144d8f859a04a5f21efc2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 22 May 2018 13:14:18 -0700 Subject: Move executor_test.cc to tensorflow/core/common_runtime/. PiperOrigin-RevId: 197611583 --- tensorflow/core/BUILD | 38 ++ tensorflow/core/common_runtime/executor_test.cc | 413 +++++++++++++++++++++ tensorflow/core/common_runtime/testlib_ops.cc | 95 +++++ tensorflow/core/distributed_runtime/BUILD | 19 +- .../core/distributed_runtime/executor_test.cc | 413 --------------------- tensorflow/core/distributed_runtime/master_test.cc | 2 +- tensorflow/core/distributed_runtime/rpc/BUILD | 16 +- .../distributed_runtime/rpc/grpc_testlib_ops.cc | 85 ----- 8 files changed, 560 insertions(+), 521 deletions(-) create mode 100644 tensorflow/core/common_runtime/executor_test.cc create mode 100644 tensorflow/core/common_runtime/testlib_ops.cc delete mode 100644 tensorflow/core/distributed_runtime/executor_test.cc delete mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5d63cd68ae..05b8423e15 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1108,6 +1108,7 @@ cc_library( ":shape_inference_testutil", ":tensor_testutil", ":test", + ":testlib_ops", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", @@ -1115,6 +1116,18 @@ cc_library( ], ) +cc_library( + name = "testlib_ops", + testonly = 1, + srcs = ["common_runtime/testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -3748,6 +3761,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_executor_test", + size = "small", + srcs = ["common_runtime/executor_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + tf_cc_test( name = "common_runtime_function_test", size = "small", diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc new file mode 100644 index 0000000000..e34224205b --- /dev/null +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -0,0 +1,413 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")), + + step_stats_collector_(&step_stats_) { + SessionOptions options; + thread_pool_ = ComputePool(options); + } + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(std::unique_ptr graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_)); + runner_ = [this](std::function fn) { thread_pool_->Schedule(fn); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.stats_collector = &step_stats_collector_; + args.runner = runner_; + return exec_->Run(args); + } + + thread::ThreadPool* thread_pool_ = nullptr; + Device* device_ = nullptr; + Executor* exec_ = nullptr; + StepStatsCollector step_stats_collector_; + StepStats step_stats_; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A int32 val -> Tensor +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A bool val -> Tensor +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A double val -> Tensor +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// Tensor -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +static uint64 kIncarnation = 1; // Uses in following tests. + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + CHECK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result) + .ok()); + return result; +} + +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/device:GPU:0" + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0), + false)); // in1 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g.get(), v, v); + } + // out <- v10 + test::graph::Send(g.get(), v, "b", BOB, 1, ALICE); + Create(std::move(g)); + Rendezvous::Args args; + // a = 1.0 + TF_ASSERT_OK( + rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); + EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + std::vector nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(testing::RandomSeed(), 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE); +} + +TEST_F(ExecutorTest, RandomTree) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); + Rendezvous::Args args; + TF_ASSERT_OK( + rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); + EXPECT_EQ(4096.0, V(out)); +} + +void BuildConcurrentAddAssign(Graph* g) { + auto one = test::graph::Constant(g, V(1.0)); + // A variable holds one float. + auto var = test::graph::Var(g, DT_FLOAT, TensorShape({})); + // Initilize the variable with 1.0. + auto init = test::graph::Assign(g, var, one); + // Output + auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB); + // Have many concurrent computation. Each does v = v + 1. + for (int i = 0; i < 1024; ++i) { + auto add = test::graph::Add(g, var, one); + g->AddControlEdge(init, add); // Ensures run after init. + auto assign = test::graph::Assign(g, var, add); + g->AddControlEdge(assign, out); + } +} + +#ifndef THREAD_SANITIZER +TEST_F(ExecutorTest, ConcurrentAddAssign) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildConcurrentAddAssign(g.get()); + Create(std::move(g)); + for (int iters = 0; iters < 16; ++iters) { + Rendezvous* rendez = NewLocalRendezvous(); + TF_ASSERT_OK(Run(rendez)); + Rendezvous::Args args; + Tensor out; + bool is_dead; + TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out, + &is_dead)); + VLOG(1) << "Get " << V(out); + EXPECT_LE(V(out), 1025.0); + rendez->Unref(); + } +} +#endif + +TEST_F(ExecutorTest, SimpleSwitchLive) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(false)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_EQ(1.0, V(out)); // out = 1.0 + EXPECT_FALSE(is_dead); +} + +TEST_F(ExecutorTest, SimpleSwitchDead) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(true)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_TRUE(is_dead); +} + +TEST_F(ExecutorTest, Abort) { + // e = a + b + c + d + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB); + auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB); + auto add0 = test::graph::Add(g.get(), in0, in1); + auto add1 = test::graph::Add(g.get(), in2, in3); + auto add2 = test::graph::Add(g.get(), add0, add1); + test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE); + Create(std::move(g)); + + // Needs 4 inputs (recv). One of them is aborted. + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + rendez_->StartAbort(errors::Aborted("")); + rendez_->Unref(); + }); + EXPECT_TRUE(errors::IsAborted(Run(rendez_))); + Tensor out = V(-1); + bool is_dead = false; + EXPECT_TRUE(errors::IsAborted(rendez_->Recv( + Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead))); + // At this point there can still be pending (albeit Aborted) Send + // closures holding Refs on rendez_. We need to wait for them, or + // else there can be a memory leak at termination. + while (!rendez_->RefCountIsOne()) + ; +} + +TEST_F(ExecutorTest, RecvInvalidDtype) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + // An input vector of type float of size 1. + auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB); + // A floating point variable vector of size 1. + auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1})); + // Initialize the variable with input. + auto init = test::graph::Assign(g.get(), var, one); + // Output + auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE); + g->AddControlEdge(init, two); // Ensures run after init. + Create(std::move(g)); + Rendezvous* rendez = NewLocalRendezvous(); + // Send a double instead of float. + TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(), + VD(1.0), false)); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInternal(Run(rendez))); + Tensor output; + bool is_dead; + EXPECT_TRUE(errors::IsInternal(rendez->Recv( + Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead))); + rendez->Unref(); +} + +TEST_F(ExecutorTest, RecvInvalidRefDtype) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + // A var that always produces as invalid dtype. + auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE); + test::graph::Send(g.get(), var, "out", BOB, 1, ALICE); + Create(std::move(g)); + Rendezvous* rendez = NewLocalRendezvous(); + EXPECT_TRUE(errors::IsInternal(Run(rendez))); + Tensor output; + bool is_dead; + EXPECT_TRUE(errors::IsInternal(rendez->Recv( + Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead))); + rendez->Unref(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/testlib_ops.cc b/tensorflow/core/common_runtime/testlib_ops.cc new file mode 100644 index 0000000000..a0139c3ee5 --- /dev/null +++ b/tensorflow/core/common_runtime/testlib_ops.cc @@ -0,0 +1,95 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace test { + +// ErrorOp::Compute returns an error. +REGISTER_OP("Error") + .Input("in: T") + .Output("out: T") + .Attr("T: type") + .Attr("message: string") + .SetShapeFn(shape_inference::UnknownShape); +class ErrorOp : public OpKernel { + public: + explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_)); + } + + void Compute(OpKernelContext* ctx) override { + ctx->SetStatus(errors::Internal(errmsg_)); + } + + private: + string errmsg_; +}; +REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp); + +REGISTER_OP("InvalidRefType") + .Output("out: Ref(TIn)") + .Attr("TIn: type") + .Attr("TOut: type") + .SetShapeFn(shape_inference::UnknownShape); +class InvalidRefType : public OpKernel { + public: + explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_)); + output_ = Tensor(dtout_, TensorShape({})); + } + + void Compute(OpKernelContext* ctx) override { + ctx->set_output_ref(0, &mu_, &output_); + } + + private: + DataType dtout_; + mutex mu_; + Tensor output_; +}; +REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), + InvalidRefType); + +// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns +// its input. +REGISTER_OP("Delay") + .Input("in: T") + .Output("out: T") + .Attr("T: type") + .Attr("micros: int") + .SetShapeFn(shape_inference::UnchangedShape); +class DelayOp : public AsyncOpKernel { + public: + explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", µs_)); + } + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ctx->set_output(0, ctx->input(0)); + ctx->env()->SchedClosureAfter(micros_, done); + } + + private: + int64 micros_; +}; +REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp); + +} // namespace test +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 18b7069dbe..ead698d787 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -561,17 +561,19 @@ tf_cc_test( ], ) -# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends -# on grpc_testlib. -tf_cuda_cc_tests( - name = "executor_tests", +tf_cuda_cc_test( + name = "master_test", size = "medium", srcs = [ - "executor_test.cc", - #"master_test.cc", # TODO(b/27683709): Re-enable when not flaky. + "master_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "manual", # TODO(b/27683709): Re-enable when not flaky. + "notap", # TODO(b/27683709): Re-enable when not flaky. + "noguitar", # TODO(b/27683709): Re-enable when not flaky. + "nooss", # TODO(b/27683709): Re-enable when not flaky. + ], deps = [ ":master", ":remote_device", @@ -588,6 +590,7 @@ tf_cuda_cc_tests( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_master_service_impl", "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", @@ -648,10 +651,10 @@ tf_cuda_cc_test( "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", - "//tensorflow/core/distributed_runtime/rpc:grpc_testlib_ops", "//tensorflow/core/kernels:aggregate_ops", "//tensorflow/core/kernels:array", ], diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc deleted file mode 100644 index e34224205b..0000000000 --- a/tensorflow/core/distributed_runtime/executor_test.cc +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/common_runtime/process_util.h" -#include "tensorflow/core/common_runtime/step_stats_collector.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/random/simple_philox.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/public/session_options.h" - -namespace tensorflow { - -class ExecutorTest : public ::testing::Test { - protected: - ExecutorTest() - : device_(DeviceFactory::NewDevice("CPU", {}, - "/job:localhost/replica:0/task:0")), - - step_stats_collector_(&step_stats_) { - SessionOptions options; - thread_pool_ = ComputePool(options); - } - - ~ExecutorTest() override { - // There should always be exactly one Ref left on the Rendezvous - // when the test completes. - CHECK(rendez_->Unref()); - delete exec_; - delete device_; - } - - // Resets executor_ with a new executor based on a graph 'gdef'. - void Create(std::unique_ptr graph) { - const int version = graph->versions().producer(); - LocalExecutorParams params; - params.device = device_; - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); - }; - params.delete_kernel = [](OpKernel* kernel) { - DeleteNonCachedKernel(kernel); - }; - delete exec_; - TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_)); - runner_ = [this](std::function fn) { thread_pool_->Schedule(fn); }; - rendez_ = NewLocalRendezvous(); - } - - Status Run(Rendezvous* rendez) { - Executor::Args args; - args.rendezvous = rendez; - args.stats_collector = &step_stats_collector_; - args.runner = runner_; - return exec_->Run(args); - } - - thread::ThreadPool* thread_pool_ = nullptr; - Device* device_ = nullptr; - Executor* exec_ = nullptr; - StepStatsCollector step_stats_collector_; - StepStats step_stats_; - Executor::Args::Runner runner_; - Rendezvous* rendez_ = nullptr; -}; - -// A float val -> Tensor -Tensor V(const float val) { - Tensor tensor(DT_FLOAT, TensorShape({})); - tensor.scalar()() = val; - return tensor; -} - -// A int32 val -> Tensor -Tensor VI(const int32 val) { - Tensor tensor(DT_INT32, TensorShape({})); - tensor.scalar()() = val; - return tensor; -} - -// A bool val -> Tensor -Tensor VB(const bool val) { - Tensor tensor(DT_BOOL, TensorShape({})); - tensor.scalar()() = val; - return tensor; -} - -// A double val -> Tensor -Tensor VD(const double val) { - Tensor tensor(DT_DOUBLE, TensorShape({})); - tensor.scalar()() = val; - return tensor; -} - -// Tensor -> a float val. -float V(const Tensor& tensor) { - CHECK_EQ(tensor.dtype(), DT_FLOAT); - CHECK(TensorShapeUtils::IsScalar(tensor.shape())); - return tensor.scalar()(); -} - -static uint64 kIncarnation = 1; // Uses in following tests. - -Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, - const string& receiver, const string& name) { - Rendezvous::ParsedKey result; - CHECK( - Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, - name, FrameAndIter(0, 0)), - &result) - .ok()); - return result; -} - -#define ALICE "/job:j/replica:0/task:0/cpu:0" -#define BOB "/job:j/replica:0/task:0/device:GPU:0" - -TEST_F(ExecutorTest, SimpleAdd) { - // c = a + b - std::unique_ptr g(new Graph(OpRegistry::Global())); - auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); - auto tmp = test::graph::Add(g.get(), in0, in1); - test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); - Create(std::move(g)); - Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), - false)); // in0 = 1.0 - TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0), - false)); // in1 = 1.0 - TF_ASSERT_OK(Run(rendez_)); - Tensor out = V(-1); - bool is_dead = false; - TF_ASSERT_OK( - rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); - EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0 -} - -TEST_F(ExecutorTest, SelfAdd) { - // v0 <- a - // v1 = v0 + v0 - // v2 = v1 + v1 - // ... ... - // v10 = v9 + v9 - // - // b <- v10 - // All nodes are executed by one thread. - std::unique_ptr g(new Graph(OpRegistry::Global())); - auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); - const int N = 10; - for (int i = 1; i <= N; ++i) { - v = test::graph::Add(g.get(), v, v); - } - // out <- v10 - test::graph::Send(g.get(), v, "b", BOB, 1, ALICE); - Create(std::move(g)); - Rendezvous::Args args; - // a = 1.0 - TF_ASSERT_OK( - rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); - TF_ASSERT_OK(Run(rendez_)); - Tensor out = V(-1); - bool is_dead = false; - TF_ASSERT_OK( - rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); - EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0 -} - -// Builds a graph which adds N copies of one variable "in". I.e., -// a + a + a + ... + a -// The returned graph is parenthesized ramdonly. I.e., -// a + ((a + a) + a) -// (a + a) + (a + a) -// ((a + a) + a) + a -// are all possibly generated. -void BuildTree(int N, Graph* g) { - CHECK_GT(N, 1); - // A single input node "in". - auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - std::vector nodes; - int i = 0; - // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... - for (; i < N; ++i) { - nodes.push_back(test::graph::Identity(g, in, 0)); - } - random::PhiloxRandom philox(testing::RandomSeed(), 17); - random::SimplePhilox rnd(&philox); - while (nodes.size() > 1) { - // Randomly pick two from nodes and add them. The resulting node - // is named lik n10, n11, .... and is put back into "nodes". - int x = rnd.Uniform(nodes.size()); - auto in0 = nodes[x]; - nodes[x] = nodes.back(); - nodes.resize(nodes.size() - 1); - x = rnd.Uniform(nodes.size()); - auto in1 = nodes[x]; - // node = in0 + in1. - nodes[x] = test::graph::Add(g, in0, in1); - } - // The final output node "out". - test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE); -} - -TEST_F(ExecutorTest, RandomTree) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - BuildTree(4096, g.get()); - Create(std::move(g)); - Rendezvous::Args args; - TF_ASSERT_OK( - rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); - TF_ASSERT_OK(Run(rendez_)); - Tensor out = V(-1); - bool is_dead = false; - TF_ASSERT_OK( - rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); - EXPECT_EQ(4096.0, V(out)); -} - -void BuildConcurrentAddAssign(Graph* g) { - auto one = test::graph::Constant(g, V(1.0)); - // A variable holds one float. - auto var = test::graph::Var(g, DT_FLOAT, TensorShape({})); - // Initilize the variable with 1.0. - auto init = test::graph::Assign(g, var, one); - // Output - auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB); - // Have many concurrent computation. Each does v = v + 1. - for (int i = 0; i < 1024; ++i) { - auto add = test::graph::Add(g, var, one); - g->AddControlEdge(init, add); // Ensures run after init. - auto assign = test::graph::Assign(g, var, add); - g->AddControlEdge(assign, out); - } -} - -#ifndef THREAD_SANITIZER -TEST_F(ExecutorTest, ConcurrentAddAssign) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - BuildConcurrentAddAssign(g.get()); - Create(std::move(g)); - for (int iters = 0; iters < 16; ++iters) { - Rendezvous* rendez = NewLocalRendezvous(); - TF_ASSERT_OK(Run(rendez)); - Rendezvous::Args args; - Tensor out; - bool is_dead; - TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out, - &is_dead)); - VLOG(1) << "Get " << V(out); - EXPECT_LE(V(out), 1025.0); - rendez->Unref(); - } -} -#endif - -TEST_F(ExecutorTest, SimpleSwitchLive) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g.get(), VB(false)); - auto tmp = test::graph::Switch(g.get(), in0, in1); - test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); - Create(std::move(g)); - Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), - false)); // in0 = 1.0 - TF_ASSERT_OK(Run(rendez_)); - Tensor out = V(-1); - bool is_dead = false; - TF_ASSERT_OK( - rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); - EXPECT_EQ(1.0, V(out)); // out = 1.0 - EXPECT_FALSE(is_dead); -} - -TEST_F(ExecutorTest, SimpleSwitchDead) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g.get(), VB(true)); - auto tmp = test::graph::Switch(g.get(), in0, in1); - test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); - Create(std::move(g)); - Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), - false)); // in0 = 1.0 - TF_ASSERT_OK(Run(rendez_)); - Tensor out = V(-1); - bool is_dead = false; - TF_ASSERT_OK( - rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); - EXPECT_TRUE(is_dead); -} - -TEST_F(ExecutorTest, Abort) { - // e = a + b + c + d - std::unique_ptr g(new Graph(OpRegistry::Global())); - auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); - auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB); - auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB); - auto add0 = test::graph::Add(g.get(), in0, in1); - auto add1 = test::graph::Add(g.get(), in2, in3); - auto add2 = test::graph::Add(g.get(), add0, add1); - test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE); - Create(std::move(g)); - - // Needs 4 inputs (recv). One of them is aborted. - rendez_->Ref(); - SchedClosure([this]() { - Env::Default()->SleepForMicroseconds(100 * 1000); - Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), - Rendezvous::Args(), V(1.0), false); - rendez_->Unref(); - }); - rendez_->Ref(); - SchedClosure([this]() { - Env::Default()->SleepForMicroseconds(100 * 1000); - Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), - Rendezvous::Args(), V(1.0), false); - rendez_->Unref(); - }); - rendez_->Ref(); - SchedClosure([this]() { - Env::Default()->SleepForMicroseconds(100 * 1000); - Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"), - Rendezvous::Args(), V(1.0), false); - rendez_->Unref(); - }); - rendez_->Ref(); - SchedClosure([this]() { - Env::Default()->SleepForMicroseconds(100 * 1000); - rendez_->StartAbort(errors::Aborted("")); - rendez_->Unref(); - }); - EXPECT_TRUE(errors::IsAborted(Run(rendez_))); - Tensor out = V(-1); - bool is_dead = false; - EXPECT_TRUE(errors::IsAborted(rendez_->Recv( - Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead))); - // At this point there can still be pending (albeit Aborted) Send - // closures holding Refs on rendez_. We need to wait for them, or - // else there can be a memory leak at termination. - while (!rendez_->RefCountIsOne()) - ; -} - -TEST_F(ExecutorTest, RecvInvalidDtype) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - // An input vector of type float of size 1. - auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB); - // A floating point variable vector of size 1. - auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1})); - // Initialize the variable with input. - auto init = test::graph::Assign(g.get(), var, one); - // Output - auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE); - g->AddControlEdge(init, two); // Ensures run after init. - Create(std::move(g)); - Rendezvous* rendez = NewLocalRendezvous(); - // Send a double instead of float. - TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(), - VD(1.0), false)); - // Fails due to invalid dtype. - EXPECT_TRUE(errors::IsInternal(Run(rendez))); - Tensor output; - bool is_dead; - EXPECT_TRUE(errors::IsInternal(rendez->Recv( - Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead))); - rendez->Unref(); -} - -TEST_F(ExecutorTest, RecvInvalidRefDtype) { - std::unique_ptr g(new Graph(OpRegistry::Global())); - // A var that always produces as invalid dtype. - auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE); - test::graph::Send(g.get(), var, "out", BOB, 1, ALICE); - Create(std::move(g)); - Rendezvous* rendez = NewLocalRendezvous(); - EXPECT_TRUE(errors::IsInternal(Run(rendez))); - Tensor output; - bool is_dead; - EXPECT_TRUE(errors::IsInternal(rendez->Recv( - Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead))); - rendez->Unref(); -} - -} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index f2c1f3489c..0826a90860 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "grpc++/grpc++.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/allocator.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/master.pb.h" -#include "tensorflow/core/protobuf/master_service.grpc.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 40028ee241..4b2747f26d 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -314,18 +314,6 @@ tf_cc_binary( ], ) -tf_cuda_library( - name = "grpc_testlib_ops", - testonly = 1, - srcs = ["grpc_testlib_ops.cc"], - linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) - tf_cc_binary( name = "grpc_testlib_server", testonly = 1, @@ -334,11 +322,11 @@ tf_cc_binary( ], deps = [ ":grpc_server_lib", - ":grpc_testlib_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:cwise_op", @@ -362,12 +350,12 @@ tf_cuda_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":grpc_session", - ":grpc_testlib_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", + "//tensorflow/core:testlib", ], alwayslink = 1, ) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc deleted file mode 100644 index 5597ee7a76..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace test { - -// ErrorOp::Compute returns an error. -REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr( - "message: string"); -class ErrorOp : public OpKernel { - public: - explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_)); - } - - void Compute(OpKernelContext* ctx) override { - ctx->SetStatus(errors::Internal(errmsg_)); - } - - private: - string errmsg_; -}; -REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp); - -REGISTER_OP("InvalidRefType") - .Output("out: Ref(TIn)") - .Attr("TIn: type") - .Attr("TOut: type"); -class InvalidRefType : public OpKernel { - public: - explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_)); - output_ = Tensor(dtout_, TensorShape({})); - } - - void Compute(OpKernelContext* ctx) override { - ctx->set_output_ref(0, &mu_, &output_); - } - - private: - DataType dtout_; - mutex mu_; - Tensor output_; -}; -REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), - InvalidRefType); - -// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns -// its input. -REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr( - "micros: int"); -class DelayOp : public AsyncOpKernel { - public: - explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", µs_)); - } - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - ctx->set_output(0, ctx->input(0)); - ctx->env()->SchedClosureAfter(micros_, done); - } - - private: - int64 micros_; -}; -REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp); - -} // namespace test -} // namespace tensorflow -- cgit v1.2.3