aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_segment_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_segment_test.cc')
-rw-r--r--tensorflow/core/framework/op_segment_test.cc142
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc
new file mode 100644
index 0000000000..6297718df8
--- /dev/null
+++ b/tensorflow/core/framework/op_segment_test.cc
@@ -0,0 +1,142 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class OpSegmentTest : public ::testing::Test {
+ protected:
+ DeviceBase device_;
+ std::vector<NodeDef> int32_nodedefs_;
+ std::vector<NodeDef> float_nodedefs_;
+
+ OpSegmentTest() : device_(Env::Default()) {
+ RequireDefaultOps();
+ for (int i = 0; i < 10; ++i) {
+ NodeDef def;
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_INT32)
+ .Input("y", 0, DT_INT32)
+ .Finalize(&def));
+ int32_nodedefs_.push_back(def);
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_FLOAT)
+ .Input("y", 0, DT_FLOAT)
+ .Finalize(&def));
+ float_nodedefs_.push_back(def);
+ }
+ }
+
+ void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
+ ASSERT_NE(op, nullptr);
+ EXPECT_EQ(expected.DebugString(), op->def().DebugString());
+ EXPECT_EQ(2, op->num_inputs());
+ EXPECT_EQ(dt, op->input_type(0));
+ EXPECT_EQ(dt, op->input_type(1));
+ EXPECT_EQ(1, op->num_outputs());
+ EXPECT_EQ(dt, op->output_type(0));
+ }
+
+ OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
+ return [this, ndef](OpKernel** kernel) {
+ Status s;
+ auto created =
+ CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s);
+ if (s.ok()) {
+ *kernel = created.release();
+ }
+ return s;
+ };
+ }
+};
+
+TEST_F(OpSegmentTest, Basic) {
+ OpSegment opseg;
+ OpKernel* op;
+
+ opseg.AddHold("A");
+ opseg.AddHold("B");
+ for (int i = 0; i < 10; ++i) {
+ // Register in session A.
+ auto* ndef = &float_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_FLOAT);
+
+ // Register in session B.
+ ndef = &int32_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_INT32);
+ }
+
+ auto reterr = [](OpKernel** kernel) {
+ return errors::Internal("Should not be called");
+ };
+ for (int i = 0; i < 10; ++i) {
+ // Lookup op in session A.
+ EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
+
+ // Lookup op in session B.
+ EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
+ }
+
+ opseg.RemoveHold("A");
+ opseg.RemoveHold("B");
+}
+
+TEST_F(OpSegmentTest, SessionNotFound) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+}
+
+TEST_F(OpSegmentTest, CreateFailure) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ def.set_op("nonexistop");
+ opseg.AddHold("A");
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+ opseg.RemoveHold("A");
+}
+
+TEST_F(OpSegmentTest, AddRemoveHolds) {
+ OpSegment opseg;
+ OpKernel* op;
+ const auto& ndef = int32_nodedefs_[0];
+
+ // No op.
+ opseg.RemoveHold("null");
+
+ // Thread1 register the op and wants to ensure it alive.
+ opseg.AddHold("foo");
+ EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
+
+ // Thread2 starts some execution needs "op" to be alive.
+ opseg.AddHold("foo");
+
+ // Thread1 clears session "foo". E.g., a master sends CleanupGraph
+ // before an execution finishes.
+ opseg.RemoveHold("foo");
+
+ // Thread2 should still be able to access "op".
+ ValidateOpAndTypes(op, ndef, DT_INT32);
+
+ // Thread2 then remove its hold on "foo".
+ opseg.RemoveHold("foo");
+}
+
+} // namespace tensorflow