diff options
author | 2017-10-30 10:44:36 -0700 | |
---|---|---|
committer | 2017-10-30 10:51:22 -0700 | |
commit | cef680b5320f85d155d6e16c607021e7182c5df6 (patch) | |
tree | b300406b2483f11ce50fe13380b5b438836ca1ae /tensorflow | |
parent | e8ac0b48f443879d9e3d516b0b3a151978128423 (diff) |
Enable shape inference on functions in grappler.
PiperOrigin-RevId: 173914941
Diffstat (limited to 'tensorflow')
5 files changed, 152 insertions, 3 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index d1288d671e..570b4db163 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -164,6 +164,10 @@ class ShapeRefiner { function_library_ = lib; } + bool function_shape_inference_supported() const { + return function_library_ != nullptr; + } + // Call this to keep nested shapes information for user-defined functions: // nested inferences will be available on the ExtendedInferenceContext for // each function node, forming a tree of shape inferences corresponding to the diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 9432775ff3..8fe4f535fb 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -846,9 +846,10 @@ Status GraphConstructor::Convert() { } } - // TODO(skyewm): remove conditional when b/35715995 ("Functions lack shape - // inference") is resolved. - if (g_->flib_def().Find(node_def->name()) == nullptr) { + // Function shape inference is supported on an opt-in basis per + // ShapeRefiner. + if (refiner_->function_shape_inference_supported() || + g_->flib_def().Find(node_def->name()) == nullptr) { TF_RETURN_IF_ERROR(ValidateShape(node)); } diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index f62a21ace5..e9cb2ee09d 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -195,9 +195,12 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( Status GraphProperties::InferStatically() { Graph graph(OpRegistry::Global()); + FunctionLibraryDefinition function_library(graph.op_registry(), + item_.graph.library()); ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); shape_refiner.set_require_shape_inference_fns(false); shape_refiner.set_disable_constant_propagation(true); + shape_refiner.set_function_library_for_shape_inference(&function_library); ImportGraphDefOptions options; Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 975ec31b14..134db5ec5a 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -703,6 +703,36 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } +TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) { + // Test graph produced in python using: + /* + @function.Defun(*[tf.float32] * 2, noinline=True) + def MyAdd(x, y): + return tf.add(x,y) + + with tf.Graph().as_default(): + x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32) + y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32) + z = MyAdd(x, y) + z = MyAdd(x, z) + */ + // Check that the shape of the second MyAdd node propagates + // correctly. + GrapplerItem item; + string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, + "simple_function.pbtxt"); + TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1"); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(1, prop.shape().dim(0).size()); + EXPECT_EQ(2, prop.shape().dim(1).size()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt new file mode 100644 index 0000000000..86b67f2049 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/simple_function.pbtxt @@ -0,0 +1,111 @@ +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 2 + } + } + float_val: 2.0 + } + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 2 + } + } + float_val: 2.0 + } + } + } +} +node { + name: "MyAdd_55e046a8" + op: "MyAdd_55e046a8" + input: "Const" + input: "Const_1" +} +node { + name: "MyAdd_55e046a8_1" + op: "MyAdd_55e046a8" + input: "Const" + input: "MyAdd_55e046a8" +} +library { + function { + signature { + name: "MyAdd_55e046a8" + input_arg { + name: "x" + type: DT_FLOAT + } + input_arg { + name: "y" + type: DT_FLOAT + } + output_arg { + name: "Add" + type: DT_FLOAT + } + } + node_def { + name: "Add" + op: "Add" + input: "x" + input: "y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "Add" + value: "Add:z:0" + } + attr { + key: "_noinline" + value { + b: true + } + } + } +} +versions { + producer: 24 + min_consumer: 12 +} |