aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-15 16:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-15 16:50:32 -0700
commit477d49c9eaafc5e1e1667d454ce5883956180713 (patch)
tree02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/core/common_runtime/function_test.cc
parent9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff)
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index b1d32a984a..3ca4457b00 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
Status status;
Node* n = scope->graph()->AddNode(def, &status);
TF_CHECK_OK(status);
+ TF_CHECK_OK(scope->DoShapeInference(n));
for (int i = 0; i < inputs.size(); ++i) {
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
}
@@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
GraphDef expected;
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
@@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
}
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
{{"o", "o:sum"}});
{
- Scope scope = Scope::NewRootScope();
+ Scope scope = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
@@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{{"o", "o:sum"}});
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),