diff options
Diffstat (limited to 'tensorflow/core/framework/op_segment_test.cc')
-rw-r--r-- | tensorflow/core/framework/op_segment_test.cc | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc new file mode 100644 index 0000000000..6297718df8 --- /dev/null +++ b/tensorflow/core/framework/op_segment_test.cc @@ -0,0 +1,142 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class OpSegmentTest : public ::testing::Test { + protected: + DeviceBase device_; + std::vector<NodeDef> int32_nodedefs_; + std::vector<NodeDef> float_nodedefs_; + + OpSegmentTest() : device_(Env::Default()) { + RequireDefaultOps(); + for (int i = 0; i < 10; ++i) { + NodeDef def; + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_INT32) + .Input("y", 0, DT_INT32) + .Finalize(&def)); + int32_nodedefs_.push_back(def); + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_FLOAT) + .Input("y", 0, DT_FLOAT) + .Finalize(&def)); + float_nodedefs_.push_back(def); + } + } + + void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) { + ASSERT_NE(op, nullptr); + EXPECT_EQ(expected.DebugString(), op->def().DebugString()); + EXPECT_EQ(2, op->num_inputs()); + EXPECT_EQ(dt, op->input_type(0)); + EXPECT_EQ(dt, op->input_type(1)); + EXPECT_EQ(1, op->num_outputs()); + EXPECT_EQ(dt, op->output_type(0)); + } + + OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { + return [this, ndef](OpKernel** kernel) { + Status s; + auto created = + CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s); + if (s.ok()) { + *kernel = created.release(); + } + return s; + }; + } +}; + +TEST_F(OpSegmentTest, Basic) { + OpSegment opseg; + OpKernel* op; + + opseg.AddHold("A"); + opseg.AddHold("B"); + for (int i = 0; i < 10; ++i) { + // Register in session A. + auto* ndef = &float_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_FLOAT); + + // Register in session B. + ndef = &int32_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_INT32); + } + + auto reterr = [](OpKernel** kernel) { + return errors::Internal("Should not be called"); + }; + for (int i = 0; i < 10; ++i) { + // Lookup op in session A. + EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT); + + // Lookup op in session B. + EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32); + } + + opseg.RemoveHold("A"); + opseg.RemoveHold("B"); +} + +TEST_F(OpSegmentTest, SessionNotFound) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; +} + +TEST_F(OpSegmentTest, CreateFailure) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + def.set_op("nonexistop"); + opseg.AddHold("A"); + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; + opseg.RemoveHold("A"); +} + +TEST_F(OpSegmentTest, AddRemoveHolds) { + OpSegment opseg; + OpKernel* op; + const auto& ndef = int32_nodedefs_[0]; + + // No op. + opseg.RemoveHold("null"); + + // Thread1 register the op and wants to ensure it alive. + opseg.AddHold("foo"); + EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef))); + + // Thread2 starts some execution needs "op" to be alive. + opseg.AddHold("foo"); + + // Thread1 clears session "foo". E.g., a master sends CleanupGraph + // before an execution finishes. + opseg.RemoveHold("foo"); + + // Thread2 should still be able to access "op". + ValidateOpAndTypes(op, ndef, DT_INT32); + + // Thread2 then remove its hold on "foo". + opseg.RemoveHold("foo"); +} + +} // namespace tensorflow |