diff options
author | 2016-08-23 13:33:04 -0800 | |
---|---|---|
committer | 2016-08-23 14:46:32 -0700 | |
commit | 011402d8987e753acd54c6251a7edd2e2d8155ba (patch) | |
tree | d718a73c7e5d2f340f1ab07079d2bce711b30237 | |
parent | d0550db5736f484c12ac7f52dfaf2aa581d3170f (diff) |
Allow a python shape inference fn to delegate to the cpp shape
inference function. Enable this for MatMul and SparseMatMul.
Change: 131097313
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 11 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/python/framework/common_shapes.py | 34 | ||||
-rw-r--r-- | tensorflow/python/framework/cpp_shape_inference.cc | 102 | ||||
-rw-r--r-- | tensorflow/python/framework/cpp_shape_inference.h | 48 | ||||
-rw-r--r-- | tensorflow/python/framework/cpp_shape_inference.i | 28 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/matmul_op_test.py | 28 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/platform/base.i | 24 | ||||
-rw-r--r-- | tensorflow/python/tensorflow.i | 2 |
11 files changed, 308 insertions, 6 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 4300784ffe..c6da445165 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -49,6 +49,25 @@ InferenceContext::InferenceContext( InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<string>& input_shapes_string, + const std::vector<TensorShapeProto>& input_shapes, + const std::vector<const Tensor*>& input_tensors) + : node_def_(*CHECK_NOTNULL(node_def)) { + PreInputInit(op_def, input_tensors); + if (!construction_status_.ok()) return; + for (const TensorShapeProto& p : input_shapes) { + const Shape* shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + PostInputInit(); +} + +InferenceContext::InferenceContext( + const NodeDef* node_def, const OpDef& op_def, + const std::vector<string>& input_shapes_string, const std::vector<const Shape*>& input_shapes, const std::vector<const Tensor*>& input_tensors) : node_def_(*CHECK_NOTNULL(node_def)) { diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index b08ffd369e..1d0d4ab471 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -100,6 +100,17 @@ class InferenceContext { const std::vector<const Shape*>& input_shapes, const std::vector<const Tensor*>& input_tensors); + // <input_tensors> is NULL-padded to be the same size as <input_shapes>. + // + // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. + // + // TODO(cwhipkey): Remove 'input_shapes_string' once we can move the creation + // of Shapes from strings out of this class (or hide it). + InferenceContext(const NodeDef* node_def, const OpDef& op_def, + const std::vector<string>& input_shapes_string, + const std::vector<TensorShapeProto>& input_shapes, + const std::vector<const Tensor*>& input_tensors); + // This is a temporary constructor used for initial testing. // // TODO(cwhipkey): remove this temporary constructor. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 21825e1bc2..fa10426ca2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -169,6 +169,18 @@ cc_library( ) cc_library( + name = "cpp_shape_inference", + srcs = ["framework/cpp_shape_inference.cc"], + hdrs = ["framework/cpp_shape_inference.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:protos_cc", + ], +) + +cc_library( name = "python_op_gen_main", srcs = [ "framework/python_op_gen_main.cc", @@ -1125,6 +1137,7 @@ tf_py_wrap_cc( "client/net_lib.i", "client/quantize_training.i", "client/tf_session.i", + "framework/cpp_shape_inference.i", "framework/python_op_gen.i", "lib/core/py_func.i", "lib/core/strings.i", @@ -1142,6 +1155,7 @@ tf_py_wrap_cc( ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", + ":cpp_shape_inference", ":python_op_gen", ":tf_session_helper", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 3e034579c0..843317d391 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -19,6 +19,9 @@ from __future__ import print_function import six.moves +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape @@ -586,3 +589,34 @@ def broadcast_shape(shape_x, shape_y): return tensor_shape.TensorShape(return_dims) +def call_cpp_shape_fn(op): + """A shape function that delegates to the registered C++ shape function. + + Args: + op: the node in the graph for which to compute output shapes. + + Returns: + A TensorShape list of the output shapes of the op, as computed using the + C++ shape inference function registered for the op. + + Raises: + ValueError: If the C++ shape function returned an error (e.g. because the + shapes of the inputs are of the wrong rank or otherwise incompatible + according to the shape function). + """ + node_def_str = op.node_def.SerializeToString() + input_shapes = [i.get_shape().as_proto().SerializeToString() for i in + op.inputs] + + try: + with errors.raise_exception_on_not_ok_status() as status: + output_shapes = pywrap_tensorflow.RunCppShapeInference( + node_def_str, input_shapes, status) + except errors.InvalidArgumentError as err: + raise ValueError(err.message) + + # Convert TensorShapeProto values in output_shapes. + return [ + tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s)) + for s in output_shapes + ] diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc new file mode 100644 index 0000000000..49ef244944 --- /dev/null +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -0,0 +1,102 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/framework/cpp_shape_inference.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace swig { +namespace { + +Status RunCppShapeInferenceImpl( + const string& serialized_node_def, + const std::vector<string>& input_serialized_shapes, + std::vector<string>* output_tensor_shape_protos) { + tensorflow::NodeDef node; + if (!node.ParseFromString(serialized_node_def)) { + return errors::InvalidArgument( + "Error parsing node_def during cpp shape inference"); + } + DCHECK_EQ(output_tensor_shape_protos->size(), 0); + + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data)); + + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "No shape inference function exists for op '", node.op(), + "', did you forget to define it?"); + } + + std::vector<TensorShapeProto> input_shapes; + input_shapes.resize(input_serialized_shapes.size()); + for (int i = 0; i < input_serialized_shapes.size(); ++i) { + if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) { + return errors::InvalidArgument( + "Error parsing shape proto during cpp shape inference"); + } + } + + tensorflow::shape_inference::InferenceContext c( + &node, op_reg_data->op_def, {} /* input_shape_strings */, input_shapes, + {} /* input_tensors */); + TF_RETURN_IF_ERROR(c.construction_status()); + TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c)); + + // Convert output shapes. + output_tensor_shape_protos->resize(c.num_outputs()); + TensorShapeProto out; + for (int i = 0; i < c.num_outputs(); ++i) { + const shape_inference::Shape* s = c.output(i); + out.Clear(); + if (c.RankKnown(s)) { + const int32 rank = c.Rank(s); + for (int i = 0; i < rank; ++i) { + const shape_inference::Dimension* d = c.Dim(s, i); + auto* out_dim = out.add_dim(); + if (c.ValueKnown(d)) { + out_dim->set_size(c.Value(d)); + } else { + out_dim->set_size(-1); + } + } + } else { + out.set_unknown_rank(true); + } + CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i])); + } + return Status::OK(); +} + +} // namespace + +std::vector<string> RunCppShapeInference( + const string& serialized_node_def, + const std::vector<string>& input_serialized_shapes, TF_Status* out_status) { + std::vector<string> output_tensor_shape_protos; + tensorflow::Status status = + RunCppShapeInferenceImpl(serialized_node_def, input_serialized_shapes, + &output_tensor_shape_protos); + Set_TF_Status_from_Status(out_status, status); + return status.ok() ? output_tensor_shape_protos : std::vector<string>(); +} + +} // namespace swig +} // namespace tensorflow diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h new file mode 100644 index 0000000000..f10ac8cd3b --- /dev/null +++ b/tensorflow/python/framework/cpp_shape_inference.h @@ -0,0 +1,48 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ +#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ + +#include <vector> +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace swig { + +// Calls the registered C++ shape inference function for <node> (a serialized +// NodeDef). +// Should not be called for shape functions that access input tensors; constant +// input tensor values are not made available, and so the inferred shapes will +// be less precise than they could be. +// +// Returns an error, or OK, in <out_status> according to whether the shape +// inference was successful. +// +// On success, <*output_shapes> is populated with the inferred output shapes (as +// serialized TensorShapeProtos). +// <*output_shapes> must be empty when this function is called. +// +// This is temporary code to be used during the migration +// from python shape inference functions to C++ shape inference functions. +std::vector<string> RunCppShapeInference( + const string& serialized_node_def, + const std::vector<string>& input_serialized_shapes, TF_Status* out_status); + +} // namespace swig +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ diff --git a/tensorflow/python/framework/cpp_shape_inference.i b/tensorflow/python/framework/cpp_shape_inference.i new file mode 100644 index 0000000000..7135f9380b --- /dev/null +++ b/tensorflow/python/framework/cpp_shape_inference.i @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/framework/cpp_shape_inference.h" +%} + +%ignoreall; +%unignore tensorflow; +%unignore tensorflow::swig; +%unignore tensorflow::swig::RunCppShapeInference; +%include "tensorflow/python/framework/cpp_shape_inference.h" + +%unignoreall diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index aefcde0b89..e6d3960980 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -283,11 +283,35 @@ class MatMulTest(tf.test.TestCase): b = tf.placeholder(tf.float32, [36, 2]) c = tf.placeholder(tf.float32, [37]) with self.assertRaisesRegexp( - ValueError, "Dimensions 37 and 36 are not compatible"): + ValueError, "Dimensions must be equal, but are 37 and 36"): tf.matmul(a, b) - with self.assertRaisesRegexp(ValueError, "must have rank 2"): + with self.assertRaisesRegexp(ValueError, "must be rank 2"): tf.matmul(a, c) + def testShapeInference(self): + """Tests common_shapes.call_cpp_shape_fn.""" + a = tf.constant([2] * 6, shape=[3, 2]) + b = tf.constant([2] * 2, shape=[2, 1]) + mm = tf.matmul(a, b) + self.assertEqual([3, 1], mm.get_shape()) + + # Transpose arguments are respected. + a = tf.constant([2] * 6, shape=[2, 3]) + b = tf.constant([2] * 2, shape=[1, 2]) + mm = tf.matmul(a, b, transpose_a=True, transpose_b=True) + self.assertEqual([3, 1], mm.get_shape()) + + # Unknown dims come through in output. + a = tf.placeholder(np.float32) + b = tf.placeholder(np.float32) + mm = tf.matmul(a, b) + self.assertEqual([None, None], mm.get_shape().as_list()) + + a = tf.constant([1] * 6, shape=[2, 3]) + b = tf.constant([2] * 2, shape=[1, 2]) + with self.assertRaisesRegexp(ValueError, ".*must be equal.*"): + tf.matmul(a, b, transpose_a=False, transpose_b=True) + # TODO(zhifengc): Figures out how to test matmul gradients on GPU. class MatMulGradientTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 2f7c09b7ac..b976e953f8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1354,8 +1354,8 @@ def matmul(a, b, sparse_matmul = gen_math_ops._sparse_mat_mul batch_matmul = gen_math_ops._batch_mat_mul -ops.RegisterShape("MatMul")(common_shapes.matmul_shape) -ops.RegisterShape("SparseMatMul")(common_shapes.matmul_shape) +ops.RegisterShape("MatMul")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseMatMul")(common_shapes.call_cpp_shape_fn) @ops.RegisterStatistics("MatMul", "flops") diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index df40491ed3..246446893c 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -17,6 +17,7 @@ limitations under the License. // %{ #include <memory> + #include <vector> #include "tensorflow/core/platform/types.h" using tensorflow::uint64; using tensorflow::string; @@ -75,6 +76,25 @@ limitations under the License. return PyUnicode_FromStringAndSize(s.data(), s.size()); #endif } + + template <class T> + bool tf_vector_input_helper(PyObject * seq, std::vector<T> * out, + bool (*convert)(PyObject*, T * const)) { + PyObject *item, *it = PyObject_GetIter(seq); + if (!it) return false; + while ((item = PyIter_Next(it))) { + T elem; + bool success = convert(item, &elem); + Py_DECREF(item); + if (!success) { + Py_DECREF(it); + return false; + } + if (out) out->push_back(elem); + } + Py_DECREF(it); + return static_cast<bool>(!PyErr_Occurred()); + } %} %typemap(in) string { @@ -112,7 +132,7 @@ limitations under the License. %define _LIST_OUTPUT_TYPEMAP(type, py_converter) %typemap(in) std::vector<type>(std::vector<type> temp) { - if (!vector_input_helper($input, &temp, _PyObjAs<type>)) { + if (!tf_vector_input_helper($input, &temp, _PyObjAs<type>)) { if (!PyErr_Occurred()) PyErr_SetString(PyExc_TypeError, "sequence(type) expected"); return NULL; @@ -121,7 +141,7 @@ limitations under the License. } %typemap(in) const std::vector<type>& (std::vector<type> temp), const std::vector<type>* (std::vector<type> temp) { - if (!vector_input_helper($input, &temp, _PyObjAs<type>)) { + if (!tf_vector_input_helper($input, &temp, _PyObjAs<type>)) { if (!PyErr_Occurred()) PyErr_SetString(PyExc_TypeError, "sequence(type) expected"); return NULL; diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index ef82a009f9..7a8fbf7201 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -35,3 +35,5 @@ limitations under the License. %include "tensorflow/python/training/server_lib.i" %include "tensorflow/python/framework/python_op_gen.i" + +%include "tensorflow/python/framework/cpp_shape_inference.i" |