aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-06-16 15:24:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-16 16:32:46 -0700
commit3d192c5a3c7cf7d130c1cc8d9fa5c304cb522029 (patch)
tree3d19366d310dbe46f5517620d9edba5abcec183d
parent69de6c42b90c7cc594dac0a8eae6b4b54a717027 (diff)
Make bias_add handle empty tensors
Change: 125115298
-rw-r--r--tensorflow/core/kernels/bias_op.cc53
-rw-r--r--tensorflow/python/kernel_tests/bias_op_test.py23
2 files changed, 52 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 9a5c961e61..46e12cff2a 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -76,6 +76,7 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
+ if (input.NumElements() == 0) return;
switch (input.shape().dims()) {
case 2:
@@ -202,18 +203,25 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
TensorShape output_shape{channel};
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ if (channel == 0) {
+ return; // Nothing to do
+ } else if (output_backprop.NumElements() == 0) {
+ // Eigen often crashes by design on empty tensors, but setZero is safe
+ output->template flat<T>().setZero();
+ } else {
+ Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
#ifdef EIGEN_HAS_INDEX_LIST
- Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
+ Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<int, 1> reduction_axis = {0};
#endif
- output->template flat<T>().device(context->eigen_device<CPUDevice>()) =
- output_backprop.flat<T>()
- .template cast<typename AccumulatorType<T>::type>()
- .reshape(two_dims)
- .sum(reduction_axis)
- .template cast<T>();
+ output->template flat<T>().device(context->eigen_device<CPUDevice>()) =
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(two_dims)
+ .sum(reduction_axis)
+ .template cast<T>();
+ }
}
private:
@@ -254,9 +262,6 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
errors::InvalidArgument("Biases must be 1D: ",
bias.shape().DebugString()));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input.shape(), &output));
int32 batch, height, width, channel;
GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
@@ -265,10 +270,15 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
"of the input tensor: ",
bias.shape().DebugString(), " vs. ", channel, " in ",
input.shape().DebugString()));
- BiasGPU<T>::compute(context->template eigen_device<Device>(),
- input.flat<T>().data(), bias.flat<T>().data(),
- output->flat<T>().data(), batch, width, height, channel,
- data_format_);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ if (input.NumElements() > 0) {
+ BiasGPU<T>::compute(context->template eigen_device<Device>(),
+ input.flat<T>().data(), bias.flat<T>().data(),
+ output->flat<T>().data(), batch, width, height,
+ channel, data_format_);
+ }
}
private:
@@ -314,15 +324,18 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
Tensor* output = nullptr;
TensorShape output_shape{channel};
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ if (channel == 0) return;
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
perftools::gputools::DeviceMemoryBase output_ptr(
output->flat<T>().data(), output->NumElements() * sizeof(T));
stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
- BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
- output_backprop.template flat<T>().data(),
- output->flat<T>().data(), batch, width, height,
- channel, data_format_);
+ if (output_backprop.NumElements() > 0) {
+ BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
+ output_backprop.template flat<T>().data(),
+ output->flat<T>().data(), batch, width, height,
+ channel, data_format_);
+ }
}
private:
diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py
index 1300934283..629875487a 100644
--- a/tensorflow/python/kernel_tests/bias_op_test.py
+++ b/tensorflow/python/kernel_tests/bias_op_test.py
@@ -56,11 +56,15 @@ class BiasAddTest(tf.test.TestCase):
tf_val = tf.nn.bias_add(np_inputs, np_bias).eval()
self.assertAllCloseAccordingToType(np_val, tf_val)
- def _NHWCToNCHW(self, np_value):
+ def _AtLeast3d(self, np_value):
# fill the input value to at least 3-dimension
if np_value.ndim < 3:
- np_value = np.reshape(np_value,
- (1,) * (3 - np_value.ndim) + np_value.shape)
+ return np.reshape(np_value, (1,) * (3 - np_value.ndim) + np_value.shape)
+ return np_value
+
+ def _NHWCToNCHW(self, np_value):
+ # fill the input value to at least 3-dimension
+ np_value = self._AtLeast3d(np_value)
# move the last dimension to third-to-last
np_dim = list(range(np_value.ndim))
np_dim_new = list(np_dim[0:-3]) + list(np_dim[-1:]) + list(np_dim[-3:-1])
@@ -79,7 +83,7 @@ class BiasAddTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu):
tf_val = tf.nn.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
tf_val = self._NCHWToNHWC(tf_val)
- self.assertAllCloseAccordingToType(np_val, tf_val)
+ self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
def _testAll(self, np_inputs, np_bias):
self._testBias(np_inputs, np_bias, use_gpu=False)
@@ -163,6 +167,17 @@ class BiasAddTest(tf.test.TestCase):
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
+ def testEmpty(self):
+ np.random.seed(7)
+ for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
+ self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
+
+ def testEmptyGradient(self):
+ for data_format, use_gpu in GetTestConfigs():
+ for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
+ self._testGradient(np.random.randn(*shape), np.random.randn(shape[-1]),
+ tf.float64, data_format, use_gpu)
+
if __name__ == "__main__":
tf.test.main()