aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-01-24 10:02:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 10:06:06 -0800
commitd9f93c42a50b1f1401d9c186eac0ae8dc9093c3b (patch)
tree178d1a692f56580c266139642b5a1d0d155c477e /tensorflow/contrib/periodic_resample
parent7b62a71e2d46c148df7d5704972f4592bc5e0f1b (diff)
Merge changes from github.
PiperOrigin-RevId: 183100142
Diffstat (limited to 'tensorflow/contrib/periodic_resample')
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD18
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h11
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc42
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py16
4 files changed, 72 insertions, 15 deletions
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index 71582f9c9a..bd9078ae76 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -6,6 +6,7 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
+ "py_test",
"tf_gen_op_libs",
"tf_custom_op_library",
"tf_custom_op_py_library",
@@ -64,11 +65,28 @@ py_library(
"python/__init__.py",
],
srcs_version = "PY2AND3",
+ tags = [
+ "notap",
+ ],
deps = [
":periodic_resample_op_py",
],
)
+py_test(
+ name = "periodic_resample_op_test",
+ srcs = ["python/kernel_tests/periodic_resample_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "notap",
+ ],
+ deps = [
+ ":init_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
# py_library(
# name = "periodic_resample_op_py",
# srcs = ["python/ops/periodic_resample_op.py"],
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index bef21f7a5c..ba410f025d 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -100,6 +100,8 @@ template <class InputDataT,
desired_shape.size(), "."));
bool found = false;
+ const auto& input_tensor_shape = input_tensor.shape();
+
for (int i = 0; i < rank; ++i) {
// if (desired_shape(i) < 1) {
if (desired_shape[i] < 1) {
@@ -111,6 +113,15 @@ template <class InputDataT,
adjustable_dimension = i;
found = true;
} else {
+ OP_REQUIRES(
+ context, desired_shape[i] >= input_tensor_shape.dim_size(i),
+ tensorflow::errors::InvalidArgument(
+ "periodic_resample expects the size of non-adjustable "
+ "dimensions be at least as large as size of input tensor."
+ " Dimension ", i, " input tensor has size ",
+ input_tensor_shape.dim_size(i), ", desired shape has size ",
+ desired_shape[i], "."));
+
// target_dimensions[i] = desired_shape(i);
target_dimensions[i] = desired_shape[i];
new_sliced_size *= target_dimensions[i];
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index c90fc06c7f..82bd796956 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -34,26 +34,40 @@ This function implements a slightly more generic version of the subpixel
convolutions found in this [paper](https://arxiv.org/abs/1609.05158).
The formula for computing the elements in the `output` tensor is as follows:
+
`T` = `values` tensor of rank `R`
+
`S` = desired `shape` of output tensor (vector of length `R`)
+
`P` = `output` tensor of rank `R`
- \((T_1,\ldots,T_R)\) = shape(`T`)
- \([S_1,\ldots,S_q,\ldots,S_R]\) = elements of vector `S`
- A single element in `S` is left unspecified (denoted \(S_q=-1\)).
- Let \(f_i\) denote the (possibly non-integer) factor that relates the original
- dimension to the desired dimensions, \(S_i=f_i T_i\), for \(i\neq q\) where
- \(f_i>0\).
+ \\((T_1,\\ldots,T_R)\\) = shape(`T`)
+
+ \\([S_1,\\ldots,S_q,\\ldots,S_R]\\) = elements of vector `S`
+
+ A single element in `S` is left unspecified (denoted \\(S_q=-1\\)).
+
+ Let \\(f_i\\) denote the (possibly non-integer) factor that relates the original
+ dimension to the desired dimensions, \\(S_i=f_i T_i\\), for \\(i\\neq q\\) where
+ \\(f_i>0\\).
+
Define the following:
- \(g_i=\lceil f_i\rceil\)
- \(t=\prod_i T_i\)
- \(s=\prod_{i\neq q} S_i\)
- \(S_q\) can then be defined as by \(S_q=\lfloor t/s\rfloor\).
+
+ \\(g_i=\\lceil f_i\\rceil\\)
+
+ \\(t=\\prod_i T_i\\)
+
+ \\(s=\\prod_{i\\neq q} S_i\\)
+
+ \\(S_q\\) can then be defined by \\(S_q=\\lfloor t/s\\rfloor\\).
The elements of the resulting tensor are defined as
- \(P_{s_1,\ldots,s_R}=T_{h_1,\ldots,h_q,\ldots,h_R}\).
- The \(h_i\) (\(i\neq q\)) are defined by \(h_i=\lfloor s_i/g_i\rfloor\).
- \(h_q=S_q\sum_{j\neq q}^{q-1}G_j \mathrm{mod}(s_j,g_j) + s_q\), where
- \(G_j=\prod_{i}^{j-1}g_i\) (\(G_0=1\)).
+
+ \\(P_{s_1,\\ldots,s_R}=T_{h_1,\\ldots,h_q,\\ldots,h_R}\\).
+
+ The \\(h_i\\) (\\(i\\neq q\\)) are defined by \\(h_i=\\lfloor s_i/g_i\\rfloor\\).
+
+ \\(h_q=S_q\\sum_{j\\neq q}^{q-1}G_j \\mathrm{mod}(s_j,g_j) + s_q\\), where
+ \\(G_j=\\prod_{i}^{j-1}g_i\\) (\\(G_0=1\\)).
One drawback of this method is that whenever the output dimensions are slightly
less than integer multiples of the input dimensions, many of the tensor elements
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 1d727870f6..30a2077570 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -19,8 +19,9 @@ from __future__ import division
from __future__ import print_function
import numpy
-import tensorflow
+
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -96,6 +97,19 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
+ def testPeriodicResampleErrors(self):
+ input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ 'Dimension 3 input tensor has size 4, desired shape has size 1'):
+ periodic_resample(input_tensor, [None, 4, 4, 1]).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ '4, to be the same as the length of the desired shape, 3'):
+ periodic_resample(input_tensor, [None, 4, 4]).eval()
+
if __name__ == "__main__":
googletest.main()