aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 13:33:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 14:46:32 -0700
commit011402d8987e753acd54c6251a7edd2e2d8155ba (patch)
treed718a73c7e5d2f340f1ab07079d2bce711b30237
parentd0550db5736f484c12ac7f52dfaf2aa581d3170f (diff)
Allow a python shape inference fn to delegate to the cpp shape
inference function. Enable this for MatMul and SparseMatMul. Change: 131097313
-rw-r--r--tensorflow/core/framework/shape_inference.cc19
-rw-r--r--tensorflow/core/framework/shape_inference.h11
-rw-r--r--tensorflow/python/BUILD14
-rw-r--r--tensorflow/python/framework/common_shapes.py34
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.cc102
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h48
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.i28
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py28
-rw-r--r--tensorflow/python/ops/math_ops.py4
-rw-r--r--tensorflow/python/platform/base.i24
-rw-r--r--tensorflow/python/tensorflow.i2
11 files changed, 308 insertions, 6 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4300784ffe..c6da445165 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -49,6 +49,25 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
+ const std::vector<TensorShapeProto>& input_shapes,
+ const std::vector<const Tensor*>& input_tensors)
+ : node_def_(*CHECK_NOTNULL(node_def)) {
+ PreInputInit(op_def, input_tensors);
+ if (!construction_status_.ok()) return;
+ for (const TensorShapeProto& p : input_shapes) {
+ const Shape* shape;
+ construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
+ if (!construction_status_.ok()) {
+ return;
+ }
+ inputs_.push_back(shape);
+ }
+ PostInputInit();
+}
+
+InferenceContext::InferenceContext(
+ const NodeDef* node_def, const OpDef& op_def,
+ const std::vector<string>& input_shapes_string,
const std::vector<const Shape*>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
: node_def_(*CHECK_NOTNULL(node_def)) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index b08ffd369e..1d0d4ab471 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -100,6 +100,17 @@ class InferenceContext {
const std::vector<const Shape*>& input_shapes,
const std::vector<const Tensor*>& input_tensors);
+ // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
+ //
+ // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
+ //
+ // TODO(cwhipkey): Remove 'input_shapes_string' once we can move the creation
+ // of Shapes from strings out of this class (or hide it).
+ InferenceContext(const NodeDef* node_def, const OpDef& op_def,
+ const std::vector<string>& input_shapes_string,
+ const std::vector<TensorShapeProto>& input_shapes,
+ const std::vector<const Tensor*>& input_tensors);
+
// This is a temporary constructor used for initial testing.
//
// TODO(cwhipkey): remove this temporary constructor.
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 21825e1bc2..fa10426ca2 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -169,6 +169,18 @@ cc_library(
)
cc_library(
+ name = "cpp_shape_inference",
+ srcs = ["framework/cpp_shape_inference.cc"],
+ hdrs = ["framework/cpp_shape_inference.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_cc",
+ ],
+)
+
+cc_library(
name = "python_op_gen_main",
srcs = [
"framework/python_op_gen_main.cc",
@@ -1125,6 +1137,7 @@ tf_py_wrap_cc(
"client/net_lib.i",
"client/quantize_training.i",
"client/tf_session.i",
+ "framework/cpp_shape_inference.i",
"framework/python_op_gen.i",
"lib/core/py_func.i",
"lib/core/strings.i",
@@ -1142,6 +1155,7 @@ tf_py_wrap_cc(
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
+ ":cpp_shape_inference",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 3e034579c0..843317d391 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -19,6 +19,9 @@ from __future__ import print_function
import six.moves
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
@@ -586,3 +589,34 @@ def broadcast_shape(shape_x, shape_y):
return tensor_shape.TensorShape(return_dims)
+def call_cpp_shape_fn(op):
+ """A shape function that delegates to the registered C++ shape function.
+
+ Args:
+ op: the node in the graph for which to compute output shapes.
+
+ Returns:
+ A TensorShape list of the output shapes of the op, as computed using the
+ C++ shape inference function registered for the op.
+
+ Raises:
+ ValueError: If the C++ shape function returned an error (e.g. because the
+ shapes of the inputs are of the wrong rank or otherwise incompatible
+ according to the shape function).
+ """
+ node_def_str = op.node_def.SerializeToString()
+ input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
+ op.inputs]
+
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ output_shapes = pywrap_tensorflow.RunCppShapeInference(
+ node_def_str, input_shapes, status)
+ except errors.InvalidArgumentError as err:
+ raise ValueError(err.message)
+
+ # Convert TensorShapeProto values in output_shapes.
+ return [
+ tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
+ for s in output_shapes
+ ]
diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc
new file mode 100644
index 0000000000..49ef244944
--- /dev/null
+++ b/tensorflow/python/framework/cpp_shape_inference.cc
@@ -0,0 +1,102 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/framework/cpp_shape_inference.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace swig {
+namespace {
+
+Status RunCppShapeInferenceImpl(
+ const string& serialized_node_def,
+ const std::vector<string>& input_serialized_shapes,
+ std::vector<string>* output_tensor_shape_protos) {
+ tensorflow::NodeDef node;
+ if (!node.ParseFromString(serialized_node_def)) {
+ return errors::InvalidArgument(
+ "Error parsing node_def during cpp shape inference");
+ }
+ DCHECK_EQ(output_tensor_shape_protos->size(), 0);
+
+ const OpRegistrationData* op_reg_data;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data));
+
+ if (op_reg_data->shape_inference_fn == nullptr) {
+ return errors::InvalidArgument(
+ "No shape inference function exists for op '", node.op(),
+ "', did you forget to define it?");
+ }
+
+ std::vector<TensorShapeProto> input_shapes;
+ input_shapes.resize(input_serialized_shapes.size());
+ for (int i = 0; i < input_serialized_shapes.size(); ++i) {
+ if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) {
+ return errors::InvalidArgument(
+ "Error parsing shape proto during cpp shape inference");
+ }
+ }
+
+ tensorflow::shape_inference::InferenceContext c(
+ &node, op_reg_data->op_def, {} /* input_shape_strings */, input_shapes,
+ {} /* input_tensors */);
+ TF_RETURN_IF_ERROR(c.construction_status());
+ TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
+
+ // Convert output shapes.
+ output_tensor_shape_protos->resize(c.num_outputs());
+ TensorShapeProto out;
+ for (int i = 0; i < c.num_outputs(); ++i) {
+ const shape_inference::Shape* s = c.output(i);
+ out.Clear();
+ if (c.RankKnown(s)) {
+ const int32 rank = c.Rank(s);
+ for (int i = 0; i < rank; ++i) {
+ const shape_inference::Dimension* d = c.Dim(s, i);
+ auto* out_dim = out.add_dim();
+ if (c.ValueKnown(d)) {
+ out_dim->set_size(c.Value(d));
+ } else {
+ out_dim->set_size(-1);
+ }
+ }
+ } else {
+ out.set_unknown_rank(true);
+ }
+ CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+std::vector<string> RunCppShapeInference(
+ const string& serialized_node_def,
+ const std::vector<string>& input_serialized_shapes, TF_Status* out_status) {
+ std::vector<string> output_tensor_shape_protos;
+ tensorflow::Status status =
+ RunCppShapeInferenceImpl(serialized_node_def, input_serialized_shapes,
+ &output_tensor_shape_protos);
+ Set_TF_Status_from_Status(out_status, status);
+ return status.ok() ? output_tensor_shape_protos : std::vector<string>();
+}
+
+} // namespace swig
+} // namespace tensorflow
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
new file mode 100644
index 0000000000..f10ac8cd3b
--- /dev/null
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
+#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
+
+#include <vector>
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace swig {
+
+// Calls the registered C++ shape inference function for <node> (a serialized
+// NodeDef).
+// Should not be called for shape functions that access input tensors; constant
+// input tensor values are not made available, and so the inferred shapes will
+// be less precise than they could be.
+//
+// Returns an error, or OK, in <out_status> according to whether the shape
+// inference was successful.
+//
+// On success, <*output_shapes> is populated with the inferred output shapes (as
+// serialized TensorShapeProtos).
+// <*output_shapes> must be empty when this function is called.
+//
+// This is temporary code to be used during the migration
+// from python shape inference functions to C++ shape inference functions.
+std::vector<string> RunCppShapeInference(
+ const string& serialized_node_def,
+ const std::vector<string>& input_serialized_shapes, TF_Status* out_status);
+
+} // namespace swig
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
diff --git a/tensorflow/python/framework/cpp_shape_inference.i b/tensorflow/python/framework/cpp_shape_inference.i
new file mode 100644
index 0000000000..7135f9380b
--- /dev/null
+++ b/tensorflow/python/framework/cpp_shape_inference.i
@@ -0,0 +1,28 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/framework/cpp_shape_inference.h"
+%}
+
+%ignoreall;
+%unignore tensorflow;
+%unignore tensorflow::swig;
+%unignore tensorflow::swig::RunCppShapeInference;
+%include "tensorflow/python/framework/cpp_shape_inference.h"
+
+%unignoreall
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index aefcde0b89..e6d3960980 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -283,11 +283,35 @@ class MatMulTest(tf.test.TestCase):
b = tf.placeholder(tf.float32, [36, 2])
c = tf.placeholder(tf.float32, [37])
with self.assertRaisesRegexp(
- ValueError, "Dimensions 37 and 36 are not compatible"):
+ ValueError, "Dimensions must be equal, but are 37 and 36"):
tf.matmul(a, b)
- with self.assertRaisesRegexp(ValueError, "must have rank 2"):
+ with self.assertRaisesRegexp(ValueError, "must be rank 2"):
tf.matmul(a, c)
+ def testShapeInference(self):
+ """Tests common_shapes.call_cpp_shape_fn."""
+ a = tf.constant([2] * 6, shape=[3, 2])
+ b = tf.constant([2] * 2, shape=[2, 1])
+ mm = tf.matmul(a, b)
+ self.assertEqual([3, 1], mm.get_shape())
+
+ # Transpose arguments are respected.
+ a = tf.constant([2] * 6, shape=[2, 3])
+ b = tf.constant([2] * 2, shape=[1, 2])
+ mm = tf.matmul(a, b, transpose_a=True, transpose_b=True)
+ self.assertEqual([3, 1], mm.get_shape())
+
+ # Unknown dims come through in output.
+ a = tf.placeholder(np.float32)
+ b = tf.placeholder(np.float32)
+ mm = tf.matmul(a, b)
+ self.assertEqual([None, None], mm.get_shape().as_list())
+
+ a = tf.constant([1] * 6, shape=[2, 3])
+ b = tf.constant([2] * 2, shape=[1, 2])
+ with self.assertRaisesRegexp(ValueError, ".*must be equal.*"):
+ tf.matmul(a, b, transpose_a=False, transpose_b=True)
+
# TODO(zhifengc): Figures out how to test matmul gradients on GPU.
class MatMulGradientTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 2f7c09b7ac..b976e953f8 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1354,8 +1354,8 @@ def matmul(a, b,
sparse_matmul = gen_math_ops._sparse_mat_mul
batch_matmul = gen_math_ops._batch_mat_mul
-ops.RegisterShape("MatMul")(common_shapes.matmul_shape)
-ops.RegisterShape("SparseMatMul")(common_shapes.matmul_shape)
+ops.RegisterShape("MatMul")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseMatMul")(common_shapes.call_cpp_shape_fn)
@ops.RegisterStatistics("MatMul", "flops")
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
index df40491ed3..246446893c 100644
--- a/tensorflow/python/platform/base.i
+++ b/tensorflow/python/platform/base.i
@@ -17,6 +17,7 @@ limitations under the License.
//
%{
#include <memory>
+ #include <vector>
#include "tensorflow/core/platform/types.h"
using tensorflow::uint64;
using tensorflow::string;
@@ -75,6 +76,25 @@ limitations under the License.
return PyUnicode_FromStringAndSize(s.data(), s.size());
#endif
}
+
+ template <class T>
+ bool tf_vector_input_helper(PyObject * seq, std::vector<T> * out,
+ bool (*convert)(PyObject*, T * const)) {
+ PyObject *item, *it = PyObject_GetIter(seq);
+ if (!it) return false;
+ while ((item = PyIter_Next(it))) {
+ T elem;
+ bool success = convert(item, &elem);
+ Py_DECREF(item);
+ if (!success) {
+ Py_DECREF(it);
+ return false;
+ }
+ if (out) out->push_back(elem);
+ }
+ Py_DECREF(it);
+ return static_cast<bool>(!PyErr_Occurred());
+ }
%}
%typemap(in) string {
@@ -112,7 +132,7 @@ limitations under the License.
%define _LIST_OUTPUT_TYPEMAP(type, py_converter)
%typemap(in) std::vector<type>(std::vector<type> temp) {
- if (!vector_input_helper($input, &temp, _PyObjAs<type>)) {
+ if (!tf_vector_input_helper($input, &temp, _PyObjAs<type>)) {
if (!PyErr_Occurred())
PyErr_SetString(PyExc_TypeError, "sequence(type) expected");
return NULL;
@@ -121,7 +141,7 @@ limitations under the License.
}
%typemap(in) const std::vector<type>& (std::vector<type> temp),
const std::vector<type>* (std::vector<type> temp) {
- if (!vector_input_helper($input, &temp, _PyObjAs<type>)) {
+ if (!tf_vector_input_helper($input, &temp, _PyObjAs<type>)) {
if (!PyErr_Occurred())
PyErr_SetString(PyExc_TypeError, "sequence(type) expected");
return NULL;
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index ef82a009f9..7a8fbf7201 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -35,3 +35,5 @@ limitations under the License.
%include "tensorflow/python/training/server_lib.i"
%include "tensorflow/python/framework/python_op_gen.i"
+
+%include "tensorflow/python/framework/cpp_shape_inference.i"