diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-08-15 16:46:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-15 16:50:32 -0700 |
commit | 477d49c9eaafc5e1e1667d454ce5883956180713 (patch) | |
tree | 02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/core/common_runtime/function_test.cc | |
parent | 9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff) |
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code:
AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
if (!scope.ok()) return;
auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("AddN");
auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
.Input(_inputs)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->sum = Output(ret, 0);
}
Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.
PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index b1d32a984a..3ca4457b00 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, Status status; Node* n = scope->graph()->AddNode(def, &status); TF_CHECK_OK(status); + TF_CHECK_OK(scope->DoShapeInference(n)); for (int i = 0; i < inputs.size(); ++i) { scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i); } @@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) { GraphDef expected; { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT); @@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) { {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto a = ops::Square(s.WithOpName("a"), x); @@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) { } { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto a = ops::Square(s.WithOpName("a"), x); @@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) { {{"o", "o:sum"}}); { - Scope scope = Scope::NewRootScope(); + Scope scope = Scope::DisabledShapeInferenceScope(); auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); auto zero = ops::Const(scope.WithOpName("zero"), 0); auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); @@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { {{"o", "o:sum"}}); { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0); auto dummy = ops::Const(s.WithOpName("dummy"), 0); auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy), |