aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-03 15:10:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-03 15:15:02 -0800
commit5f0d3395d4c61000cf0cfb3dc681177147be938d (patch)
tree83185514a44c21a15356a4d46c499cb96faa49d5
parent2f83be3379e28fb2732a9f22034e33dbfdf37c77 (diff)
Fix tf.nn.fractional_max_pool output have same batch size when feed with different input batch size. Fixes #14985.
PiperOrigin-RevId: 180724096
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/fractional_avg_pool_op.cc94
-rw-r--r--tensorflow/core/kernels/fractional_max_pool_op.cc104
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py31
-rw-r--r--tensorflow/python/kernel_tests/fractional_max_pool_op_test.py31
6 files changed, 155 insertions, 108 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ae39c4522d..a1b62179f5 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3370,6 +3370,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
],
diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc
index bfdb7b4a1e..47f4189c30 100644
--- a/tensorflow/core/kernels/fractional_avg_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@@ -47,9 +48,20 @@ class FractionalAvgPoolOp : public OpKernel {
errors::Unimplemented("Fractional average pooling is not yet "
"supported on the batch nor channel dimension."));
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
- pooling_region_generated_ = false;
- // Initialize philox random generator.
- OP_REQUIRES_OK(context, generator_.Init(context));
+ OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+ if (deterministic_) {
+ // If both seeds are not set when deterministic_ is true, force set seeds.
+ if ((seed_ == 0) && (seed2_ == 0)) {
+ seed_ = random::New64();
+ seed2_ = random::New64();
+ }
+ } else {
+ OP_REQUIRES(
+ context, (seed_ == 0) && (seed2_ == 0),
+ errors::InvalidArgument(
+ "Both seed and seed2 should be 0 if deterministic is false."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -64,47 +76,35 @@ class FractionalAvgPoolOp : public OpKernel {
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ std::vector<int> input_size(tensor_in_and_out_dims);
+ std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
- input_size_.push_back(tensor_in.dim_size(i));
+ input_size[i] = tensor_in.dim_size(i);
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
- output_size_.push_back(
- static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
- DCHECK_GT(output_size_[i], 0);
+ output_size[i] =
+ static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
+ DCHECK_GT(output_size[i], 0);
}
// Generate pooling sequence.
std::vector<int64> row_cum_seq;
std::vector<int64> col_cum_seq;
- if (deterministic_) {
- if (pooling_region_generated_) {
- row_cum_seq = row_cum_seq_;
- col_cum_seq = col_cum_seq_;
- } else {
- row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
- &generator_, pseudo_random_);
- col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
- &generator_, pseudo_random_);
- mutex_lock lock(mu_);
- row_cum_seq_ = row_cum_seq;
- col_cum_seq_ = col_cum_seq;
- pooling_region_generated_ = true;
- }
- } else {
- row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
- &generator_, pseudo_random_);
- col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
- &generator_, pseudo_random_);
- }
+ GuardedPhiloxRandom generator;
+ generator.Init(seed_, seed2_);
+ row_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
+ &generator, pseudo_random_);
+ col_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
+ &generator, pseudo_random_);
// Prepare output.
Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({output_size_[0], output_size_[1],
- output_size_[2], output_size_[3]}),
- &output_tensor));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0,
+ TensorShape({output_size[0], output_size[1],
+ output_size[2], output_size[3]}),
+ &output_tensor));
Tensor* output_row_seq_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
@@ -116,12 +116,11 @@ class FractionalAvgPoolOp : public OpKernel {
2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
&output_col_seq_tensor));
- ConstEigenMatrixMap in_mat(
- tensor_in.flat<T>().data(), input_size_[3],
- input_size_[2] * input_size_[1] * input_size_[0]);
+ ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
+ input_size[2] * input_size[1] * input_size[0]);
- EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
- output_size_[2] * output_size_[1] * output_size_[0]);
+ EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
+ output_size[2] * output_size[1] * output_size[0]);
// out_count corresponds to number of elements in each pooling cell.
Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
@@ -146,9 +145,9 @@ class FractionalAvgPoolOp : public OpKernel {
// 1: row / row
// 2: col / col
// 3: depth / channel
- const int64 row_max = input_size_[1] - 1;
- const int64 col_max = input_size_[2] - 1;
- for (int64 b = 0; b < input_size_[0]; ++b) {
+ const int64 row_max = input_size[1] - 1;
+ const int64 col_max = input_size[2] - 1;
+ for (int64 b = 0; b < input_size[0]; ++b) {
// row sequence.
for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
// row start and end.
@@ -160,7 +159,7 @@ class FractionalAvgPoolOp : public OpKernel {
// col sequence.
for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
const int64 out_offset =
- (b * output_size_[1] + hs) * output_size_[2] + ws;
+ (b * output_size[1] + hs) * output_size[2] + ws;
// col start and end.
const int64 col_start = col_cum_seq[ws];
int64 col_end =
@@ -169,7 +168,7 @@ class FractionalAvgPoolOp : public OpKernel {
for (int64 h = row_start; h <= row_end; ++h) {
for (int64 w = col_start; w <= col_end; ++w) {
const int64 in_offset =
- (b * input_size_[1] + h) * input_size_[2] + w;
+ (b * input_size[1] + h) * input_size[2] + w;
out_mat.col(out_offset) += in_mat.col(in_offset);
out_count(out_offset)++;
}
@@ -183,18 +182,11 @@ class FractionalAvgPoolOp : public OpKernel {
private:
bool deterministic_;
- // meaningful only when deterministic_ is true.
- mutex mu_;
- std::vector<int64> row_cum_seq_;
- std::vector<int64> col_cum_seq_;
- bool pooling_region_generated_;
-
- std::vector<int32> input_size_;
- std::vector<int32> output_size_;
+ int64 seed_;
+ int64 seed2_;
std::vector<float> pooling_ratio_;
bool pseudo_random_;
bool overlapping_;
- GuardedPhiloxRandom generator_;
};
#define REGISTER_FRACTIONALAVGPOOL(type) \
diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc
index 33d73c8477..cf580adab2 100644
--- a/tensorflow/core/kernels/fractional_max_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_max_pool_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@@ -50,9 +51,20 @@ class FractionalMaxPoolOp : public OpKernel {
"supported on the batch nor channel dimension."));
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
- pooling_region_generated_ = false;
- // Initialize philox random generator.
- OP_REQUIRES_OK(context, generator_.Init(context));
+ OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+ if (deterministic_) {
+ // If both seeds are not set when deterministic_ is true, force set seeds.
+ if ((seed_ == 0) && (seed2_ == 0)) {
+ seed_ = random::New64();
+ seed2_ = random::New64();
+ }
+ } else {
+ OP_REQUIRES(
+ context, (seed_ == 0) && (seed2_ == 0),
+ errors::InvalidArgument(
+ "Both seed and seed2 should be 0 if deterministic is false."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -67,49 +79,37 @@ class FractionalMaxPoolOp : public OpKernel {
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ std::vector<int> input_size(tensor_in_and_out_dims);
+ std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
- input_size_.push_back(tensor_in.dim_size(i));
+ input_size[i] = tensor_in.dim_size(i);
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
// This must match the same logic in the shape function in
// core/ops/nn_ops.cc.
- output_size_.push_back(
- static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
- DCHECK_GT(output_size_[i], 0);
+ output_size[i] =
+ static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
+ DCHECK_GT(output_size[i], 0);
}
// Generate pooling sequence.
std::vector<int64> height_cum_seq;
std::vector<int64> width_cum_seq;
- if (deterministic_) {
- if (pooling_region_generated_) {
- height_cum_seq = height_cum_seq_;
- width_cum_seq = width_cum_seq_;
- } else {
- height_cum_seq = GeneratePoolingSequence(
- input_size_[1], output_size_[1], &generator_, pseudo_random_);
- width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
- &generator_, pseudo_random_);
- mutex_lock lock(mu_);
- height_cum_seq_ = height_cum_seq;
- width_cum_seq_ = width_cum_seq;
- pooling_region_generated_ = true;
- }
- } else {
- height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
- &generator_, pseudo_random_);
- width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
- &generator_, pseudo_random_);
- }
+ GuardedPhiloxRandom generator;
+ generator.Init(seed_, seed2_);
+ height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
+ &generator, pseudo_random_);
+ width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
+ &generator, pseudo_random_);
// Prepare output.
Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({output_size_[0], output_size_[1],
- output_size_[2], output_size_[3]}),
- &output_tensor));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0,
+ TensorShape({output_size[0], output_size[1],
+ output_size[2], output_size[3]}),
+ &output_tensor));
Tensor* output_height_seq_tensor = nullptr;
OP_REQUIRES_OK(
context,
@@ -122,12 +122,11 @@ class FractionalMaxPoolOp : public OpKernel {
2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
&output_width_seq_tensor));
- ConstEigenMatrixMap in_mat(
- tensor_in.flat<T>().data(), input_size_[3],
- input_size_[2] * input_size_[1] * input_size_[0]);
+ ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
+ input_size[2] * input_size[1] * input_size[0]);
- EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
- output_size_[2] * output_size_[1] * output_size_[0]);
+ EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
+ output_size[2] * output_size[1] * output_size[0]);
// Initializes the output tensor with MIN<T>.
output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
@@ -149,9 +148,9 @@ class FractionalMaxPoolOp : public OpKernel {
// 1: height / row
// 2: width / col
// 3: depth / channel
- const int64 height_max = input_size_[1] - 1;
- const int64 width_max = input_size_[2] - 1;
- for (int64 b = 0; b < input_size_[0]; ++b) {
+ const int64 height_max = input_size[1] - 1;
+ const int64 width_max = input_size[2] - 1;
+ for (int64 b = 0; b < input_size[0]; ++b) {
// height sequence.
for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
// height start and end.
@@ -163,7 +162,7 @@ class FractionalMaxPoolOp : public OpKernel {
// width sequence.
for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
const int64 out_offset =
- (b * output_size_[1] + hs) * output_size_[2] + ws;
+ (b * output_size[1] + hs) * output_size[2] + ws;
// width start and end.
const int64 width_start = width_cum_seq[ws];
int64 width_end =
@@ -172,7 +171,7 @@ class FractionalMaxPoolOp : public OpKernel {
for (int64 h = height_start; h <= height_end; ++h) {
for (int64 w = width_start; w <= width_end; ++w) {
const int64 in_offset =
- (b * input_size_[1] + h) * input_size_[2] + w;
+ (b * input_size[1] + h) * input_size[2] + w;
out_mat.col(out_offset) =
out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
}
@@ -184,18 +183,11 @@ class FractionalMaxPoolOp : public OpKernel {
private:
bool deterministic_;
- // meaningful only when deterministic_ is true.
- mutex mu_;
- std::vector<int64> height_cum_seq_;
- std::vector<int64> width_cum_seq_;
- bool pooling_region_generated_;
-
- std::vector<int32> input_size_;
- std::vector<int32> output_size_;
+ int64 seed_;
+ int64 seed2_;
std::vector<float> pooling_ratio_;
bool pseudo_random_;
bool overlapping_;
- GuardedPhiloxRandom generator_;
};
#define REGISTER_FRACTIONALMAXPOOL(type) \
@@ -243,15 +235,13 @@ class FractionalMaxPoolGradOp : public OpKernel {
// Just to make it similar to FractionalMaxPoolOp.
constexpr int tensor_in_and_out_dims = 4;
- std::vector<int64> input_size;
- std::vector<int64> output_size;
- input_size.reserve(tensor_in_and_out_dims);
+ std::vector<int64> input_size(tensor_in_and_out_dims);
+ std::vector<int64> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
- input_size.push_back(tensor_in.dim_size(i));
+ input_size[i] = tensor_in.dim_size(i);
}
- output_size.reserve(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
- output_size.push_back(tensor_out.dim_size(i));
+ output_size[i] = tensor_out.dim_size(i);
}
// ---------
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index d7403fe6ee..e5f65edd39 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -369,6 +369,7 @@ tf_py_test(
srcs = ["fractional_avg_pool_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_grad",
@@ -383,6 +384,7 @@ tf_py_test(
srcs = ["fractional_max_pool_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_grad",
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
index 48a51c8072..feec9934e4 100644
--- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -23,6 +23,8 @@ import math
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -310,6 +312,35 @@ class FractionalAvgTest(test.TestCase):
self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
overlapping)
+ def testDifferentInputTensorShape(self):
+ """Runs the operation in one session with different input tensor shapes."""
+ with self.test_session() as sess:
+ input_holder = array_ops.placeholder(dtypes.float32,
+ [None, None, None, 3])
+ pooling_ratio = [1, 1.5, 1.5, 1]
+ pseudo_random = False
+ overlapping = False
+ p, r, c = nn_ops.fractional_avg_pool(
+ input_holder,
+ pooling_ratio,
+ pseudo_random,
+ overlapping,
+ deterministic=True,
+ seed=self._SEED,
+ seed2=self._SEED2)
+ # First run.
+ input_a = np.zeros([3, 32, 32, 3])
+ actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a})
+ expected = self._GetExpectedFractionalAvgPoolResult(
+ input_a, row_seq, col_seq, overlapping)
+ self.assertSequenceEqual(expected.shape, actual.shape)
+ # Second run.
+ input_b = np.zeros([4, 60, 60, 3])
+ actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b})
+ expected = self._GetExpectedFractionalAvgPoolResult(
+ input_b, row_seq, col_seq, overlapping)
+ self.assertSequenceEqual(expected.shape, actual.shape)
+
class FractionalAvgPoolGradTest(test.TestCase):
"""Tests for FractionalAvgPoolGrad.
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
index d380c31de3..5983ae7759 100644
--- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -23,6 +23,8 @@ import math
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -281,6 +283,35 @@ class FractionalMaxPoolTest(test.TestCase):
self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
overlapping)
+ def testDifferentInputTensorShape(self):
+ """Runs the operation in one session with different input tensor shapes."""
+ with self.test_session() as sess:
+ input_holder = array_ops.placeholder(dtypes.float32,
+ [None, None, None, 3])
+ pooling_ratio = [1, 1.5, 1.5, 1]
+ pseudo_random = False
+ overlapping = False
+ p, r, c = nn_ops.fractional_max_pool(
+ input_holder,
+ pooling_ratio,
+ pseudo_random,
+ overlapping,
+ deterministic=True,
+ seed=self._SEED,
+ seed2=self._SEED2)
+ # First run.
+ input_a = np.zeros([3, 32, 32, 3])
+ actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a})
+ expected = self._GetExpectedFractionalMaxPoolResult(
+ input_a, row_seq, col_seq, overlapping)
+ self.assertSequenceEqual(expected.shape, actual.shape)
+ # Second run.
+ input_b = np.zeros([4, 45, 45, 3])
+ actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b})
+ expected = self._GetExpectedFractionalMaxPoolResult(
+ input_b, row_seq, col_seq, overlapping)
+ self.assertSequenceEqual(expected.shape, actual.shape)
+
class FractionalMaxPoolGradTest(test.TestCase):
"""Tests for FractionalMaxPoolGrad.