diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-05-16 17:01:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-16 17:05:15 -0700 |
commit | 43db5c623f748b6f9704e9e9be5a5a11fa2a4c1a (patch) | |
tree | 985844ec8f6653f36e38592f9700dcaba66d94f2 /tensorflow/core/common_runtime/function_test.cc | |
parent | 7ab0c2eff12ea79648f6717dae8558d6669e5c27 (diff) |
Automated g4 rollback of changelist 156244933
PiperOrigin-RevId: 156251356
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 40 |
1 files changed, 12 insertions, 28 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index e27fc3898d..dfa1ed8a7e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -namespace { typedef FunctionDefHelper FDH; @@ -59,29 +58,13 @@ void HasError(const Status& s, const string& substr) { << s << ", expected substring " << substr; } -// A helper class to make AttrSlice from initializer lists -class Attrs { - public: - Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { - for (const auto& aval : attrs) { - map_.insert({aval.first, aval.second.proto}); - } - } - - operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) - - private: - AttrValueMap map_; -}; - class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, Attrs attrs) { + void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -168,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { fdef_lib_ = lib_def_->ToProto(); } - Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { + Status Run(const string& name, InstantiateAttrValueSlice attrs, + const std::vector<Tensor>& args, std::vector<Tensor*> rets) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -205,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) { + std::unique_ptr<Graph> GetFuncBody(const string& name, + InstantiateAttrValueSlice attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -219,7 +203,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) { + std::unique_ptr<Graph> GetGradBody(const string& func, + InstantiateAttrValueSlice attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { @@ -630,14 +615,13 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), + HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK( - lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); - TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle)); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; @@ -944,7 +928,8 @@ bool DoNothing(Graph* g) { return false; } GraphDef Optimize(const std::function<bool(Graph* g)>& pass, const FunctionDef& fdef) { InstantiationResult result; - TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + InstantiateAttrValueMap empty; + TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; @@ -1263,5 +1248,4 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); } -} // end namespace } // end namespace tensorflow |