diff options
Diffstat (limited to 'tensorflow/core/framework/tensor_shape.h')
-rw-r--r-- | tensorflow/core/framework/tensor_shape.h | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index bd80215849..e341ceddfb 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -143,6 +143,9 @@ class TensorShape { void RecomputeNumElements(); + void CheckDimsEqual(int NDIMS) const; + void CheckDimsAtLeast(int NDIMS) const; + // We use 16 bytes to represent a TensorShape. Because we need to // be able to support full 64-bit dimension sizes and an arbitrary // number of dimensions for a Tensor, but most tensor dimensions are @@ -266,16 +269,14 @@ class TensorShapeUtils { template <int NDIMS> Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { - CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS - << " for a tensor of " << dims() << " dimensions"; + CheckDimsEqual(NDIMS); return AsEigenDSizesWithPadding<NDIMS>(); } template <int NDIMS> Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() const { - CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS - << " for a tensor of " << dims() << " dimensions"; + CheckDimsAtLeast(NDIMS); Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; for (int d = 0; d < dims(); d++) { dsizes[d] = dim_size(d); |