aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-30 10:44:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-30 10:51:22 -0700
commitcef680b5320f85d155d6e16c607021e7182c5df6 (patch)
treeb300406b2483f11ce50fe13380b5b438836ca1ae /tensorflow
parente8ac0b48f443879d9e3d516b0b3a151978128423 (diff)
Enable shape inference on functions in grappler.
PiperOrigin-RevId: 173914941
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h4
-rw-r--r--tensorflow/core/graph/graph_constructor.cc7
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc3
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc30
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt111
5 files changed, 152 insertions, 3 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index d1288d671e..570b4db163 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -164,6 +164,10 @@ class ShapeRefiner {
function_library_ = lib;
}
+ bool function_shape_inference_supported() const {
+ return function_library_ != nullptr;
+ }
+
// Call this to keep nested shapes information for user-defined functions:
// nested inferences will be available on the ExtendedInferenceContext for
// each function node, forming a tree of shape inferences corresponding to the
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 9432775ff3..8fe4f535fb 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -846,9 +846,10 @@ Status GraphConstructor::Convert() {
}
}
- // TODO(skyewm): remove conditional when b/35715995 ("Functions lack shape
- // inference") is resolved.
- if (g_->flib_def().Find(node_def->name()) == nullptr) {
+ // Function shape inference is supported on an opt-in basis per
+ // ShapeRefiner.
+ if (refiner_->function_shape_inference_supported() ||
+ g_->flib_def().Find(node_def->name()) == nullptr) {
TF_RETURN_IF_ERROR(ValidateShape(node));
}
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index f62a21ace5..e9cb2ee09d 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -195,9 +195,12 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
Status GraphProperties::InferStatically() {
Graph graph(OpRegistry::Global());
+ FunctionLibraryDefinition function_library(graph.op_registry(),
+ item_.graph.library());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
+ shape_refiner.set_function_library_for_shape_inference(&function_library);
ImportGraphDefOptions options;
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 975ec31b14..134db5ec5a 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -703,6 +703,36 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ return tf.add(x,y)
+
+ with tf.Graph().as_default():
+ x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = MyAdd(x, y)
+ z = MyAdd(x, z)
+ */
+ // Check that the shape of the second MyAdd node propagates
+ // correctly.
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "simple_function.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically());
+ const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1");
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(1, prop.shape().dim(0).size());
+ EXPECT_EQ(2, prop.shape().dim(1).size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt
new file mode 100644
index 0000000000..86b67f2049
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt
@@ -0,0 +1,111 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MyAdd_55e046a8"
+ op: "MyAdd_55e046a8"
+ input: "Const"
+ input: "Const_1"
+}
+node {
+ name: "MyAdd_55e046a8_1"
+ op: "MyAdd_55e046a8"
+ input: "Const"
+ input: "MyAdd_55e046a8"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_55e046a8"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 24
+ min_consumer: 12
+}