diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 10:11:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 10:11:22 -0700 |
commit | bb5fbbcb663c795dd7fc16e43a0eaaae53231fd9 (patch) | |
tree | 702f13489b23170ccf1149c01da4bdf501f5cf5a /tensorflow/core/ops | |
parent | 7d25d2d6c5db2269b6dba4cade6edaf7e8ddf6ba (diff) | |
parent | 33f57bd1311df97a25cd70784dfaafc8e44d07c4 (diff) |
Merge pull request #21715 from hsgkim:volume_patches
PiperOrigin-RevId: 214177065
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index c24950643f..442686c92a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2595,6 +2595,116 @@ REGISTER_OP("ExtractImagePatches") // -------------------------------------------------------------------------- +// To enable rates, uncomment all lines commented below and use ksize_*_eff +// as the second parameter of all GetWindowedOutputSizeVerbose calls instead +// of ksize_*. +REGISTER_OP("ExtractVolumePatches") + .Input("input: T") + .Output("patches: T") + .Attr("ksizes: list(int) >= 5") + .Attr("strides: list(int) >= 5") + /* .Attr("rates: list(int) >= 5") */ + .Attr("T: realnumbertype") + .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + + std::vector<int32> ksizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); + if (ksizes.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the ksizes attribute to contain 5 " + "values, but got: ", + ksizes.size()); + } + + std::vector<int32> strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the stride attribute to contain 5 " + "values, but got: ", + strides.size()); + } + + /* + // TODO(hsgkim): Enable rates. + // See extract_volume_patches_op.cc for why rates are disabled now. + + std::vector<int32> rates; + TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); + if (rates.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the rates attribute to contain 5 " + "values, but got: ", + rates.size()); + } + */ + + int32 ksize_planes = ksizes[1]; + int32 ksize_rows = ksizes[2]; + int32 ksize_cols = ksizes[3]; + + int32 stride_planes = strides[1]; + int32 stride_rows = strides[2]; + int32 stride_cols = strides[3]; + + /* + int32 rate_planes = rates[1]; + int32 rate_rows = rates[2]; + int32 rate_cols = rates[3]; + + int32 ksize_planes_eff = ksize_planes + + (ksize_planes - 1) * (rate_planes - 1); + int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); + int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); + */ + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_planes_dim = c->Dim(input_shape, 1); + DimensionHandle in_rows_dim = c->Dim(input_shape, 2); + DimensionHandle in_cols_dim = c->Dim(input_shape, 3); + DimensionHandle output_depth_dim; + TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4), + ksize_planes * ksize_rows * ksize_cols, + &output_depth_dim)); + + if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || + !c->ValueKnown(in_cols_dim)) { + ShapeHandle output_shape = + c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); + } + auto in_planes = c->Value(in_planes_dim); + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + int64 output_planes, output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_planes, ksize_planes, stride_planes, padding, &output_planes, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, ksize_rows, stride_rows, padding, &output_rows, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, ksize_cols, stride_cols, padding, &output_cols, + &padding_before, &padding_after)); + ShapeHandle output_shape = + c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, + output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); + }); + +// -------------------------------------------------------------------------- + REGISTER_OP("Bitcast") .Input("input: T") .Output("output: type") |