aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-15 16:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-15 16:50:32 -0700
commit477d49c9eaafc5e1e1667d454ce5883956180713 (patch)
tree02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/tools
parent9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff)
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/tools')
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training_test.cc2
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights_test.cc2
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils_test.cc2
3 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
index 3ea7f512c6..5e4ab209e9 100644
--- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
+++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
@@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
// TODO(suharshs): Once we implement the fake_quantize_training transform
// using the GTT, write proper tests of the transform here.
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor a_data(DT_FLOAT, TensorShape());
diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
index e1828831db..a58ce73453 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
@@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
const TensorShape& weight_shape,
std::initializer_list<float> weight_values,
GraphDef* original_graph_def) {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
Tensor input_data(DT_FLOAT, input_shape);
test::FillValues<float>(&input_data, input_values);
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc
index b5bc2d75fd..abb25916c5 100644
--- a/tensorflow/tools/graph_transforms/transform_utils_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc
@@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
}
void TestRenameNodeInputsWithWildcard() {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;