diff options
author | 2017-08-15 16:46:16 -0700 | |
---|---|---|
committer | 2017-08-15 16:50:32 -0700 | |
commit | 477d49c9eaafc5e1e1667d454ce5883956180713 (patch) | |
tree | 02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/tools/graph_transforms/transform_utils_test.cc | |
parent | 9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff) |
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code:
AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
if (!scope.ok()) return;
auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("AddN");
auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
.Input(_inputs)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->sum = Output(ret, 0);
}
Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.
PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/tools/graph_transforms/transform_utils_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/transform_utils_test.cc | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index b5bc2d75fd..abb25916c5 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test { } void TestRenameNodeInputsWithWildcard() { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) const int width = 10; |