aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/resource_mgr_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/resource_mgr_test.cc')
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc173
1 files changed, 173 insertions, 0 deletions
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
new file mode 100644
index 0000000000..9f7ce3dde3
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -0,0 +1,173 @@
+#include "tensorflow/core/framework/resource_mgr.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+class Resource : public ResourceBase {
+ public:
+ explicit Resource(const string& label) : label_(label) {}
+ ~Resource() override {}
+
+ string DebugString() { return strings::StrCat("R/", label_); }
+
+ private:
+ string label_;
+};
+
+class Other : public ResourceBase {
+ public:
+ explicit Other(const string& label) : label_(label) {}
+ ~Other() override {}
+
+ string DebugString() { return strings::StrCat("O/", label_); }
+
+ private:
+ string label_;
+};
+
+template <typename T>
+string Find(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ TF_CHECK_OK(rm.Lookup(container, name, &r));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+template <typename T>
+string LookupOrCreate(ResourceMgr* rm, const string& container,
+ const string& name, const string& label) {
+ T* r;
+ TF_CHECK_OK(rm->LookupOrCreate<T>(container, name, &r, [&label](T** ret) {
+ *ret = new T(label);
+ return Status::OK();
+ }));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+static void HasError(const Status& s, const string& substr) {
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ << s << ", expected substring " << substr;
+}
+
+template <typename T>
+Status FindErr(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ Status s = rm.Lookup(container, name, &r);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ResourceMgrTest, Basic) {
+ ResourceMgr rm;
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat")));
+ TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog")));
+ TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger")));
+
+ // Expected to fail.
+ HasError(rm.Create("foo", "bar", new Resource("kitty")),
+ "Already exists: Resource foo/bar");
+
+ // Expected to be found.
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+ EXPECT_EQ("R/dog", Find<Resource>(rm, "foo", "baz"));
+ EXPECT_EQ("O/tiger", Find<Other>(rm, "foo", "bar"));
+
+ // Expected to be not found.
+ HasError(FindErr<Resource>(rm, "bar", "foo"), "Not found: Container bar");
+ HasError(FindErr<Resource>(rm, "foo", "xxx"), "Not found: Resource foo/xxx");
+ HasError(FindErr<Other>(rm, "foo", "baz"), "Not found: Resource foo/baz");
+
+ // Delete foo/bar/Resource.
+ TF_CHECK_OK(rm.Delete<Resource>("foo", "bar"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty")));
+ EXPECT_EQ("R/kitty", Find<Resource>(rm, "foo", "bar"));
+
+ // Drop the whole container foo.
+ TF_CHECK_OK(rm.Cleanup("foo"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Container foo");
+}
+
+TEST(ResourceMgr, CreateOrLookup) {
+ ResourceMgr rm;
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "cat"));
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "dog"));
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "tiger"));
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "lion"));
+ TF_CHECK_OK(rm.Delete<Other>("foo", "bar"));
+ HasError(FindErr<Other>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+}
+
+Status ComputePolicy(const string& attr_container,
+ const string& attr_shared_name,
+ bool use_node_name_as_default, string* result) {
+ ContainerInfo cinfo;
+ ResourceMgr rmgr;
+ NodeDef ndef;
+ ndef.set_name("foo");
+ if (attr_container != "none") {
+ AddNodeAttr("container", attr_container, &ndef);
+ }
+ if (attr_shared_name != "none") {
+ AddNodeAttr("shared_name", attr_shared_name, &ndef);
+ }
+ TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default));
+ *result = cinfo.DebugString();
+ return Status::OK();
+}
+
+string Policy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string ret;
+ TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &ret));
+ return ret;
+}
+
+TEST(ContainerInfo, Basic) {
+ // Correct cases.
+ EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]");
+ EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]");
+ EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]");
+ EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]");
+ EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]");
+ EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]");
+}
+
+Status WrongPolicy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string dbg;
+ auto s = ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &dbg);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ContainerInfo, Error) {
+ // Missing attribute.
+ HasError(WrongPolicy("none", "", false), "No attr");
+ HasError(WrongPolicy("", "none", false), "No attr");
+ HasError(WrongPolicy("none", "none", false), "No attr");
+
+ // Invalid container.
+ HasError(WrongPolicy("12$%", "", false), "container contains invalid char");
+
+ // Invalid shared name.
+ HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'");
+}
+
+} // end namespace tensorflow