aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_shape.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor_shape.h')
-rw-r--r--tensorflow/core/framework/tensor_shape.h9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index bd80215849..e341ceddfb 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -143,6 +143,9 @@ class TensorShape {
void RecomputeNumElements();
+ void CheckDimsEqual(int NDIMS) const;
+ void CheckDimsAtLeast(int NDIMS) const;
+
// We use 16 bytes to represent a TensorShape. Because we need to
// be able to support full 64-bit dimension sizes and an arbitrary
// number of dimensions for a Tensor, but most tensor dimensions are
@@ -266,16 +269,14 @@ class TensorShapeUtils {
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
- CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsEqual(NDIMS);
return AsEigenDSizesWithPadding<NDIMS>();
}
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
const {
- CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsAtLeast(NDIMS);
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < dims(); d++) {
dsizes[d] = dim_size(d);