aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/constant_op.cc11
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc36
-rw-r--r--tensorflow/core/kernels/edit_distance_op.cc13
-rw-r--r--tensorflow/core/kernels/random_op.cc12
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op.cc9
-rw-r--r--tensorflow/core/public/tensor_shape.h15
-rw-r--r--tensorflow/python/framework/tensor_shape.py2
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py27
8 files changed, 80 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index e33751b6ac..7f6214e75e 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -153,13 +153,12 @@ class FillOp : public OpKernel {
errors::InvalidArgument("dims[", i, "] = ", dims(i),
" must be nonnegative."));
}
+ TensorShape shape;
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ reinterpret_cast<const int32*>(dims.data()),
+ dims.size(), &shape));
Tensor* out = nullptr;
- OP_REQUIRES_OK(
- context,
- context->allocate_output(
- 0, TensorShapeUtils::MakeShape(
- reinterpret_cast<const int32*>(dims.data()), dims.size()),
- &out));
+ OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out));
functor::FillFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), out->flat<T>(),
Tvalue.scalar<T>());
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fffafa2db4..ac0ed0a404 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -264,16 +264,18 @@ typedef Eigen::GpuDevice GPUDevice;
<< ", strides = " << strides[1]
namespace {
-TensorShape VectorToShape(const TTypes<int32>::ConstVec& sizes) {
- TensorShape shape;
-
+Status VectorToShape(const TTypes<int32>::ConstVec& sizes, TensorShape* out) {
using Index = TTypes<int32>::ConstVec::Index;
const Index dims = sizes.size();
for (Index i = 0; i < dims; ++i) {
- shape.AddDim(sizes(i));
+ if (sizes(i) >= 0) {
+ out->AddDim(sizes(i));
+ } else {
+ return errors::InvalidArgument("Dimension ", sizes(i), " must be >= 0");
+ }
}
- return shape;
+ return Status::OK();
}
} // namespace
@@ -309,7 +311,9 @@ class Conv2DFastBackpropInputOp : public OpKernel {
errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_sizes.dims()));
- TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ TensorShape input_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(input_sizes.vec<int32>(), &input_shape));
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
@@ -359,7 +363,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_sizes.dims()));
- TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ TensorShape input_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(input_sizes.vec<int32>(), &input_shape));
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
@@ -566,7 +572,9 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_sizes.dims()));
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+ TensorShape filter_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(filter_sizes.vec<int32>(), &filter_shape));
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter");
Tensor* filter_backprop = nullptr;
@@ -618,7 +626,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
"not ",
filter_sizes.dims()));
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+ TensorShape filter_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(filter_sizes.vec<int32>(), &filter_shape));
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DCustomBackpropFilter");
Tensor* filter_backprop;
@@ -790,7 +800,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_sizes.dims()));
- TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ TensorShape input_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(input_sizes.vec<int32>(), &input_shape));
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
@@ -1067,7 +1079,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_sizes.dims()));
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+ TensorShape filter_shape;
+ OP_REQUIRES_OK(context,
+ VectorToShape(filter_sizes.vec<int32>(), &filter_shape));
EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter");
Tensor* filter_backprop = nullptr;
diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc
index 16fdc25445..85b546dc7f 100644
--- a/tensorflow/core/kernels/edit_distance_op.cc
+++ b/tensorflow/core/kernels/edit_distance_op.cc
@@ -118,10 +118,15 @@ class EditDistanceOp : public OpKernel {
*hypothesis_shape, *truth_indices, *truth_values,
*truth_shape));
- TensorShape hypothesis_st_shape = TensorShapeUtils::MakeShape(
- hypothesis_shape->vec<int64>().data(), hypothesis_shape->NumElements());
- TensorShape truth_st_shape = TensorShapeUtils::MakeShape(
- truth_shape->vec<int64>().data(), truth_shape->NumElements());
+ TensorShape hypothesis_st_shape;
+ OP_REQUIRES_OK(
+ ctx, TensorShapeUtils::MakeShape(hypothesis_shape->vec<int64>().data(),
+ hypothesis_shape->NumElements(),
+ &hypothesis_st_shape));
+ TensorShape truth_st_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
+ truth_shape->vec<int64>().data(),
+ truth_shape->NumElements(), &truth_st_shape));
// Assume indices are sorted in row-major order.
std::vector<int64> sorted_order(truth_st_shape.dims());
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index c0f8f77d1d..bb1566b4e8 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -187,12 +187,16 @@ static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
- TF_RETURN_IF_ERROR(ctx->allocate_output(
- index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output));
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
- TF_RETURN_IF_ERROR(ctx->allocate_output(
- index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output));
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc
index 9f33aee371..3de5132049 100644
--- a/tensorflow/core/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/core/kernels/sparse_to_dense_op.cc
@@ -81,11 +81,12 @@ class SparseToDense : public OpKernel {
errors::InvalidArgument("default_value should be a scalar."));
auto output_shape_vec = output_shape.flat<Index>();
+ TensorShape output_tensor_shape;
+ OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(),
+ output_shape_vec.size(),
+ &output_tensor_shape));
Tensor* output = nullptr;
- OP_REQUIRES_OK(c, c->allocate_output(0, TensorShapeUtils::MakeShape(
- output_shape_vec.data(),
- output_shape_vec.size()),
- &output));
+ OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output));
TensorShape ix_shape({num_elems, num_dims});
Tensor indices_shaped(DT_INT64, ix_shape);
diff --git a/tensorflow/core/public/tensor_shape.h b/tensorflow/core/public/tensor_shape.h
index 0d057eb7ac..0759185d09 100644
--- a/tensorflow/core/public/tensor_shape.h
+++ b/tensorflow/core/public/tensor_shape.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -210,10 +211,16 @@ class TensorShapeUtils {
/// \brief Returns a `TensorShape` whose dimensions are
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
template <typename T>
- static TensorShape MakeShape(const T* dims, int n) {
- TensorShape shape;
- for (int i = 0; i < n; ++i) shape.AddDim(dims[i]);
- return shape;
+ static Status MakeShape(const T* dims, int n, TensorShape* out) {
+ *out = TensorShape();
+ for (int i = 0; i < n; ++i) {
+ if (dims[i] >= 0) {
+ out->AddDim(dims[i]);
+ } else {
+ return errors::InvalidArgument("Dimension ", dims[i], " must be >= 0");
+ }
+ }
+ return Status::OK();
}
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index c692e66577..fe4482740a 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -32,6 +32,8 @@ class Dimension(object):
self._value = None
else:
self._value = int(value)
+ if self._value < 0:
+ raise ValueError("Dimension %d must be >= 0" % self._value)
def __repr__(self):
return "Dimension(%s)" % repr(self._value)
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 5bc7f2b37d..253244b412 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -465,6 +465,19 @@ class FillTest(tf.test.TestCase):
tf_ans = tf.fill([2, 3], np_ans[0][0], name="fill").eval()
self.assertAllEqual(np_ans, tf_ans)
+ def testFillNegative(self):
+ with self.test_session():
+ for shape in (-1,), (2, -1), (-1, 2):
+ with self.assertRaises(ValueError):
+ tf.fill(shape, 7)
+
+ # Using a placeholder so this won't be caught in Python.
+ dims = tf.placeholder(tf.int32)
+ fill_t = tf.fill(dims, 3.0)
+ for shape in (-1,), (2, -1), (-1, 2):
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ fill_t.eval({dims: shape})
+
def testShapeFunctionEdgeCases(self):
# Non-vector dimensions.
with self.assertRaises(ValueError):
@@ -531,19 +544,9 @@ class PlaceholderTest(tf.test.TestCase):
d = tf.mul(p, c)
self.assertEqual(10, d.eval(feed_dict={p: 2}))
- def testFillNegative(self):
- with self.test_session():
- for shape in (-1,), (2, -1), (-1, 2):
- with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
- " must be nonnegative"):
- tf.fill(shape, 7).eval()
-
def testBadShape(self):
- with self.test_session():
- a = tf.placeholder(tf.float32, shape=(-1, 10))
- s = tf.shape(a)
- with self.assertRaisesOpError(r"Shape \[-1,10\] has negative dimensions"):
- s.eval()
+ with self.assertRaises(ValueError):
+ tf.placeholder(tf.float32, shape=(-1, 10))
def testTensorStr(self):
a = tf.placeholder(tf.float32, name="a")