aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-01 09:13:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 10:19:48 -0700
commit6d4956edb02c5a2d0a9bdeb919a4c4c0e55fd882 (patch)
treeb4fec91515b41649be5588797cf4d9aca02e30f9
parenta964d5f3c8a6ad924379a19e9924afae7c428d62 (diff)
Allow passing constant_value_as_shape from call_cpp_shape_fn to the C++ shape
inference. Delegate to C++ shape function for Reshape. Fix reshape to handle attempting to infer unknown dim when product of known elements is 0. Change: 137837591
-rw-r--r--tensorflow/core/framework/shape_inference.cc13
-rw-r--r--tensorflow/core/framework/shape_inference.h2
-rw-r--r--tensorflow/core/ops/array_ops.cc2
-rw-r--r--tensorflow/core/ops/array_ops_test.cc8
-rw-r--r--tensorflow/python/framework/common_shapes.py29
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.cc32
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h4
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py2
-rw-r--r--tensorflow/python/ops/array_ops.py4
9 files changed, 72 insertions, 24 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4aa32f6a84..f6475e0736 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -31,11 +31,20 @@ InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
: node_def_(*CHECK_NOTNULL(node_def)) {
- PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
+ std::vector<ShapeHandle> input_tensors_as_shape_handles;
+ for (const TensorShapeProto& p : input_tensors_as_shapes) {
+ ShapeHandle shape;
+ construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
+ if (!construction_status_.ok()) {
+ return;
+ }
+ input_tensors_as_shape_handles.push_back(shape);
+ }
+ PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
if (!construction_status_.ok()) return;
for (const TensorShapeProto& p : input_shapes) {
ShapeHandle shape;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index e02490efd9..1a8107ef00 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -164,7 +164,7 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index a7fb3375c8..0d3e6fa94a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -151,7 +151,7 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems));
}
}
- if (!too_many_unknown) {
+ if (!too_many_unknown && c->Value(known_elems) != 0) {
DimensionHandle inferred_dim;
TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems),
true /* evenly_divisible */, &inferred_dim));
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 7f7861384c..4e10d72816 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -718,6 +718,14 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) {
INFER_ERROR(
"Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op,
"[1,2];[0]");
+
+ // Reshaping a tensor with no elements.
+ new_shape = test::AsTensor<int32>({-1});
+ INFER_OK(op, "[0];[1]", "[0]");
+ new_shape = test::AsTensor<int32>({-1, 6});
+ INFER_OK(op, "[0,2];[1]", "[0,6]");
+ new_shape = test::AsTensor<int32>({0, -1});
+ INFER_OK(op, "[0,2];[1]", "[0,?]");
}
TEST(ArrayOpsTest, QuantizedReshape_ShapeFn) {
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 09afe56b19..c8867b4b1b 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -552,7 +552,9 @@ def broadcast_shape(shape_x, shape_y):
return tensor_shape.TensorShape(return_dims)
-def call_cpp_shape_fn(op, input_tensors_needed=None,
+def call_cpp_shape_fn(op,
+ input_tensors_needed=None,
+ input_tensors_as_shapes_needed=None,
debug_python_shape_fn=None):
"""A shape function that delegates to the registered C++ shape function.
@@ -560,6 +562,8 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
op: the node in the graph for which to compute output shapes.
input_tensors_needed: a list of input tensor indices for which to compute
the input tensor's value and pass to the C++ shape function.
+ input_tensors_as_shapes_needed: a list of input tensor indices for which to
+ compute the constant_value_as_shape and pass to the C++ shape function.
debug_python_shape_fn: For testing only during migration to using
call_cpp_shape_fn. Do not submit calls that set this,
as the comparison is slow. If non-None, the python shape function;
@@ -594,16 +598,25 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
input_tensors = [None for i in input_shapes]
if input_tensors_needed:
for idx in input_tensors_needed:
- input_tensors[idx] = tensor_util.constant_value(op.inputs[idx])
- if input_tensors[idx] is not None:
- input_tensors[idx] = np.asarray(input_tensors[idx])
+ v = tensor_util.constant_value(op.inputs[idx])
+ if v is not None:
+ input_tensors[idx] = np.asarray(v)
+
+ serialized_unknown_shape = (
+ tensor_shape.TensorShape(None).as_proto().SerializeToString())
+ arr = [serialized_unknown_shape for i in input_shapes]
+ if input_tensors_as_shapes_needed:
+ for idx in input_tensors_as_shapes_needed:
+ s = tensor_util.constant_value_as_shape(op.inputs[idx])
+ if s is not None:
+ arr[idx] = s.as_proto().SerializeToString()
+ input_tensors_as_shapes = arr
try:
with errors.raise_exception_on_not_ok_status() as status:
- output_shapes = pywrap_tensorflow.RunCppShapeInference(node_def_str,
- input_shapes,
- input_tensors,
- status)
+ output_shapes = pywrap_tensorflow.RunCppShapeInference(
+ node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,
+ status)
except errors.InvalidArgumentError as err:
raise ValueError(err.message)
diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc
index 57b85e8118..7620b52f9f 100644
--- a/tensorflow/python/framework/cpp_shape_inference.cc
+++ b/tensorflow/python/framework/cpp_shape_inference.cc
@@ -50,6 +50,7 @@ Status RunCppShapeInferenceImpl(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
const std::vector<PyObject*>& input_constant_tensor_values,
+ const std::vector<string>& input_constant_tensor_as_shape_values,
std::vector<string>* output_tensor_shape_protos) {
tensorflow::NodeDef node;
if (!node.ParseFromString(serialized_node_def)) {
@@ -87,10 +88,9 @@ Status RunCppShapeInferenceImpl(
}
// Convert input tensor values;
- const int num_input_tensors = input_constant_tensor_values.size();
- std::vector<Tensor> input_tensor_values(num_input_tensors);
+ std::vector<Tensor> input_tensor_values(input_constant_tensor_values.size());
std::vector<const Tensor*> input_tensors;
- for (int i = 0; i < num_input_tensors; ++i) {
+ for (int i = 0; i < input_constant_tensor_values.size(); ++i) {
auto* py_val = input_constant_tensor_values[i];
if (py_val == Py_None) {
input_tensors.push_back(nullptr);
@@ -101,11 +101,21 @@ Status RunCppShapeInferenceImpl(
}
}
+ // Convert input tensor-as-shape values;
+ std::vector<TensorShapeProto> input_tensor_as_shapes_protos(
+ input_constant_tensor_as_shape_values.size());
+ for (int i = 0; i < input_constant_tensor_as_shape_values.size(); ++i) {
+ if (!input_tensor_as_shapes_protos[i].ParseFromString(
+ input_constant_tensor_as_shape_values[i])) {
+ return errors::InvalidArgument(
+ "Error parsing shape proto during cpp shape inference");
+ }
+ }
+
// Run shape inference.
tensorflow::shape_inference::InferenceContext c(
&node, op_reg_data->op_def, input_shapes, input_tensors,
- {} /* input_tensors_as_shapes */, input_handle_shapes,
- input_handle_dtypes);
+ input_tensor_as_shapes_protos, input_handle_shapes, input_handle_dtypes);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
@@ -130,16 +140,17 @@ Status RunCppShapeInferenceImpl(
std::vector<string> RunCppShapeInference(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
- PyObject* input_constant_tensor_values, TF_Status* out_status) {
+ PyObject* input_constant_tensor_values,
+ const std::vector<string>& input_constant_tensor_as_shape_values,
+ TF_Status* out_status) {
if (!PyList_Check(input_constant_tensor_values)) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value");
return std::vector<string>();
}
std::vector<PyObject*> input_constant_tensor_values_v;
- int num_input_constant_tensor_values =
- PyList_Size(input_constant_tensor_values);
- for (int i = 0; i < num_input_constant_tensor_values; ++i) {
+ int cnt = PyList_Size(input_constant_tensor_values);
+ for (int i = 0; i < cnt; ++i) {
input_constant_tensor_values_v.push_back(
PyList_GetItem(input_constant_tensor_values, i));
}
@@ -147,7 +158,8 @@ std::vector<string> RunCppShapeInference(
std::vector<string> output_tensor_shape_protos;
tensorflow::Status status = RunCppShapeInferenceImpl(
serialized_node_def, input_serialized_shapes,
- input_constant_tensor_values_v, &output_tensor_shape_protos);
+ input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
+ &output_tensor_shape_protos);
Set_TF_Status_from_Status(out_status, status);
return status.ok() ? output_tensor_shape_protos : std::vector<string>();
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
index f91af8e1a8..b489382993 100644
--- a/tensorflow/python/framework/cpp_shape_inference.h
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -44,7 +44,9 @@ namespace swig {
std::vector<string> RunCppShapeInference(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
- PyObject* input_constant_tensor_values, TF_Status* out_status);
+ PyObject* input_constant_tensor_values,
+ const std::vector<string>& input_constant_tensor_as_shape_values,
+ TF_Status* out_status);
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index a68f722244..8e62be107b 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -100,7 +100,7 @@ class ReshapeTest(tf.test.TestCase):
def testErrors(self):
y = tf.constant(0.0, shape=[23, 29, 31])
- with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
+ with self.assertRaisesRegexp(ValueError, "must be evenly divisible by 17"):
tf.reshape(y, [17, -1])
z = tf.constant(0.0, shape=[32, 128])
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 6474d54f66..533b5fcb98 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1774,6 +1774,10 @@ ops.RegisterShape("Bitcast")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("Reshape")
+def _DelegateReshapeShape(op):
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_as_shapes_needed=[1])
+
+
def _ReshapeShape(op):
"""Shape function for Reshape op."""
input_shape = op.inputs[0].get_shape()