diff options
author | 2016-11-01 09:13:26 -0800 | |
---|---|---|
committer | 2016-11-01 10:19:48 -0700 | |
commit | 6d4956edb02c5a2d0a9bdeb919a4c4c0e55fd882 (patch) | |
tree | b4fec91515b41649be5588797cf4d9aca02e30f9 | |
parent | a964d5f3c8a6ad924379a19e9924afae7c428d62 (diff) |
Allow passing constant_value_as_shape from call_cpp_shape_fn to the C++ shape
inference. Delegate to C++ shape function for Reshape.
Fix reshape to handle attempting to infer unknown dim when product of known
elements is 0.
Change: 137837591
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 2 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/python/framework/common_shapes.py | 29 | ||||
-rw-r--r-- | tensorflow/python/framework/cpp_shape_inference.cc | 32 | ||||
-rw-r--r-- | tensorflow/python/framework/cpp_shape_inference.h | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/reshape_op_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 4 |
9 files changed, 72 insertions, 24 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 4aa32f6a84..f6475e0736 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -31,11 +31,20 @@ InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<TensorShapeProto>& input_tensors_as_shapes, const std::vector<TensorShapeProto>& input_handle_shapes, const std::vector<DataType>& input_handle_dtypes) : node_def_(*CHECK_NOTNULL(node_def)) { - PreInputInit(op_def, input_tensors, input_tensors_as_shapes); + std::vector<ShapeHandle> input_tensors_as_shape_handles; + for (const TensorShapeProto& p : input_tensors_as_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + input_tensors_as_shape_handles.push_back(shape); + } + PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); if (!construction_status_.ok()) return; for (const TensorShapeProto& p : input_shapes) { ShapeHandle shape; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index e02490efd9..1a8107ef00 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -164,7 +164,7 @@ class InferenceContext { InferenceContext(const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<TensorShapeProto>& input_tensors_as_shapes, const std::vector<TensorShapeProto>& input_handle_shapes, const std::vector<DataType>& input_handle_dtypes); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a7fb3375c8..0d3e6fa94a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -151,7 +151,7 @@ Status SetOutputShapeForReshape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems)); } } - if (!too_many_unknown) { + if (!too_many_unknown && c->Value(known_elems) != 0) { DimensionHandle inferred_dim; TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems), true /* evenly_divisible */, &inferred_dim)); diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 7f7861384c..4e10d72816 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -718,6 +718,14 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) { INFER_ERROR( "Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op, "[1,2];[0]"); + + // Reshaping a tensor with no elements. + new_shape = test::AsTensor<int32>({-1}); + INFER_OK(op, "[0];[1]", "[0]"); + new_shape = test::AsTensor<int32>({-1, 6}); + INFER_OK(op, "[0,2];[1]", "[0,6]"); + new_shape = test::AsTensor<int32>({0, -1}); + INFER_OK(op, "[0,2];[1]", "[0,?]"); } TEST(ArrayOpsTest, QuantizedReshape_ShapeFn) { diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 09afe56b19..c8867b4b1b 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -552,7 +552,9 @@ def broadcast_shape(shape_x, shape_y): return tensor_shape.TensorShape(return_dims) -def call_cpp_shape_fn(op, input_tensors_needed=None, +def call_cpp_shape_fn(op, + input_tensors_needed=None, + input_tensors_as_shapes_needed=None, debug_python_shape_fn=None): """A shape function that delegates to the registered C++ shape function. @@ -560,6 +562,8 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, op: the node in the graph for which to compute output shapes. input_tensors_needed: a list of input tensor indices for which to compute the input tensor's value and pass to the C++ shape function. + input_tensors_as_shapes_needed: a list of input tensor indices for which to + compute the constant_value_as_shape and pass to the C++ shape function. debug_python_shape_fn: For testing only during migration to using call_cpp_shape_fn. Do not submit calls that set this, as the comparison is slow. If non-None, the python shape function; @@ -594,16 +598,25 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, input_tensors = [None for i in input_shapes] if input_tensors_needed: for idx in input_tensors_needed: - input_tensors[idx] = tensor_util.constant_value(op.inputs[idx]) - if input_tensors[idx] is not None: - input_tensors[idx] = np.asarray(input_tensors[idx]) + v = tensor_util.constant_value(op.inputs[idx]) + if v is not None: + input_tensors[idx] = np.asarray(v) + + serialized_unknown_shape = ( + tensor_shape.TensorShape(None).as_proto().SerializeToString()) + arr = [serialized_unknown_shape for i in input_shapes] + if input_tensors_as_shapes_needed: + for idx in input_tensors_as_shapes_needed: + s = tensor_util.constant_value_as_shape(op.inputs[idx]) + if s is not None: + arr[idx] = s.as_proto().SerializeToString() + input_tensors_as_shapes = arr try: with errors.raise_exception_on_not_ok_status() as status: - output_shapes = pywrap_tensorflow.RunCppShapeInference(node_def_str, - input_shapes, - input_tensors, - status) + output_shapes = pywrap_tensorflow.RunCppShapeInference( + node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, + status) except errors.InvalidArgumentError as err: raise ValueError(err.message) diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index 57b85e8118..7620b52f9f 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -50,6 +50,7 @@ Status RunCppShapeInferenceImpl( const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, const std::vector<PyObject*>& input_constant_tensor_values, + const std::vector<string>& input_constant_tensor_as_shape_values, std::vector<string>* output_tensor_shape_protos) { tensorflow::NodeDef node; if (!node.ParseFromString(serialized_node_def)) { @@ -87,10 +88,9 @@ Status RunCppShapeInferenceImpl( } // Convert input tensor values; - const int num_input_tensors = input_constant_tensor_values.size(); - std::vector<Tensor> input_tensor_values(num_input_tensors); + std::vector<Tensor> input_tensor_values(input_constant_tensor_values.size()); std::vector<const Tensor*> input_tensors; - for (int i = 0; i < num_input_tensors; ++i) { + for (int i = 0; i < input_constant_tensor_values.size(); ++i) { auto* py_val = input_constant_tensor_values[i]; if (py_val == Py_None) { input_tensors.push_back(nullptr); @@ -101,11 +101,21 @@ Status RunCppShapeInferenceImpl( } } + // Convert input tensor-as-shape values; + std::vector<TensorShapeProto> input_tensor_as_shapes_protos( + input_constant_tensor_as_shape_values.size()); + for (int i = 0; i < input_constant_tensor_as_shape_values.size(); ++i) { + if (!input_tensor_as_shapes_protos[i].ParseFromString( + input_constant_tensor_as_shape_values[i])) { + return errors::InvalidArgument( + "Error parsing shape proto during cpp shape inference"); + } + } + // Run shape inference. tensorflow::shape_inference::InferenceContext c( &node, op_reg_data->op_def, input_shapes, input_tensors, - {} /* input_tensors_as_shapes */, input_handle_shapes, - input_handle_dtypes); + input_tensor_as_shapes_protos, input_handle_shapes, input_handle_dtypes); TF_RETURN_IF_ERROR(c.construction_status()); TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); @@ -130,16 +140,17 @@ Status RunCppShapeInferenceImpl( std::vector<string> RunCppShapeInference( const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, - PyObject* input_constant_tensor_values, TF_Status* out_status) { + PyObject* input_constant_tensor_values, + const std::vector<string>& input_constant_tensor_as_shape_values, + TF_Status* out_status) { if (!PyList_Check(input_constant_tensor_values)) { TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value"); return std::vector<string>(); } std::vector<PyObject*> input_constant_tensor_values_v; - int num_input_constant_tensor_values = - PyList_Size(input_constant_tensor_values); - for (int i = 0; i < num_input_constant_tensor_values; ++i) { + int cnt = PyList_Size(input_constant_tensor_values); + for (int i = 0; i < cnt; ++i) { input_constant_tensor_values_v.push_back( PyList_GetItem(input_constant_tensor_values, i)); } @@ -147,7 +158,8 @@ std::vector<string> RunCppShapeInference( std::vector<string> output_tensor_shape_protos; tensorflow::Status status = RunCppShapeInferenceImpl( serialized_node_def, input_serialized_shapes, - input_constant_tensor_values_v, &output_tensor_shape_protos); + input_constant_tensor_values_v, input_constant_tensor_as_shape_values, + &output_tensor_shape_protos); Set_TF_Status_from_Status(out_status, status); return status.ok() ? output_tensor_shape_protos : std::vector<string>(); diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h index f91af8e1a8..b489382993 100644 --- a/tensorflow/python/framework/cpp_shape_inference.h +++ b/tensorflow/python/framework/cpp_shape_inference.h @@ -44,7 +44,9 @@ namespace swig { std::vector<string> RunCppShapeInference( const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, - PyObject* input_constant_tensor_values, TF_Status* out_status); + PyObject* input_constant_tensor_values, + const std::vector<string>& input_constant_tensor_as_shape_values, + TF_Status* out_status); } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index a68f722244..8e62be107b 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -100,7 +100,7 @@ class ReshapeTest(tf.test.TestCase): def testErrors(self): y = tf.constant(0.0, shape=[23, 29, 31]) - with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"): + with self.assertRaisesRegexp(ValueError, "must be evenly divisible by 17"): tf.reshape(y, [17, -1]) z = tf.constant(0.0, shape=[32, 128]) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 6474d54f66..533b5fcb98 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1774,6 +1774,10 @@ ops.RegisterShape("Bitcast")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("Reshape") +def _DelegateReshapeShape(op): + return common_shapes.call_cpp_shape_fn(op, input_tensors_as_shapes_needed=[1]) + + def _ReshapeShape(op): """Shape function for Reshape op.""" input_shape = op.inputs[0].get_shape() |