aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 10:11:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 10:11:22 -0700
commitbb5fbbcb663c795dd7fc16e43a0eaaae53231fd9 (patch)
tree702f13489b23170ccf1149c01da4bdf501f5cf5a /tensorflow/core/ops
parent7d25d2d6c5db2269b6dba4cade6edaf7e8ddf6ba (diff)
parent33f57bd1311df97a25cd70784dfaafc8e44d07c4 (diff)
Merge pull request #21715 from hsgkim:volume_patches
PiperOrigin-RevId: 214177065
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/array_ops.cc110
1 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index c24950643f..442686c92a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2595,6 +2595,116 @@ REGISTER_OP("ExtractImagePatches")
// --------------------------------------------------------------------------
+// To enable rates, uncomment all lines commented below and use ksize_*_eff
+// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
+// of ksize_*.
+REGISTER_OP("ExtractVolumePatches")
+ .Input("input: T")
+ .Output("patches: T")
+ .Attr("ksizes: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ /* .Attr("rates: list(int) >= 5") */
+ .Attr("T: realnumbertype")
+ .Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> ksizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
+ if (ksizes.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the ksizes attribute to contain 5 "
+ "values, but got: ",
+ ksizes.size());
+ }
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the stride attribute to contain 5 "
+ "values, but got: ",
+ strides.size());
+ }
+
+ /*
+ // TODO(hsgkim): Enable rates.
+ // See extract_volume_patches_op.cc for why rates are disabled now.
+
+ std::vector<int32> rates;
+ TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
+ if (rates.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the rates attribute to contain 5 "
+ "values, but got: ",
+ rates.size());
+ }
+ */
+
+ int32 ksize_planes = ksizes[1];
+ int32 ksize_rows = ksizes[2];
+ int32 ksize_cols = ksizes[3];
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ /*
+ int32 rate_planes = rates[1];
+ int32 rate_rows = rates[2];
+ int32 rate_cols = rates[3];
+
+ int32 ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim;
+ TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
+ ksize_planes * ksize_rows * ksize_cols,
+ &output_depth_dim));
+
+ if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
+ !c->ValueKnown(in_cols_dim)) {
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, ksize_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, ksize_rows, stride_rows, padding, &output_rows,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, ksize_cols, stride_cols, padding, &output_cols,
+ &padding_before, &padding_after));
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ });
+
+// --------------------------------------------------------------------------
+
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")