diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-05-17 09:23:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-17 09:27:36 -0700 |
commit | 73882f257ffb1bc9e1a828571c085d080b1d9266 (patch) | |
tree | 8adcefa226f95d6c6ce067ee45528d76794e55fb /tensorflow/core/common_runtime/function_test.cc | |
parent | 9a47c258c9c2286ae2c14a0da6458055f3b691d3 (diff) |
Automated g4 rollback of changelist 156251356
PiperOrigin-RevId: 156315860
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 40 |
1 files changed, 28 insertions, 12 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index dfa1ed8a7e..e27fc3898d 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { +namespace { typedef FunctionDefHelper FDH; @@ -58,13 +59,29 @@ void HasError(const Status& s, const string& substr) { << s << ", expected substring " << substr; } +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) { + void Create(const FunctionDef& fdef, Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -151,8 +168,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { fdef_lib_ = lib_def_->ToProto(); } - Status Run(const string& name, InstantiateAttrValueSlice attrs, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -188,8 +205,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - std::unique_ptr<Graph> GetFuncBody(const string& name, - InstantiateAttrValueSlice attrs) { + std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -203,8 +219,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - std::unique_ptr<Graph> GetGradBody(const string& func, - InstantiateAttrValueSlice attrs) { + std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { @@ -615,13 +630,14 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle), + HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle)); - TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle)); + TF_CHECK_OK( + lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; @@ -928,8 +944,7 @@ bool DoNothing(Graph* g) { return false; } GraphDef Optimize(const std::function<bool(Graph* g)>& pass, const FunctionDef& fdef) { InstantiationResult result; - InstantiateAttrValueMap empty; - TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); + TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; @@ -1248,4 +1263,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); } +} // end namespace } // end namespace tensorflow |