diff options
author | 2018-01-03 15:10:58 -0800 | |
---|---|---|
committer | 2018-01-03 15:15:02 -0800 | |
commit | 5f0d3395d4c61000cf0cfb3dc681177147be938d (patch) | |
tree | 83185514a44c21a15356a4d46c499cb96faa49d5 | |
parent | 2f83be3379e28fb2732a9f22034e33dbfdf37c77 (diff) |
Fix tf.nn.fractional_max_pool output have same batch size when feed with different input batch size. Fixes #14985.
PiperOrigin-RevId: 180724096
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/fractional_avg_pool_op.cc | 94 | ||||
-rw-r--r-- | tensorflow/core/kernels/fractional_max_pool_op.cc | 104 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py | 31 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/fractional_max_pool_op_test.py | 31 |
6 files changed, 155 insertions, 108 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ae39c4522d..a1b62179f5 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3370,6 +3370,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", "//third_party/eigen3", ], diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc index bfdb7b4a1e..47f4189c30 100644 --- a/tensorflow/core/kernels/fractional_avg_pool_op.cc +++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/guarded_philox_random.h" @@ -47,9 +48,20 @@ class FractionalAvgPoolOp : public OpKernel { errors::Unimplemented("Fractional average pooling is not yet " "supported on the batch nor channel dimension.")); OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_)); - pooling_region_generated_ = false; - // Initialize philox random generator. - OP_REQUIRES_OK(context, generator_.Init(context)); + OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); + OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); + if (deterministic_) { + // If both seeds are not set when deterministic_ is true, force set seeds. + if ((seed_ == 0) && (seed2_ == 0)) { + seed_ = random::New64(); + seed2_ = random::New64(); + } + } else { + OP_REQUIRES( + context, (seed_ == 0) && (seed2_ == 0), + errors::InvalidArgument( + "Both seed and seed2 should be 0 if deterministic is false.")); + } } void Compute(OpKernelContext* context) override { @@ -64,47 +76,35 @@ class FractionalAvgPoolOp : public OpKernel { OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims, errors::InvalidArgument("tensor_in must be 4-dimensional")); + std::vector<int> input_size(tensor_in_and_out_dims); + std::vector<int> output_size(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { - input_size_.push_back(tensor_in.dim_size(i)); + input_size[i] = tensor_in.dim_size(i); } // Output size. for (int i = 0; i < tensor_in_and_out_dims; ++i) { - output_size_.push_back( - static_cast<int>(floor(input_size_[i] / pooling_ratio_[i]))); - DCHECK_GT(output_size_[i], 0); + output_size[i] = + static_cast<int>(floor(input_size[i] / pooling_ratio_[i])); + DCHECK_GT(output_size[i], 0); } // Generate pooling sequence. std::vector<int64> row_cum_seq; std::vector<int64> col_cum_seq; - if (deterministic_) { - if (pooling_region_generated_) { - row_cum_seq = row_cum_seq_; - col_cum_seq = col_cum_seq_; - } else { - row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], - &generator_, pseudo_random_); - col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], - &generator_, pseudo_random_); - mutex_lock lock(mu_); - row_cum_seq_ = row_cum_seq; - col_cum_seq_ = col_cum_seq; - pooling_region_generated_ = true; - } - } else { - row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], - &generator_, pseudo_random_); - col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], - &generator_, pseudo_random_); - } + GuardedPhiloxRandom generator; + generator.Init(seed_, seed2_); + row_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1], + &generator, pseudo_random_); + col_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2], + &generator, pseudo_random_); // Prepare output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({output_size_[0], output_size_[1], - output_size_[2], output_size_[3]}), - &output_tensor)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, + TensorShape({output_size[0], output_size[1], + output_size[2], output_size[3]}), + &output_tensor)); Tensor* output_row_seq_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output( @@ -116,12 +116,11 @@ class FractionalAvgPoolOp : public OpKernel { 2, TensorShape({static_cast<int64>(col_cum_seq.size())}), &output_col_seq_tensor)); - ConstEigenMatrixMap in_mat( - tensor_in.flat<T>().data(), input_size_[3], - input_size_[2] * input_size_[1] * input_size_[0]); + ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3], + input_size[2] * input_size[1] * input_size[0]); - EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3], - output_size_[2] * output_size_[1] * output_size_[0]); + EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3], + output_size[2] * output_size[1] * output_size[0]); // out_count corresponds to number of elements in each pooling cell. Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols()); @@ -146,9 +145,9 @@ class FractionalAvgPoolOp : public OpKernel { // 1: row / row // 2: col / col // 3: depth / channel - const int64 row_max = input_size_[1] - 1; - const int64 col_max = input_size_[2] - 1; - for (int64 b = 0; b < input_size_[0]; ++b) { + const int64 row_max = input_size[1] - 1; + const int64 col_max = input_size[2] - 1; + for (int64 b = 0; b < input_size[0]; ++b) { // row sequence. for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) { // row start and end. @@ -160,7 +159,7 @@ class FractionalAvgPoolOp : public OpKernel { // col sequence. for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) { const int64 out_offset = - (b * output_size_[1] + hs) * output_size_[2] + ws; + (b * output_size[1] + hs) * output_size[2] + ws; // col start and end. const int64 col_start = col_cum_seq[ws]; int64 col_end = @@ -169,7 +168,7 @@ class FractionalAvgPoolOp : public OpKernel { for (int64 h = row_start; h <= row_end; ++h) { for (int64 w = col_start; w <= col_end; ++w) { const int64 in_offset = - (b * input_size_[1] + h) * input_size_[2] + w; + (b * input_size[1] + h) * input_size[2] + w; out_mat.col(out_offset) += in_mat.col(in_offset); out_count(out_offset)++; } @@ -183,18 +182,11 @@ class FractionalAvgPoolOp : public OpKernel { private: bool deterministic_; - // meaningful only when deterministic_ is true. - mutex mu_; - std::vector<int64> row_cum_seq_; - std::vector<int64> col_cum_seq_; - bool pooling_region_generated_; - - std::vector<int32> input_size_; - std::vector<int32> output_size_; + int64 seed_; + int64 seed2_; std::vector<float> pooling_ratio_; bool pseudo_random_; bool overlapping_; - GuardedPhiloxRandom generator_; }; #define REGISTER_FRACTIONALAVGPOOL(type) \ diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc index 33d73c8477..cf580adab2 100644 --- a/tensorflow/core/kernels/fractional_max_pool_op.cc +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/guarded_philox_random.h" @@ -50,9 +51,20 @@ class FractionalMaxPoolOp : public OpKernel { "supported on the batch nor channel dimension.")); OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_)); - pooling_region_generated_ = false; - // Initialize philox random generator. - OP_REQUIRES_OK(context, generator_.Init(context)); + OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); + OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); + if (deterministic_) { + // If both seeds are not set when deterministic_ is true, force set seeds. + if ((seed_ == 0) && (seed2_ == 0)) { + seed_ = random::New64(); + seed2_ = random::New64(); + } + } else { + OP_REQUIRES( + context, (seed_ == 0) && (seed2_ == 0), + errors::InvalidArgument( + "Both seed and seed2 should be 0 if deterministic is false.")); + } } void Compute(OpKernelContext* context) override { @@ -67,49 +79,37 @@ class FractionalMaxPoolOp : public OpKernel { OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims, errors::InvalidArgument("tensor_in must be 4-dimensional")); + std::vector<int> input_size(tensor_in_and_out_dims); + std::vector<int> output_size(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { - input_size_.push_back(tensor_in.dim_size(i)); + input_size[i] = tensor_in.dim_size(i); } // Output size. for (int i = 0; i < tensor_in_and_out_dims; ++i) { // This must match the same logic in the shape function in // core/ops/nn_ops.cc. - output_size_.push_back( - static_cast<int>(floor(input_size_[i] / pooling_ratio_[i]))); - DCHECK_GT(output_size_[i], 0); + output_size[i] = + static_cast<int>(floor(input_size[i] / pooling_ratio_[i])); + DCHECK_GT(output_size[i], 0); } // Generate pooling sequence. std::vector<int64> height_cum_seq; std::vector<int64> width_cum_seq; - if (deterministic_) { - if (pooling_region_generated_) { - height_cum_seq = height_cum_seq_; - width_cum_seq = width_cum_seq_; - } else { - height_cum_seq = GeneratePoolingSequence( - input_size_[1], output_size_[1], &generator_, pseudo_random_); - width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], - &generator_, pseudo_random_); - mutex_lock lock(mu_); - height_cum_seq_ = height_cum_seq; - width_cum_seq_ = width_cum_seq; - pooling_region_generated_ = true; - } - } else { - height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], - &generator_, pseudo_random_); - width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], - &generator_, pseudo_random_); - } + GuardedPhiloxRandom generator; + generator.Init(seed_, seed2_); + height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1], + &generator, pseudo_random_); + width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2], + &generator, pseudo_random_); // Prepare output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({output_size_[0], output_size_[1], - output_size_[2], output_size_[3]}), - &output_tensor)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, + TensorShape({output_size[0], output_size[1], + output_size[2], output_size[3]}), + &output_tensor)); Tensor* output_height_seq_tensor = nullptr; OP_REQUIRES_OK( context, @@ -122,12 +122,11 @@ class FractionalMaxPoolOp : public OpKernel { 2, TensorShape({static_cast<int64>(width_cum_seq.size())}), &output_width_seq_tensor)); - ConstEigenMatrixMap in_mat( - tensor_in.flat<T>().data(), input_size_[3], - input_size_[2] * input_size_[1] * input_size_[0]); + ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3], + input_size[2] * input_size[1] * input_size[0]); - EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3], - output_size_[2] * output_size_[1] * output_size_[0]); + EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3], + output_size[2] * output_size[1] * output_size[0]); // Initializes the output tensor with MIN<T>. output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest()); @@ -149,9 +148,9 @@ class FractionalMaxPoolOp : public OpKernel { // 1: height / row // 2: width / col // 3: depth / channel - const int64 height_max = input_size_[1] - 1; - const int64 width_max = input_size_[2] - 1; - for (int64 b = 0; b < input_size_[0]; ++b) { + const int64 height_max = input_size[1] - 1; + const int64 width_max = input_size[2] - 1; + for (int64 b = 0; b < input_size[0]; ++b) { // height sequence. for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) { // height start and end. @@ -163,7 +162,7 @@ class FractionalMaxPoolOp : public OpKernel { // width sequence. for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) { const int64 out_offset = - (b * output_size_[1] + hs) * output_size_[2] + ws; + (b * output_size[1] + hs) * output_size[2] + ws; // width start and end. const int64 width_start = width_cum_seq[ws]; int64 width_end = @@ -172,7 +171,7 @@ class FractionalMaxPoolOp : public OpKernel { for (int64 h = height_start; h <= height_end; ++h) { for (int64 w = width_start; w <= width_end; ++w) { const int64 in_offset = - (b * input_size_[1] + h) * input_size_[2] + w; + (b * input_size[1] + h) * input_size[2] + w; out_mat.col(out_offset) = out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); } @@ -184,18 +183,11 @@ class FractionalMaxPoolOp : public OpKernel { private: bool deterministic_; - // meaningful only when deterministic_ is true. - mutex mu_; - std::vector<int64> height_cum_seq_; - std::vector<int64> width_cum_seq_; - bool pooling_region_generated_; - - std::vector<int32> input_size_; - std::vector<int32> output_size_; + int64 seed_; + int64 seed2_; std::vector<float> pooling_ratio_; bool pseudo_random_; bool overlapping_; - GuardedPhiloxRandom generator_; }; #define REGISTER_FRACTIONALMAXPOOL(type) \ @@ -243,15 +235,13 @@ class FractionalMaxPoolGradOp : public OpKernel { // Just to make it similar to FractionalMaxPoolOp. constexpr int tensor_in_and_out_dims = 4; - std::vector<int64> input_size; - std::vector<int64> output_size; - input_size.reserve(tensor_in_and_out_dims); + std::vector<int64> input_size(tensor_in_and_out_dims); + std::vector<int64> output_size(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { - input_size.push_back(tensor_in.dim_size(i)); + input_size[i] = tensor_in.dim_size(i); } - output_size.reserve(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { - output_size.push_back(tensor_out.dim_size(i)); + output_size[i] = tensor_out.dim_size(i); } // --------- diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d7403fe6ee..e5f65edd39 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -369,6 +369,7 @@ tf_py_test( srcs = ["fractional_avg_pool_op_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn_grad", @@ -383,6 +384,7 @@ tf_py_test( srcs = ["fractional_max_pool_op_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn_grad", diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py index 48a51c8072..feec9934e4 100644 --- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py +++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py @@ -23,6 +23,8 @@ import math import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops @@ -310,6 +312,35 @@ class FractionalAvgTest(test.TestCase): self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random, overlapping) + def testDifferentInputTensorShape(self): + """Runs the operation in one session with different input tensor shapes.""" + with self.test_session() as sess: + input_holder = array_ops.placeholder(dtypes.float32, + [None, None, None, 3]) + pooling_ratio = [1, 1.5, 1.5, 1] + pseudo_random = False + overlapping = False + p, r, c = nn_ops.fractional_avg_pool( + input_holder, + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + # First run. + input_a = np.zeros([3, 32, 32, 3]) + actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a}) + expected = self._GetExpectedFractionalAvgPoolResult( + input_a, row_seq, col_seq, overlapping) + self.assertSequenceEqual(expected.shape, actual.shape) + # Second run. + input_b = np.zeros([4, 60, 60, 3]) + actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b}) + expected = self._GetExpectedFractionalAvgPoolResult( + input_b, row_seq, col_seq, overlapping) + self.assertSequenceEqual(expected.shape, actual.shape) + class FractionalAvgPoolGradTest(test.TestCase): """Tests for FractionalAvgPoolGrad. diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py index d380c31de3..5983ae7759 100644 --- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py +++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py @@ -23,6 +23,8 @@ import math import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops @@ -281,6 +283,35 @@ class FractionalMaxPoolTest(test.TestCase): self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random, overlapping) + def testDifferentInputTensorShape(self): + """Runs the operation in one session with different input tensor shapes.""" + with self.test_session() as sess: + input_holder = array_ops.placeholder(dtypes.float32, + [None, None, None, 3]) + pooling_ratio = [1, 1.5, 1.5, 1] + pseudo_random = False + overlapping = False + p, r, c = nn_ops.fractional_max_pool( + input_holder, + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + # First run. + input_a = np.zeros([3, 32, 32, 3]) + actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a}) + expected = self._GetExpectedFractionalMaxPoolResult( + input_a, row_seq, col_seq, overlapping) + self.assertSequenceEqual(expected.shape, actual.shape) + # Second run. + input_b = np.zeros([4, 45, 45, 3]) + actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b}) + expected = self._GetExpectedFractionalMaxPoolResult( + input_b, row_seq, col_seq, overlapping) + self.assertSequenceEqual(expected.shape, actual.shape) + class FractionalMaxPoolGradTest(test.TestCase): """Tests for FractionalMaxPoolGrad. |