diff options
author | 2017-08-15 16:46:16 -0700 | |
---|---|---|
committer | 2017-08-15 16:50:32 -0700 | |
commit | 477d49c9eaafc5e1e1667d454ce5883956180713 (patch) | |
tree | 02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/tools | |
parent | 9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff) |
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code:
AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
if (!scope.ok()) return;
auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("AddN");
auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
.Input(_inputs)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->sum = Output(ret, 0);
}
Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.
PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/tools')
3 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc index 3ea7f512c6..5e4ab209e9 100644 --- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc +++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc @@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {}; // TODO(suharshs): Once we implement the fake_quantize_training transform // using the GTT, write proper tests of the transform here. TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) Tensor a_data(DT_FLOAT, TensorShape()); diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc index e1828831db..a58ce73453 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test { const TensorShape& weight_shape, std::initializer_list<float> weight_values, GraphDef* original_graph_def) { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); Tensor input_data(DT_FLOAT, input_shape); test::FillValues<float>(&input_data, input_values); diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index b5bc2d75fd..abb25916c5 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test { } void TestRenameNodeInputsWithWildcard() { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) const int width = 10; |