aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/transform_utils_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-15 16:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-15 16:50:32 -0700
commit477d49c9eaafc5e1e1667d454ce5883956180713 (patch)
tree02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/tools/graph_transforms/transform_utils_test.cc
parent9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff)
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/tools/graph_transforms/transform_utils_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils_test.cc2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc
index b5bc2d75fd..abb25916c5 100644
--- a/tensorflow/tools/graph_transforms/transform_utils_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc
@@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
}
void TestRenameNodeInputsWithWildcard() {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;