aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 16:35:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 17:54:30 -0700
commitf6bc8cabbd3ac1fb3acc36d3edbdce672cae7d12 (patch)
treea8c81b0269e68f57606052ab5573743f86995be6 /tensorflow/core
parentade1672d60d861c58e1930e93a1b396b22e7a4d9 (diff)
Add shape_inference::ShapeHandle and shape_inference::DimensionHandle to
replace uses of const Shape* and const Dimension*. This change only adds a typedef and updates references. A later change will make DimensionHandle and ShapeHandle real types instead of typedefs (to further hide the pointer access). Change: 131118981
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc176
-rw-r--r--tensorflow/core/framework/common_shape_fns.h8
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc32
-rw-r--r--tensorflow/core/framework/shape_inference.cc139
-rw-r--r--tensorflow/core/framework/shape_inference.h163
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc58
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc10
-rw-r--r--tensorflow/core/ops/array_ops.cc262
-rw-r--r--tensorflow/core/ops/candidate_sampling_ops.cc16
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc16
-rw-r--r--tensorflow/core/ops/ctc_ops.cc32
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc92
-rw-r--r--tensorflow/core/ops/image_ops.cc56
-rw-r--r--tensorflow/core/ops/io_ops.cc28
-rw-r--r--tensorflow/core/ops/linalg_ops.cc107
-rw-r--r--tensorflow/core/ops/math_ops.cc98
-rw-r--r--tensorflow/core/ops/nn_ops.cc78
-rw-r--r--tensorflow/core/ops/parsing_ops.cc22
-rw-r--r--tensorflow/core/ops/random_ops.cc14
-rw-r--r--tensorflow/core/ops/sparse_ops.cc72
-rw-r--r--tensorflow/core/ops/state_ops.cc16
-rw-r--r--tensorflow/core/ops/string_ops.cc8
-rw-r--r--tensorflow/core/ops/training_ops.cc54
23 files changed, 779 insertions, 778 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 090d9804e2..c345d3c742 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -75,22 +75,22 @@ Status UnchangedShape(shape_inference::InferenceContext* c) {
}
Status MatMulShape(shape_inference::InferenceContext* c) {
- const Shape* a;
+ ShapeHandle a;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
- const Shape* b;
+ ShapeHandle b;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
bool transpose_a, transpose_b;
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
- const Dimension* output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
- const Dimension* output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
+ DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
+ DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
// Validate that the inner shapes are compatible.
- const Dimension* inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
- const Dimension* inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
- const Dimension* merged;
+ DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
+ DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
+ DimensionHandle merged;
TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
c->set_output(0, c->Matrix(output_rows, output_cols));
@@ -98,7 +98,7 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
}
Status BiasAddShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
// Fetch the data_format attribute, which may not exist.
string data_format;
@@ -110,9 +110,9 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
}
- const Shape* bias_shape;
+ ShapeHandle bias_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
- const Dimension* bias_dim = c->Dim(bias_shape, 0);
+ DimensionHandle bias_dim = c->Dim(bias_shape, 0);
// If rank unknown, return unknown shape.
if (!c->RankKnown(input_shape)) {
@@ -122,32 +122,32 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
// Output has the same shape as the input, and matches the length of
// the bias in its bias dimension.
- const Shape* output_shape;
+ ShapeHandle output_shape;
if (s.ok() && data_format == "NCHW") {
// Merge the length of bias_shape into the third to last dimension
- const Shape* first;
+ ShapeHandle first;
TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first));
- const Shape* last;
+ ShapeHandle last;
TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last));
- const Dimension* input_bias_dim = c->Dim(input_shape, -3);
- const Dimension* merged_bias_dim;
+ DimensionHandle input_bias_dim = c->Dim(input_shape, -3);
+ DimensionHandle merged_bias_dim;
TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
- const Shape* merged_bias = c->Vector(merged_bias_dim);
+ ShapeHandle merged_bias = c->Vector(merged_bias_dim);
- const Shape* temp;
+ ShapeHandle temp;
TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
} else {
- const Shape* all_but_bias;
+ ShapeHandle all_but_bias;
TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
- const Dimension* input_bias_dim = c->Dim(input_shape, -1);
- const Dimension* merged_bias_dim;
+ DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
+ DimensionHandle merged_bias_dim;
TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
- const Shape* merged_bias = c->Vector(merged_bias_dim);
+ ShapeHandle merged_bias = c->Vector(merged_bias_dim);
TF_RETURN_IF_ERROR(
c->Concatenate(all_but_bias, merged_bias, &output_shape));
}
@@ -157,7 +157,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
}
Status BiasAddGradShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
// Fetch the data_format attribute, which may not exist.
string data_format;
Status s = c->GetAttr("data_format", &data_format);
@@ -174,9 +174,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
}
Status Conv2DShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
- const Shape* filter_shape;
+ ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
string data_format;
@@ -205,12 +205,12 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
stride_cols = strides[2];
}
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* filter_rows_dim = c->Dim(filter_shape, 0);
- const Dimension* filter_cols_dim = c->Dim(filter_shape, 1);
- const Dimension* output_depth_dim = c->Dim(filter_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
+ DimensionHandle output_depth_dim = c->Dim(filter_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
@@ -223,7 +223,7 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
auto filter_rows = c->Value(filter_rows_dim);
auto filter_cols = c->Value(filter_cols_dim);
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input_shape, 3), c->Dim(filter_shape, 2), &unused));
@@ -239,7 +239,7 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before,
&padding_after));
- const Shape* output_shape;
+ ShapeHandle output_shape;
if (data_format == "NCHW") {
output_shape = c->MakeShape(
{batch_size_dim, output_depth_dim, output_rows, output_cols});
@@ -253,9 +253,9 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
}
Status Conv3DShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
- const Shape* filter_shape;
+ ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
std::vector<int32> strides;
@@ -270,15 +270,15 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
int32 stride_rows = strides[2];
int32 stride_cols = strides[3];
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_planes_dim = c->Dim(input_shape, 1);
- const Dimension* in_rows_dim = c->Dim(input_shape, 2);
- const Dimension* in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
- const Dimension* filter_planes_dim = c->Dim(filter_shape, 0);
- const Dimension* filter_rows_dim = c->Dim(filter_shape, 1);
- const Dimension* filter_cols_dim = c->Dim(filter_shape, 2);
- const Dimension* output_depth_dim = c->Dim(filter_shape, 4);
+ DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
+ DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
@@ -295,7 +295,7 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
auto filter_rows = c->Value(filter_rows_dim);
auto filter_cols = c->Value(filter_cols_dim);
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
@@ -314,7 +314,7 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before,
&padding_after));
- const Shape* output_shape =
+ ShapeHandle output_shape =
c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
output_depth_dim});
c->set_output(0, output_shape);
@@ -322,9 +322,9 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
}
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
- const Shape* filter_shape;
+ ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
std::vector<int32> strides;
@@ -337,13 +337,13 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
strides.size());
}
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* filter_rows_dim = c->Dim(filter_shape, 0);
- const Dimension* filter_cols_dim = c->Dim(filter_shape, 1);
- const Dimension* input_depth = c->Dim(filter_shape, 2);
- const Dimension* depth_multiplier = c->Dim(filter_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
+ DimensionHandle input_depth = c->Dim(filter_shape, 2);
+ DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
@@ -357,7 +357,7 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
- const Dimension* output_depth;
+ DimensionHandle output_depth;
TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
const int32 stride_rows = strides[1];
@@ -383,14 +383,14 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before,
&padding_after));
- const Shape* output_shape =
+ ShapeHandle output_shape =
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
c->set_output(0, output_shape);
return Status::OK();
}
Status AvgPoolShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
string data_format;
@@ -432,10 +432,10 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
kernel_cols = kernel_sizes[2];
}
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* output_depth_dim = c->Dim(input_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle output_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
@@ -459,7 +459,7 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before,
&padding_after));
- const Shape* output_shape;
+ ShapeHandle output_shape;
if (data_format == "NCHW") {
output_shape = c->MakeShape(
{batch_size_dim, output_depth_dim, output_rows, output_cols});
@@ -473,7 +473,7 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
}
Status MaxPoolShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
string data_format;
@@ -519,10 +519,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
kernel_depth = kernel_sizes[3];
}
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* in_depth_dim = c->Dim(input_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
@@ -551,7 +551,7 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
in_depth, kernel_depth, stride_depth, padding, &output_depth,
&padding_before, &padding_after));
- const Shape* output_shape =
+ ShapeHandle output_shape =
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
if (data_format == "NCHW") {
@@ -566,7 +566,7 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
}
Status Pool3DShape(shape_inference::InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
std::vector<int32> strides;
@@ -596,11 +596,11 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
kernel_rows = kernel_sizes[2];
kernel_cols = kernel_sizes[3];
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_planes_dim = c->Dim(input_shape, 1);
- const Dimension* in_rows_dim = c->Dim(input_shape, 2);
- const Dimension* in_cols_dim = c->Dim(input_shape, 3);
- const Dimension* output_depth_dim = c->Dim(input_shape, 4);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
@@ -629,7 +629,7 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before,
&padding_after));
- const Shape* output_shape =
+ ShapeHandle output_shape =
c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
output_depth_dim});
@@ -645,9 +645,9 @@ Status UnknownShape(shape_inference::InferenceContext* c) {
}
Status ReductionShape(InferenceContext* c) {
- const Shape* input = c->input(0);
+ ShapeHandle input = c->input(0);
- const Shape* indices;
+ ShapeHandle indices;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
const Tensor* reduction_indices_t = c->input_tensor(1);
@@ -679,7 +679,7 @@ Status ReductionShape(InferenceContext* c) {
wrapped_index += input_rank;
}
- const Dimension* reduce_dim = c->Dim(input, wrapped_index);
+ DimensionHandle reduce_dim = c->Dim(input, wrapped_index);
if (c->ValueKnown(reduce_dim) && c->Value(reduce_dim) == 0) {
return errors::InvalidArgument("Cannot reduce dimension ",
reduction_index, " with size 0");
@@ -688,7 +688,7 @@ Status ReductionShape(InferenceContext* c) {
true_indices.insert(wrapped_index);
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
bool reduce_all = reduction_indices_t->NumElements() == 0;
for (int i = 0; i < input_rank; ++i) {
if (reduce_all || true_indices.count(i) > 0) {
@@ -705,7 +705,7 @@ Status ReductionShape(InferenceContext* c) {
}
Status ConcatShape(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
const Tensor* concat_dim_t = c->input_tensor(0);
@@ -730,7 +730,7 @@ Status ConcatShape(InferenceContext* c) {
"Can't concatenate scalars (use tf.pack instead)");
}
// Build result of <rank> different unknown dims.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
c->set_output(0, c->MakeShape(dims));
return Status::OK();
@@ -744,22 +744,22 @@ Status ConcatShape(InferenceContext* c) {
concat_dim);
}
- const Shape* output_before;
- const Shape* output_after;
+ ShapeHandle output_before;
+ ShapeHandle output_after;
- const Shape* input = c->input(c->num_inputs() - 1);
+ ShapeHandle input = c->input(c->num_inputs() - 1);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, concat_dim + 1, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
- const Dimension* output_middle = c->Dim(input, concat_dim);
+ DimensionHandle output_middle = c->Dim(input, concat_dim);
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
for (int i = c->num_inputs() - 2; i > 0; --i) {
- const Shape* before;
- const Shape* after;
+ ShapeHandle before;
+ ShapeHandle after;
input = c->input(i);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, concat_dim + 1, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
- const Dimension* middle = c->Dim(input, concat_dim);
+ DimensionHandle middle = c->Dim(input, concat_dim);
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
@@ -767,7 +767,7 @@ Status ConcatShape(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
}
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(
c->Concatenate(output_before, c->Vector(output_middle), &s));
TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 71692cec1e..b828b23dfe 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -102,7 +102,7 @@ Status UnchangedShape(shape_inference::InferenceContext* c);
// Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
int32 rank) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
c->set_output(0, out);
return Status::OK();
@@ -111,7 +111,7 @@ inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
// Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
inline Status UnchangedShapeWithRankAtLeast(
shape_inference::InferenceContext* c, int32 rank) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
c->set_output(0, out);
return Status::OK();
@@ -120,7 +120,7 @@ inline Status UnchangedShapeWithRankAtLeast(
// Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
int32 rank) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
c->set_output(0, out);
return Status::OK();
@@ -139,7 +139,7 @@ inline Status ScalarShape(shape_inference::InferenceContext* c) {
// Shape function for binary ops where both inputs and the output match.
inline Status MergeBothInputsShapeFn(InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
c->set_output(0, out);
return Status::OK();
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index e12a33adf4..45968fda73 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -58,14 +58,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
{
InferenceContext c(&def, op_def, {"[]"}, {});
TF_EXPECT_OK(ScalarShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
InferenceContext c(&def, op_def, {"[1,23,4,4,2]"}, {});
TF_EXPECT_OK(ScalarShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
}
@@ -92,7 +92,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
InferenceContext c(&def, op_def, {"[2,3]", "[3,4]"}, {});
TF_EXPECT_OK(MatMulShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
}
@@ -101,7 +101,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
// Unknown inner dimension for one
InferenceContext c(&def, op_def, {"[2,?]", "[3,4]"}, {});
TF_EXPECT_OK(MatMulShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
}
@@ -119,7 +119,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
// Unknown outer dimension
InferenceContext c(&def, op_def, {"[2,3]", "[3,?]"}, {});
TF_EXPECT_OK(MatMulShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_FALSE(c.ValueKnown(c.Dim(output, 1)));
}
@@ -154,7 +154,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
InferenceContext c(&def, op_def, {"[3,2]", "[3,4]"}, {});
auto s = MatMulShape(&c);
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
}
@@ -171,7 +171,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
InferenceContext c(&def, op_def, {"[2,3]", "[4,3]"}, {});
auto s = MatMulShape(&c);
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
}
@@ -195,7 +195,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
InferenceContext c(&def, op_def, {"[2,10]", "[10]"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
EXPECT_EQ(10, c.Value(c.Dim(output, 1)));
}
@@ -204,7 +204,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
// Unknown ranks.
InferenceContext c(&def, op_def, {"?", "?"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
}
@@ -212,7 +212,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
// Rank > 2
InferenceContext c(&def, op_def, {"[4,3,4,2,15]", "[15]"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
}
@@ -225,7 +225,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
InferenceContext c(&def, op_def, {"[2,3,4,5]", "[3]"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
}
@@ -238,7 +238,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
InferenceContext c(&def, op_def, {"[8,6,4,2,3,4,5]", "[3]"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
}
@@ -277,7 +277,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
InferenceContext c(&def, op_def, {"[2,10]"}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
}
@@ -285,7 +285,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
// Rank > 2
InferenceContext c(&def, op_def, {"[5,7,2,10]"}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
}
@@ -297,7 +297,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
InferenceContext c(&def, op_def, {"[2,3,4,5]"}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
}
@@ -309,7 +309,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
InferenceContext c(&def, op_def, {"[8,6,4,2,3,4,5]"}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
- const Shape* output = c.output(0);
+ ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
}
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index c6da445165..90b3a6a688 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -35,7 +35,7 @@ InferenceContext::InferenceContext(
PreInputInit(op_def, input_tensors);
for (const string& spec : input_shapes) {
- const Shape* shape;
+ ShapeHandle shape;
construction_status_.Update(MakeShapeFromString(spec, &shape));
if (!construction_status_.ok()) {
return;
@@ -55,7 +55,7 @@ InferenceContext::InferenceContext(
PreInputInit(op_def, input_tensors);
if (!construction_status_.ok()) return;
for (const TensorShapeProto& p : input_shapes) {
- const Shape* shape;
+ ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
if (!construction_status_.ok()) {
return;
@@ -68,7 +68,7 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
- const std::vector<const Shape*>& input_shapes,
+ const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors);
@@ -118,7 +118,7 @@ void InferenceContext::PostInputInit() {
requested_input_tensor_.resize(inputs_.size());
}
-bool InferenceContext::FullyDefined(const Shape* s) {
+bool InferenceContext::FullyDefined(ShapeHandle s) {
if (!RankKnown(s)) return false;
for (int i = 0; i < Rank(s); ++i) {
if (!ValueKnown(Dim(s, i))) return false;
@@ -126,7 +126,7 @@ bool InferenceContext::FullyDefined(const Shape* s) {
return true;
}
-const Dimension* InferenceContext::NumElements(const Shape* s) {
+DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
const auto rank = Rank(s);
if (rank == kUnknownRank) return UnknownDim();
int64 size = 1;
@@ -138,7 +138,7 @@ const Dimension* InferenceContext::NumElements(const Shape* s) {
return MakeDim(size);
}
-string InferenceContext::DebugString(const Shape* s) {
+string InferenceContext::DebugString(ShapeHandle s) {
if (RankKnown(s)) {
std::vector<string> vals;
for (auto d : s->dims_) vals.push_back(DebugString(d));
@@ -148,19 +148,19 @@ string InferenceContext::DebugString(const Shape* s) {
}
}
-string InferenceContext::DebugString(const Dimension* d) {
+string InferenceContext::DebugString(DimensionHandle d) {
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
}
-Status InferenceContext::WithRank(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
return Status::OK();
}
if (existing == kUnknownRank) {
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(rank);
for (int i = 0; i < rank; ++i) {
all_dims_.push_back(new Dimension());
@@ -175,8 +175,8 @@ Status InferenceContext::WithRank(const Shape* shape, int32 rank,
existing);
}
-Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing >= rank) {
*out = shape;
@@ -190,8 +190,8 @@ Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
@@ -205,8 +205,8 @@ Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithValue(const Dimension* dim, int64 value,
- const Dimension** out) {
+Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
+ DimensionHandle* out) {
const int64 existing = Value(dim);
if (existing == value) {
*out = dim;
@@ -222,8 +222,8 @@ Status InferenceContext::WithValue(const Dimension* dim, int64 value,
existing);
}
-Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
- const Dimension** out) {
+Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
+ DimensionHandle* out) {
if (d0 == d1 || !ValueKnown(d1)) {
*out = d0;
return Status::OK();
@@ -240,9 +240,9 @@ Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
}
}
-Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
- const Shape** s_out,
- const Shape** prefix_out) {
+Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
+ ShapeHandle* s_out,
+ ShapeHandle* prefix_out) {
*s_out = *prefix_out = nullptr;
if (!RankKnown(prefix) || !RankKnown(s)) {
*s_out = s;
@@ -253,7 +253,7 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
// Merge the prefix dims and create the new output shapes.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.resize(rank);
for (int i = 0; i < rank; ++i) {
TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
@@ -264,8 +264,8 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
return Status::OK();
}
-Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
- const Shape** out) {
+Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
+ ShapeHandle* out) {
if (s0 == s1 || !RankKnown(s1)) {
*out = s0;
return Status::OK();
@@ -309,7 +309,7 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
}
// Merge dims.
- std::vector<const Dimension*> dims(rank, nullptr);
+ std::vector<DimensionHandle> dims(rank, nullptr);
for (int i = 0; i < rank; ++i) {
// Invariant for merge was checked earlier, so CHECK is ok.
TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
@@ -317,13 +317,13 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::Subshape(const Shape* s, int64 start,
- const Shape** out) {
+Status InferenceContext::Subshape(ShapeHandle s, int64 start,
+ ShapeHandle* out) {
return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
}
-Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
- const Shape** out) {
+Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
+ ShapeHandle* out) {
int64 start = start_in;
int64 end = end_in;
const int32 rank = Rank(s);
@@ -362,7 +362,7 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
end, " (computed from start ", start_in, " and end ", end_in,
" over shape with rank ", rank, ")");
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(end - start);
for (int i = start; i < end; ++i) {
dims.push_back(Dim(s, i));
@@ -370,24 +370,23 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::Concatenate(const Shape* s1, const Shape* s2,
- const Shape** out) {
+Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
+ ShapeHandle* out) {
if (!RankKnown(s1) || !RankKnown(s2)) {
return ReturnUnknownShape(out);
}
const int32 s1_rank = Rank(s1);
const int32 s2_rank = Rank(s2);
const int32 rank = s1_rank + s2_rank;
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(rank);
for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
- const Dimension* new_dim,
- const Shape** out) {
+Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in,
+ DimensionHandle new_dim, ShapeHandle* out) {
if (!RankKnown(s)) {
return ReturnUnknownShape(out);
}
@@ -401,20 +400,20 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
" for shape with ", s->dims_.size(),
" dimensions");
}
- std::vector<const Dimension*> dims(s->dims_);
+ std::vector<DimensionHandle> dims(s->dims_);
dims[dim_index] = new_dim;
return ReturnCreatedShape(dims, out);
}
-const Shape* InferenceContext::MakeShape(
- const std::vector<const Dimension*>& dims) {
+ShapeHandle InferenceContext::MakeShape(
+ const std::vector<DimensionHandle>& dims) {
all_shapes_.push_back(new Shape(dims));
return all_shapes_.back();
}
-const Shape* InferenceContext::MakeShape(
+ShapeHandle InferenceContext::MakeShape(
std::initializer_list<DimensionOrConstant> dims) {
- std::vector<const Dimension*> dims_actual;
+ std::vector<DimensionHandle> dims_actual;
dims_actual.reserve(dims.size());
for (const DimensionOrConstant& d : dims) {
dims_actual.push_back(MakeDim(d));
@@ -422,45 +421,45 @@ const Shape* InferenceContext::MakeShape(
return MakeShape(dims_actual);
}
-const Shape* InferenceContext::UnknownShape() {
+ShapeHandle InferenceContext::UnknownShape() {
all_shapes_.push_back(new Shape());
return all_shapes_.back();
}
-const Shape* InferenceContext::UnknownShapeOfRank(int32 rank) {
- std::vector<const Dimension*> dims(rank);
+ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) {
+ std::vector<DimensionHandle> dims(rank);
for (int32 i = 0; i < rank; ++i) {
dims[i] = UnknownDim();
}
return MakeShape(dims);
}
-const Shape* InferenceContext::Scalar() { return MakeShape({}); }
+ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
-const Shape* InferenceContext::Vector(DimensionOrConstant dim) {
+ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
return MakeShape({dim});
}
-const Shape* InferenceContext::Matrix(DimensionOrConstant dim1,
- DimensionOrConstant dim2) {
+ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
+ DimensionOrConstant dim2) {
return MakeShape({dim1, dim2});
}
Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
- const Shape** out) {
- const Shape* input_shape;
+ ShapeHandle* out) {
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
const Tensor* t = input_tensor(input_idx);
if (t == nullptr) {
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
- const Dimension* shape_dim = Dim(input_shape, 0);
+ DimensionHandle shape_dim = Dim(input_shape, 0);
if (!ValueKnown(shape_dim)) {
return ReturnUnknownShape(out);
}
const auto num_dims = Value(shape_dim);
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
return ReturnCreatedShape(dims, out);
}
@@ -470,7 +469,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
t->shape().dims());
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
if (t->dtype() == DataType::DT_INT32) {
auto flat_t = t->flat<int32>();
for (int i = 0; i < flat_t.size(); ++i) {
@@ -492,7 +491,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
}
Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
- const Shape** out) {
+ ShapeHandle* out) {
*out = nullptr;
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
PartialTensorShape partial_shape(proto);
@@ -500,7 +499,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return ReturnUnknownShape(out);
}
const int num_dims = partial_shape.dims();
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(partial_shape.dims());
for (int i = 0; i < num_dims; ++i) {
// -1 is unknown in proto and in InferenceContext, so this size can be
@@ -511,7 +510,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
}
// Returns a new dimension whose value is given by a scalar input tensor.
-Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
+Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
const Tensor* t = input_tensor(idx);
if (t == nullptr) {
*out = UnknownDim();
@@ -539,8 +538,8 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
return Status::OK();
}
-Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
- const Dimension** out) {
+Status InferenceContext::Divide(DimensionHandle dividend, int64 divisor,
+ DimensionHandle* out) {
if (divisor == 1) {
*out = dividend;
} else if (!ValueKnown(dividend)) {
@@ -560,8 +559,8 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
return Status::OK();
}
-Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -583,9 +582,9 @@ Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
return Status::OK();
}
-Status InferenceContext::Subtract(const Dimension* first,
+Status InferenceContext::Subtract(DimensionHandle first,
DimensionOrConstant second,
- const Dimension** out) {
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -606,9 +605,9 @@ Status InferenceContext::Subtract(const Dimension* first,
return Status::OK();
}
-Status InferenceContext::Multiply(const Dimension* first,
+Status InferenceContext::Multiply(DimensionHandle first,
DimensionOrConstant second,
- const Dimension** out) {
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -635,8 +634,8 @@ Status InferenceContext::Multiply(const Dimension* first,
return Status::OK();
}
-Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
if (first_value == 0) {
@@ -655,8 +654,8 @@ Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
return Status::OK();
}
-Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
if (first_value == kUnknownDim || second_value == kUnknownDim) {
@@ -672,13 +671,13 @@ Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
}
Status InferenceContext::MakeShapeFromString(const string& spec,
- const Shape** output) {
+ ShapeHandle* output) {
if (spec == "?") {
*output = UnknownShape();
return Status::OK();
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
strings::Scanner scanner(spec);
scanner.OneLiteral("[");
while (scanner.Peek() != ']') {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 1d0d4ab471..6f561f467a 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -43,38 +43,42 @@ class Dimension {
TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
};
+typedef const Dimension* DimensionHandle;
+
// Shape rank and dimensions are accessed through InferenceContext.
class Shape {
private:
Shape();
- Shape(const std::vector<const Dimension*>& dims);
+ Shape(const std::vector<DimensionHandle>& dims);
~Shape() {}
const int32 rank_;
- const std::vector<const Dimension*> dims_;
+ const std::vector<DimensionHandle> dims_;
friend class InferenceContext;
TF_DISALLOW_COPY_AND_ASSIGN(Shape);
};
-// Struct used to allow functions to take const Dimension* or a dimension value.
+// Struct used to allow functions to take DimensionHandle or a dimension value.
// Not meant to be constructed directly.
struct DimensionOrConstant {
public:
// Intentionally not explicit.
- DimensionOrConstant(const Dimension* dim);
+ DimensionOrConstant(DimensionHandle dim);
// val must be non-negative or InferenceContext::kUnknownDim.
DimensionOrConstant(int64 val);
// dim takes precedence. If dim != nullptr, val is ignored.
- const Dimension* dim;
+ DimensionHandle dim;
int64 val;
private:
DimensionOrConstant();
};
+typedef const Shape* ShapeHandle;
+
// Note: This is experimental support for op shape inference in C++. Shape
// inference functions are not ready to be implemented yet.
//
@@ -97,7 +101,7 @@ class InferenceContext {
// creation of Shapes from strings out of this class (or hide it).
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
- const std::vector<const Shape*>& input_shapes,
+ const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
@@ -131,7 +135,7 @@ class InferenceContext {
~InferenceContext();
- const Shape* input(int idx) const { return inputs_[idx]; }
+ ShapeHandle input(int idx) const { return inputs_[idx]; }
int num_inputs() const { return inputs_.size(); }
// Returns the input tensor at index <idx>, or nullptr if the input tensor is
@@ -151,13 +155,13 @@ class InferenceContext {
input_tensors_ = input_tensors;
}
- void set_output(int idx, const Shape* shape) { outputs_[idx] = shape; }
+ void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
int num_outputs() const { return outputs_.size(); }
- const Shape* output(int idx) { return outputs_[idx]; }
+ ShapeHandle output(int idx) { return outputs_[idx]; }
// idx can be negative for an offset from end of dimensions.
// idx must be in the range [-1 * s.rank, s.rank).
- const Dimension* Dim(const Shape* s, int32 idx) {
+ DimensionHandle Dim(ShapeHandle s, int32 idx) {
if (s->rank_ == kUnknownRank) {
return UnknownDim();
}
@@ -166,8 +170,8 @@ class InferenceContext {
}
return s->dims_[idx];
}
- int32 Rank(const Shape* s) { return s->rank_; }
- bool RankKnown(const Shape* s) { return Rank(s) != kUnknownRank; }
+ int32 Rank(ShapeHandle s) { return s->rank_; }
+ bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; }
inline int64 Value(DimensionOrConstant d) {
return d.dim ? d.dim->value_ : d.val;
}
@@ -176,111 +180,111 @@ class InferenceContext {
}
// Returns true if the rank and all dimensions of the Shape are known.
- bool FullyDefined(const Shape* s);
+ bool FullyDefined(ShapeHandle s);
// Returns the total number of elements, or an unknown dimension for an
// incomplete shape.
- const Dimension* NumElements(const Shape* s);
+ DimensionHandle NumElements(ShapeHandle s);
- string DebugString(const Shape* s);
- string DebugString(const Dimension* d);
+ string DebugString(ShapeHandle s);
+ string DebugString(DimensionHandle d);
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
// the shape with asserted rank in <*out>. Otherwise return an error.
//
// Note that <*out> may be set to <shape>.
- Status WithRank(const Shape* shape, int32 rank,
- const Shape** out) TF_MUST_USE_RESULT;
- Status WithRankAtLeast(const Shape* shape, int32 rank,
- const Shape** out) TF_MUST_USE_RESULT;
- Status WithRankAtMost(const Shape* shape, int32 rank,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status WithRank(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
+ Status WithRankAtLeast(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
+ Status WithRankAtMost(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// If <dim> has value <value>, or its value is unknown, returns OK and returns
// the dimension with asserted value in <*out>. Otherwise returns an error.
//
// Note that <*out> may be set to <dim>.
- Status WithValue(const Dimension* dim, int64 value,
- const Dimension** out) TF_MUST_USE_RESULT;
+ Status WithValue(DimensionHandle dim, int64 value,
+ DimensionHandle* out) TF_MUST_USE_RESULT;
// Merges <in0> and <in1> and returns the merged shape in <*out>. If <in0> and
// <in1> are incompatible in rank, or in the value of any dimension, returns
// an error.
//
// Note that <*out> may be set to <in0> or <in1>.
- Status Merge(const Shape* in0, const Shape* in1,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status Merge(ShapeHandle in0, ShapeHandle in1,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// Asserts that <s>'s rank >= <prefix>'s rank, and the first
// <prefix.rank> dimensions of <s> are compatible with the dimensions of
// <prefix>.
// Returns the merged results in <*s_out> and <*prefix_out>.
- Status MergePrefix(const Shape* s, const Shape* prefix, const Shape** s_out,
- const Shape** prefix_out) TF_MUST_USE_RESULT;
+ Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
+ ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
// Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
// and <d1> have incompatible values, returns an error.
//
// Note that <*out> may be set to <d0> or <d1>.
- Status Merge(const Dimension* d0, const Dimension* d1,
- const Dimension** out) TF_MUST_USE_RESULT;
+ Status Merge(DimensionHandle d0, DimensionHandle d1,
+ DimensionHandle* out) TF_MUST_USE_RESULT;
// Returns in <*out> a sub-shape of <s> with dimensions [start:].
// <start> can be negative to index from the end of the shape. If <start> >
// rank of <s>, then an empty subshape is returned.
// Returns an error if the rank of <s> is < <start>.
- Status Subshape(const Shape* s, int64 start,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status Subshape(ShapeHandle s, int64 start,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
// <start> and <end> can be negative, to index from the end of the shape.
// <start> and <end> are set to the rank of <s> if > rank of <s>.
// Returns an error if the rank of <s> is insufficient.
- Status Subshape(const Shape* s, int64 start, int64 end,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status Subshape(ShapeHandle s, int64 start, int64 end,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns in <*out> the result of appending the dimensions of <s2> to those
// of <s1>.
- Status Concatenate(const Shape* s1, const Shape* s2,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status Concatenate(ShapeHandle s1, ShapeHandle s2,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns in <out> the shape from replacing <s.dim[dim_index]> with
// <new_dim>.
- Status ReplaceDim(const Shape* s, int dim_index, const Dimension* new_dim,
- const Shape** out) TF_MUST_USE_RESULT;
+ Status ReplaceDim(ShapeHandle s, int dim_index, DimensionHandle new_dim,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
// Returns a new shape with the given dims. The returned value is owned by
// this context.
- const Shape* MakeShape(const std::vector<const Dimension*>& dims);
- const Shape* MakeShape(std::initializer_list<DimensionOrConstant> dims);
+ ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
+ ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
// Returns a new unknown shape.
- const Shape* UnknownShape();
+ ShapeHandle UnknownShape();
// Returns a shape with specified rank but unknown dims.
- const Shape* UnknownShapeOfRank(int32 rank);
+ ShapeHandle UnknownShapeOfRank(int32 rank);
// Returns a new shape of zero dimensions.
- const Shape* Scalar();
+ ShapeHandle Scalar();
// Returns a new shape of one dimension.
- const Shape* Vector(DimensionOrConstant dim);
+ ShapeHandle Vector(DimensionOrConstant dim);
// Returns a new shape of two dimensions.
- const Shape* Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
+ ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
// Returns in <out> a new shape whose dimension sizes come from input tensor
// <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
// the input tensor is NULL, then an unknown shape is returned.
- Status MakeShapeFromShapeTensor(int input_idx, const Shape** out);
+ Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
// Returns in <out> a new shape corresponding to <proto>.
Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
- const Shape** out);
+ ShapeHandle* out);
// Returns a new dimension of the given size. The returned value is owned by
// this context.
- inline const Dimension* MakeDim(DimensionOrConstant d) {
+ inline DimensionHandle MakeDim(DimensionOrConstant d) {
if (d.dim) {
return d.dim;
} else {
@@ -288,12 +292,12 @@ class InferenceContext {
return all_dims_.back();
}
}
- inline const Dimension* UnknownDim() { return MakeDim(kUnknownDim); }
+ inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
// the value.
- Status MakeDimForScalarInput(int idx, const Dimension** out);
+ Status MakeDimForScalarInput(int idx, DimensionHandle* out);
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
@@ -304,37 +308,36 @@ class InferenceContext {
// Returns in <out> the result of dividing <dividend> by <divisor>.
// Returns an error if <divisor> is not positive or does not evenly
// divide <dividend>.
- Status Divide(const Dimension* dividend, int64 divisor,
- const Dimension** out);
+ Status Divide(DimensionHandle dividend, int64 divisor, DimensionHandle* out);
// Returns in <out> the sum of <first> and <second>.
- Status Add(const Dimension* first, DimensionOrConstant second,
- const Dimension** out);
+ Status Add(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out);
// Returns in <out> the dimension that is <first> minus <second>.
- Status Subtract(const Dimension* first, DimensionOrConstant second,
- const Dimension** out);
+ Status Subtract(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out);
// Returns in <out> the product of <first> and <second>.
- Status Multiply(const Dimension* first, DimensionOrConstant second,
- const Dimension** out);
+ Status Multiply(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out);
// Returns in <out> the minimum of <first> and <second>. If either <first> or
// <second> is zero the results is zero. Otherwise, if either <first> or
// <second> is unknown the results is unknown.
- Status Min(const Dimension* first, DimensionOrConstant second,
- const Dimension** out);
+ Status Min(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out);
// Returns in <out> the maximum of <first> and <second>. If either <first> or
// <second> is unknown the results is unknown.
- Status Max(const Dimension* first, DimensionOrConstant second,
- const Dimension** out);
+ Status Max(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out);
Status construction_status() const { return construction_status_; }
// Validates that 'dim' has a known value, and prints an error
// message containing 'name' if validation fails.
- Status ValidateKnownDim(const Dimension* dim, const char* name) {
+ Status ValidateKnownDim(DimensionHandle dim, const char* name) {
if (!ValueKnown(dim)) {
return errors::InvalidArgument("Cannot infer shape because dimension ",
name, " is not known.");
@@ -344,19 +347,19 @@ class InferenceContext {
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
- Status ValidateSparseTensor(const Shape* indices_shape,
- const Shape* values_shape,
- const Shape* shape_shape) {
+ Status ValidateSparseTensor(ShapeHandle indices_shape,
+ ShapeHandle values_shape,
+ ShapeHandle shape_shape) {
// Validate ranks.
- const Shape* unused_shape;
+ ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape));
TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape));
TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape));
// Number of elements in indices and values must match.
- const Dimension* num_index_elements_dim = Dim(indices_shape, 0);
+ DimensionHandle num_index_elements_dim = Dim(indices_shape, 0);
if (ValueKnown(num_index_elements_dim)) {
- const Dimension* num_values_elements_dim = Dim(values_shape, 0);
+ DimensionHandle num_values_elements_dim = Dim(values_shape, 0);
if (ValueKnown(num_values_elements_dim)) {
int64 num_index_elements = Value(num_index_elements_dim);
int64 num_values_elements = Value(num_values_elements_dim);
@@ -369,9 +372,9 @@ class InferenceContext {
}
// Rank embedded in indices must match shape.
- const Dimension* index_rank_dim = Dim(indices_shape, 1);
+ DimensionHandle index_rank_dim = Dim(indices_shape, 1);
if (ValueKnown(index_rank_dim)) {
- const Dimension* shape_rank_dim = Dim(shape_shape, 0);
+ DimensionHandle shape_rank_dim = Dim(shape_shape, 0);
if (ValueKnown(shape_rank_dim)) {
int64 index_rank = Value(index_rank_dim);
int32 shape_rank = Value(shape_rank_dim);
@@ -394,16 +397,16 @@ class InferenceContext {
void PostInputInit();
// Returns a shape from 'shape_string'.
- Status MakeShapeFromString(const string& shape_string, const Shape** output);
+ Status MakeShapeFromString(const string& shape_string, ShapeHandle* output);
- const Dimension* GetDimension(const DimensionOrConstant& d);
+ DimensionHandle GetDimension(const DimensionOrConstant& d);
- Status ReturnUnknownShape(const Shape** out) {
+ Status ReturnUnknownShape(ShapeHandle* out) {
*out = UnknownShape();
return Status::OK();
}
- Status ReturnCreatedShape(const std::vector<const Dimension*>& dims,
- const Shape** out) {
+ Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
+ ShapeHandle* out) {
*out = MakeShape(dims);
return Status::OK();
}
@@ -412,10 +415,10 @@ class InferenceContext {
std::vector<Dimension*> all_dims_; // values are owned.
// inputs_ and outputs_ refer to values from all_shapes_.
- std::vector<const Shape*> inputs_;
+ std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
- std::vector<const Shape*> outputs_;
+ std::vector<ShapeHandle> outputs_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
@@ -440,10 +443,10 @@ inline Dimension::Dimension(int64 value) : value_(value) {
}
inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
-inline Shape::Shape(const std::vector<const Dimension*>& dims)
+inline Shape::Shape(const std::vector<DimensionHandle>& dims)
: rank_(dims.size()), dims_(dims) {}
-inline DimensionOrConstant::DimensionOrConstant(const Dimension* dim)
+inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
: dim(dim) {
DCHECK(dim != nullptr) << "Internal error: Got nullptr for Dimension.";
}
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 7be573d249..7b6c9407c4 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -109,8 +109,8 @@ TEST(ShapeInferenceTest, WithRank) {
auto in0 = c.input(0);
auto in1 = c.input(1);
- const Shape* s1 = nullptr;
- const Shape* s2 = nullptr;
+ ShapeHandle s1 = nullptr;
+ ShapeHandle s2 = nullptr;
// WithRank on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
@@ -147,8 +147,8 @@ TEST(ShapeInferenceTest, WithRankAtMost) {
auto in0 = c.input(0);
auto in1 = c.input(1);
- const Shape* s1 = nullptr;
- const Shape* s2 = nullptr;
+ ShapeHandle s1 = nullptr;
+ ShapeHandle s2 = nullptr;
// WithRankAtMost on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
@@ -182,8 +182,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
auto in0 = c.input(0);
auto in1 = c.input(1);
- const Shape* s1 = nullptr;
- const Shape* s2 = nullptr;
+ ShapeHandle s1 = nullptr;
+ ShapeHandle s2 = nullptr;
// WithRankAtLeast on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
@@ -217,8 +217,8 @@ TEST(ShapeInferenceTest, WithValue) {
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
- const Dimension* out1 = nullptr;
- const Dimension* out2 = nullptr;
+ DimensionHandle out1 = nullptr;
+ DimensionHandle out2 = nullptr;
// WithValue on a dimension with unknown value always succeeds.
EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
@@ -259,7 +259,7 @@ TEST(ShapeInferenceTest, MergeDim) {
auto d2_b = c.Dim(c.input(0), 2);
auto d1 = c.Dim(c.input(0), 3);
auto d_unknown_b = c.Dim(c.input(0), 4);
- const Dimension* out = nullptr;
+ DimensionHandle out = nullptr;
// Merging anything with unknown returns the same pointer.
EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
@@ -302,7 +302,7 @@ TEST(ShapeInferenceTest, MergeShape) {
auto s_1_3 = c.input(4);
auto s_unknown_b = c.input(5);
auto s_1 = c.input(6);
- const Shape* out = nullptr;
+ ShapeHandle out = nullptr;
// Merging any shape with unknown returns the shape.
EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
@@ -359,8 +359,8 @@ TEST(ShapeInferenceTest, MergePrefix) {
auto s_1_u_3 = c.input(2);
auto s_2_4 = c.input(3);
- const Shape* s_out = nullptr;
- const Shape* s_prefix_out = nullptr;
+ ShapeHandle s_out = nullptr;
+ ShapeHandle s_prefix_out = nullptr;
// Merging with unknown returns the inputs.
EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
@@ -399,8 +399,8 @@ TEST(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {"[1,2,3,?,5]", "?"}, {});
- const Shape* unknown = c.input(1);
- const Shape* out;
+ ShapeHandle unknown = c.input(1);
+ ShapeHandle out;
EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(out == unknown);
@@ -412,7 +412,7 @@ TEST(ShapeInferenceTest, Subshape) {
EXPECT_TRUE(out != unknown);
const int kFullRank = 5;
- const Shape* out_arr[4];
+ ShapeHandle out_arr[4];
auto in0 = c.input(0);
EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
@@ -472,8 +472,8 @@ TEST(ShapeInferenceTest, Concatenate) {
auto in0 = c.input(0);
auto in1 = c.input(1);
- const Shape* unknown = c.input(2);
- const Shape* out;
+ ShapeHandle unknown = c.input(2);
+ ShapeHandle out;
EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(out != unknown);
@@ -499,7 +499,7 @@ TEST(ShapeInferenceTest, ReplaceDim) {
auto in = c.input(0);
auto unknown = c.input(1);
- const Shape* replaced;
+ ShapeHandle replaced;
EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
@@ -527,7 +527,7 @@ TEST(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,3,?,5]"}, {});
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
const int rank = c.Rank(in0);
for (int i = 0; i < rank; ++i) {
@@ -608,7 +608,7 @@ TEST(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [](Tensor* t) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {"?"}, {t});
- const Shape* out;
+ ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
return c.DebugString(out);
@@ -643,7 +643,7 @@ TEST(ShapeInferenceTest, MakeShapeFromShapeTensor) {
{
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {"[1,?]"}, {nullptr});
- const Shape* out;
+ ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
}
@@ -655,7 +655,7 @@ TEST(ShapeInferenceTest, MakeShapeFromShapeProto) {
TensorShapeProto proto;
// With a set unknown rank.
- const Shape* out;
+ ShapeHandle out;
proto.set_unknown_rank(true);
EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
@@ -733,7 +733,7 @@ TEST(ShapeInferenceTest, MakeDimForScalarInput) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {"[]", "[]"}, {&t1, &t2});
- const Dimension* d;
+ DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
EXPECT_EQ("20", c.DebugString(d));
@@ -776,7 +776,7 @@ TEST(ShapeInferenceTest, Divide) {
auto d_unknown = c.Dim(s, 1);
// Dividing unknown by non-1 gives new unknown.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Divide(d_unknown, 2, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(out != d_unknown);
@@ -808,7 +808,7 @@ TEST(ShapeInferenceTest, Add) {
auto d_0 = c.Dim(s, 2);
// Adding non-zero to unknown gives new unknown.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(out != d_unknown);
@@ -858,7 +858,7 @@ TEST(ShapeInferenceTest, Subtract) {
auto d_5 = c.Dim(s, 3);
// Subtracting non-zero from unknown gives new unknown.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(out != d_unknown);
@@ -906,7 +906,7 @@ TEST(ShapeInferenceTest, Multiply) {
auto d_1 = c.Dim(s, 3);
// Multiplying non-zero to unknown gives new unknown.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
@@ -979,7 +979,7 @@ TEST(ShapeInferenceTest, Min) {
auto d_0 = c.Dim(s, 3);
// Minimum involving zero and unknown returns zero.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
EXPECT_EQ(d_0, out);
EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
@@ -1026,7 +1026,7 @@ TEST(ShapeInferenceTest, Max) {
auto d_unknown = c.Dim(s, 2);
// Maximum involving unknowns gives new unknown.
- const Dimension* out;
+ DimensionHandle out;
EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index 60a9cb101f..28d2197629 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -25,8 +25,8 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
-using shape_inference::Shape;
+using shape_inference::DimensionHandle;
+using shape_inference::ShapeHandle;
using errors::Unknown;
Status InferShapes(ShapeInferenceTestOp op, const string& ins,
@@ -48,9 +48,9 @@ Status InferShapes(ShapeInferenceTestOp op, const string& ins,
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
const int num_outputs = c.num_outputs();
- std::unordered_map<const Dimension*, std::pair<int, int>>
+ std::unordered_map<DimensionHandle, std::pair<int, int>>
dim_to_input_and_dim_idx;
- std::unordered_map<const Shape*, int> shape_to_input_idx;
+ std::unordered_map<ShapeHandle, int> shape_to_input_idx;
for (int i = 0; i < c.num_inputs(); ++i) {
auto in = c.input(i);
shape_to_input_idx[in] = i;
@@ -120,7 +120,7 @@ Status InferShapes(ShapeInferenceTestOp op, const string& ins,
for (int j = 0; j < expected_dims.size(); ++j) {
err_prefix = strings::StrCat("Output dim ", i, ",", j);
StringPiece expected_dim(expected_dims[j]);
- const Dimension* out_dim = c.Dim(out, j);
+ DimensionHandle out_dim = c.Dim(out, j);
std::pair<int, int> in_dim_idx = gtl::FindWithDefault(
dim_to_input_and_dim_idx, out_dim, std::make_pair(-1, -1));
if (expected_dim == "?") {
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 528d6407d4..051e0c1302 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -21,9 +21,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
@@ -41,14 +41,14 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
Status PadShapeFn(InferenceContext* c) {
// Paddings is a matrix of [input_rank, 2].
- const Shape* paddings;
+ ShapeHandle paddings;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
// n_dim and input.rank are equivalent.
- const Shape* input = c->input(0);
- const Dimension* n_dim = c->Dim(paddings, 0);
+ ShapeHandle input = c->input(0);
+ DimensionHandle n_dim = c->Dim(paddings, 0);
if (c->ValueKnown(n_dim)) {
TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
} else if (c->RankKnown(input)) {
@@ -75,7 +75,7 @@ Status PadShapeFn(InferenceContext* c) {
// paddings_t is known.
auto paddings_data = paddings_t->matrix<int32>();
- std::vector<const Dimension*> dims(num_dims);
+ std::vector<DimensionHandle> dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
const int32 pad0 = paddings_data(i, 0);
const int32 pad1 = paddings_data(i, 1);
@@ -98,7 +98,7 @@ REGISTER_OP("Pack")
.Attr("axis: int = 0")
.SetShapeFn([](InferenceContext* c) {
// Validate shapes of all inputs are compatible
- const Shape* cur = c->input(c->num_inputs() - 1);
+ ShapeHandle cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
"From merging shape ", i,
@@ -116,7 +116,7 @@ REGISTER_OP("Pack")
// Copy all dimensions over, inserting a dimension of value #inputs
// at <axis>.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
int index = 0;
while (index < axis) dims.push_back(c->Dim(cur, index++));
dims.push_back(c->MakeDim(c->num_inputs()));
@@ -162,8 +162,8 @@ REGISTER_OP("Unpack")
.Attr("T: type")
.Attr("axis: int = 0")
.SetShapeFn([](InferenceContext* c) {
- const Shape* s = c->input(0);
- const Shape* out;
+ ShapeHandle s = c->input(0);
+ ShapeHandle out;
if (c->RankKnown(s)) {
// Determine the axis that will be removed, converting from negative
// axes to a positive point per negative indexing rules.
@@ -172,12 +172,12 @@ REGISTER_OP("Unpack")
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
// The axis dim matches the number of outputs.
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
// Copy all dimensions, removing the <axis> dimension.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < rank; ++i) {
if (i != axis) dims.push_back(c->Dim(s, i));
}
@@ -272,11 +272,11 @@ REGISTER_OP("Split")
.Attr("num_split: int >= 1")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Dimension* split_dimension;
+ DimensionHandle split_dimension;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
int num_split = c->num_outputs();
- const Shape* input = c->input(1);
- const Shape* out;
+ ShapeHandle input = c->input(1);
+ ShapeHandle out;
if (!c->ValueKnown(split_dimension)) {
if (c->RankKnown(input)) {
out = c->UnknownShapeOfRank(c->Rank(input));
@@ -286,7 +286,7 @@ REGISTER_OP("Split")
} else {
int64 split_dim = c->Value(split_dimension);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
- const Dimension* split_dim_size;
+ DimensionHandle split_dim_size;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->Divide(c->Dim(input, split_dim), num_split, &split_dim_size),
"Number of ways to split should evenly divide the split dimension");
@@ -319,7 +319,7 @@ REGISTER_OP("Const")
TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
TensorShape shape(proto->tensor_shape());
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < shape.dims(); ++i) {
dims.push_back(c->MakeDim(shape.dim_size(i)));
}
@@ -345,7 +345,7 @@ REGISTER_OP("ImmutableConst")
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_from_attr));
TensorShapeProto shape_proto;
shape_from_attr.AsProto(&shape_proto);
- const Shape* output_shape;
+ ShapeHandle output_shape;
TF_RETURN_IF_ERROR(
c->MakeShapeFromShapeProto(shape_proto, &output_shape));
c->set_output(0, output_shape);
@@ -381,10 +381,10 @@ REGISTER_OP("Diag")
.Output("output: T")
.Attr("T: {float, double, int32, int64, complex64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* in = c->input(0);
+ ShapeHandle in = c->input(0);
TF_RETURN_IF_ERROR(c->WithRankAtMost(in, 3, &in));
// Output shape is original concatenated with itself.
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
c->set_output(0, out);
return Status::OK();
@@ -419,7 +419,7 @@ REGISTER_OP("DiagPart")
.Output("diagonal: T")
.Attr("T: {float, double, int32, int64, complex64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* in = c->input(0);
+ ShapeHandle in = c->input(0);
if (!c->RankKnown(in)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
@@ -433,7 +433,7 @@ REGISTER_OP("DiagPart")
const int32 mid = rank / 2;
// output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
- std::vector<const Dimension*> dims(mid);
+ std::vector<DimensionHandle> dims(mid);
for (int i = 0; i < mid; ++i) {
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
@@ -474,14 +474,14 @@ REGISTER_OP("BatchMatrixDiag")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* in;
+ ShapeHandle in;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
if (!c->RankKnown(in)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
const int32 rank = c->Rank(in);
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(
c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
c->set_output(0, out);
@@ -528,17 +528,17 @@ REGISTER_OP("BatchMatrixSetDiag")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
- const Shape* diag;
+ ShapeHandle input;
+ ShapeHandle diag;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
- const Dimension* square_dim;
+ DimensionHandle square_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, -2), c->Dim(input, -1), &square_dim));
TF_RETURN_IF_ERROR(c->Merge(square_dim, c->Dim(diag, -1), &square_dim));
- const Shape* output;
+ ShapeHandle output;
TF_RETURN_IF_ERROR(c->Concatenate(diag, c->Vector(square_dim), &output));
TF_RETURN_IF_ERROR(c->Merge(input, output, &output));
@@ -573,7 +573,7 @@ REGISTER_OP("BatchMatrixDiagPart")
.Output("diagonal: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* in;
+ ShapeHandle in;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
if (!c->RankKnown(in)) {
c->set_output(0, c->UnknownShape());
@@ -581,12 +581,12 @@ REGISTER_OP("BatchMatrixDiagPart")
}
const int32 rank = c->Rank(in);
// Last two dims must match.
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(in, rank - 1), c->Dim(in, rank - 2), &unused));
// Output shape has all dims but last of input.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < rank - 1; ++i) dims.push_back(c->Dim(in, i));
c->set_output(0, c->MakeShape(dims));
return Status::OK();
@@ -695,10 +695,10 @@ REGISTER_OP("Reverse")
"T: {uint8, int8, int32, int64, bool, half, float, double, complex64, "
"complex128}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
- const Shape* dims;
+ ShapeHandle input = c->input(0);
+ ShapeHandle dims;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
- const Dimension* dims_dim = c->Dim(dims, 0);
+ DimensionHandle dims_dim = c->Dim(dims, 0);
if (c->ValueKnown(dims_dim)) {
TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
}
@@ -795,7 +795,7 @@ REGISTER_OP("EditDistance")
auto h_values = hypothesis_shape_t->flat<int64>();
auto t_values = truth_shape_t->flat<int64>();
- std::vector<const Dimension*> dims(hypothesis_shape_t->NumElements() - 1);
+ std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
for (int i = 0; i < dims.size(); ++i) {
dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
}
@@ -869,7 +869,7 @@ REGISTER_OP("Fill")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
@@ -900,12 +900,12 @@ REGISTER_OP("Gather")
.Attr("Tparams: type")
.Attr("Tindices: {int32,int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
- const Shape* params_subshape;
+ ShapeHandle params_subshape;
TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
- const Shape* indices_shape = c->input(1);
- const Shape* out;
+ ShapeHandle indices_shape = c->input(1);
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
c->set_output(0, out);
return Status::OK();
@@ -941,10 +941,10 @@ REGISTER_OP("GatherNd")
.Attr("Tparams: type")
.Attr("Tindices: {int32,int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* params = c->input(0);
- const Shape* indices;
+ ShapeHandle params = c->input(0);
+ ShapeHandle indices;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
- const Dimension* r_dim = c->Dim(indices, -1);
+ DimensionHandle r_dim = c->Dim(indices, -1);
if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
c->set_output(0, c->UnknownShape());
@@ -959,11 +959,11 @@ REGISTER_OP("GatherNd")
}
// Remove r_dim from indices to get output.
- const Shape* indices_slice;
- const Shape* params_slice;
+ ShapeHandle indices_slice;
+ ShapeHandle params_slice;
TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
c->set_output(0, out);
return Status::OK();
@@ -1131,8 +1131,8 @@ REGISTER_OP("Reshape")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* in = c->input(0);
- const Shape* out;
+ ShapeHandle in = c->input(0);
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
// If the rank and all dimensions of the input tensor are known, we may
@@ -1141,8 +1141,8 @@ REGISTER_OP("Reshape")
// dimension information.
// Additionally, if the rank of the out shape is unknown we have no shape
// information to go off of.
- const Dimension* num_in_elems = c->NumElements(in);
- const Dimension* num_out_elems = c->NumElements(out);
+ DimensionHandle num_in_elems = c->NumElements(in);
+ DimensionHandle num_out_elems = c->NumElements(out);
if (!c->ValueKnown(num_in_elems) || !c->RankKnown(out)) {
// Do nothing. We have no shape information to infer from so we directly
// return out as our shape.
@@ -1159,9 +1159,9 @@ REGISTER_OP("Reshape")
// If we don't know the number of output elements, we can infer
// the missing dimension.
int32 unknown_idx = -1;
- const Dimension* known_elems = c->MakeDim(1);
+ DimensionHandle known_elems = c->MakeDim(1);
for (int32 i = 0; i < c->Rank(out); ++i) {
- const Dimension* dim = c->Dim(out, i);
+ DimensionHandle dim = c->Dim(out, i);
if (!c->ValueKnown(dim)) {
if (unknown_idx >= 0) {
return errors::InvalidArgument(
@@ -1173,7 +1173,7 @@ REGISTER_OP("Reshape")
TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems));
}
}
- const Dimension* inferred_dim;
+ DimensionHandle inferred_dim;
TF_RETURN_IF_ERROR(
c->Divide(num_in_elems, c->Value(known_elems), &inferred_dim));
TF_RETURN_IF_ERROR(c->ReplaceDim(out, unknown_idx, inferred_dim, &out));
@@ -1250,7 +1250,7 @@ REGISTER_OP("InvertPermutation")
.Input("x: int32")
.Output("y: int32")
.SetShapeFn([](InferenceContext* c) {
- const Shape* x;
+ ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
c->set_output(0, x);
return Status::OK();
@@ -1285,10 +1285,10 @@ REGISTER_OP("Transpose")
.Output("y: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
- const Shape* perm_shape = c->input(1);
+ ShapeHandle input = c->input(0);
+ ShapeHandle perm_shape = c->input(1);
const Tensor* perm = c->input_tensor(1);
- const Dimension* perm_elems = c->NumElements(perm_shape);
+ DimensionHandle perm_elems = c->NumElements(perm_shape);
// If we don't have rank information on the input or value information on
// perm we can't return any shape information, otherwise we have enough
// information to at least find the rank of the output.
@@ -1307,7 +1307,7 @@ REGISTER_OP("Transpose")
} else {
rank = perm->NumElements();
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.resize(rank);
TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
// Ensure that perm is a vector and has rank elements.
@@ -1423,7 +1423,7 @@ namespace {
Status ShapeShapeFn(InferenceContext* c) {
for (int i = 0; i < c->num_inputs(); ++i) {
- const Dimension* dim;
+ DimensionHandle dim;
if (c->RankKnown(c->input(i))) {
dim = c->MakeDim(c->Rank(c->input(i)));
} else {
@@ -1477,8 +1477,8 @@ REGISTER_OP("ReverseSequence")
.Attr("batch_dim: int = 0")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
- const Shape* seq_lens_shape;
+ ShapeHandle input = c->input(0);
+ ShapeHandle seq_lens_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
int64 seq_dim;
@@ -1501,12 +1501,12 @@ REGISTER_OP("ReverseSequence")
seq_dim, " vs. ", input_rank);
}
- const Dimension* batch_dim_dim = c->Dim(input, batch_dim);
+ DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
TF_RETURN_IF_ERROR(
c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
// Replace batch_dim of input with batch_size
- const Shape* output_shape;
+ ShapeHandle output_shape;
TF_RETURN_IF_ERROR(
c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
c->set_output(0, output_shape);
@@ -1627,11 +1627,11 @@ namespace {
template <typename T>
Status SliceHelper(InferenceContext* c, const Tensor* begin_t,
- const Tensor* sizes_t, std::vector<const Dimension*>* dims) {
+ const Tensor* sizes_t, std::vector<DimensionHandle>* dims) {
auto begin_vec = begin_t->vec<T>();
auto sizes_vec = sizes_t->vec<T>();
for (int i = 0; i < sizes_t->NumElements(); ++i) {
- const Dimension* dim = c->Dim(c->input(0), i);
+ DimensionHandle dim = c->Dim(c->input(0), i);
if (sizes_vec(i) != -1) {
if (c->ValueKnown(dim)) {
auto dim_val = c->Value(dim);
@@ -1664,7 +1664,7 @@ Status SliceHelper(InferenceContext* c, const Tensor* begin_t,
dims->emplace_back(c->MakeDim(sizes_vec(i)));
} else {
- const Dimension* result;
+ DimensionHandle result;
TF_RETURN_IF_ERROR(c->Subtract(dim, begin_vec(i), &result));
dims->emplace_back(result);
}
@@ -1684,16 +1684,16 @@ REGISTER_OP("Slice")
.Attr("T: type")
.Attr("Index: {int32,int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
- const Shape* begin_shape;
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- const Shape* sizes_shape;
+ ShapeHandle sizes_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
// Merge to check compatibility of begin and sizes tensors.
TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
- const Dimension* ndims = c->Dim(begin_shape, 0);
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
if (c->ValueKnown(ndims)) {
TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
}
@@ -1702,7 +1702,7 @@ REGISTER_OP("Slice")
const Tensor* sizes_t = c->input_tensor(2);
if (sizes_t != nullptr && begin_t != nullptr) {
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
// If the begin and sizes tensors are available, then
// we can be precise about the shape of the output.
if (begin_t->dtype() == DT_INT64) {
@@ -1831,10 +1831,10 @@ REGISTER_OP("Tile")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
- const Shape* multiples;
+ ShapeHandle input;
+ ShapeHandle multiples;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &multiples));
- const Dimension* multiples_dim0 = c->Dim(multiples, 0);
+ DimensionHandle multiples_dim0 = c->Dim(multiples, 0);
if (!c->ValueKnown(multiples_dim0)) {
// Length of multiples vector unknown, so output is unknown.
//
@@ -1856,7 +1856,7 @@ REGISTER_OP("Tile")
// Multiply each input dimension by its corresponding value
// from the multiples tensor.
auto multiples_data = multiples_t->vec<int32>();
- std::vector<const Dimension*> dims(rank);
+ std::vector<DimensionHandle> dims(rank);
for (int i = 0; i < rank; ++i) {
const int32 multiple = multiples_data(i);
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, i), multiple, &dims[i]));
@@ -1945,7 +1945,7 @@ REGISTER_OP("BroadcastGradientArgs")
.Output("r1: int32")
.SetShapeFn([](InferenceContext* c) {
// TODO(mrry): Implement constant_value for BroadcastGradientArgs?
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
@@ -2049,9 +2049,9 @@ REGISTER_OP("MirrorPadGrad")
.Attr("T: type")
.Attr(GetMirrorPadModeAttrString())
.SetShapeFn([](InferenceContext* c) {
- const Shape* paddings;
+ ShapeHandle paddings;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
- const Dimension* pad_0 = c->Dim(paddings, 0);
+ DimensionHandle pad_0 = c->Dim(paddings, 0);
if (!c->ValueKnown(pad_0)) {
// We don't know the rank of the output since the first
// padding dimension is unknown.
@@ -2060,7 +2060,7 @@ REGISTER_OP("MirrorPadGrad")
}
int64 input_rank = c->Value(pad_0);
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
TF_RETURN_IF_ERROR(
c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
@@ -2075,7 +2075,7 @@ REGISTER_OP("MirrorPadGrad")
}
auto paddings_data = paddings_t->matrix<int32>();
- std::vector<const Dimension*> dims(input_rank);
+ std::vector<DimensionHandle> dims(input_rank);
for (int i = 0; i < input_rank; ++i) {
const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
@@ -2137,7 +2137,7 @@ REGISTER_OP("Placeholder")
TensorShapeProto shape_proto;
shape.AsProto(&shape_proto);
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
c->set_output(0, out);
return Status::OK();
@@ -2162,17 +2162,17 @@ REGISTER_OP("PlaceholderWithDefault")
.Attr("dtype: type")
.Attr("shape: shape")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
+ ShapeHandle input = c->input(0);
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
TensorShapeProto shape_proto;
shape.AsProto(&shape_proto);
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
// We merge for compatibility checking, but return the output,
// since output_shape may be less precise than input_shape.
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
c->set_output(0, out);
return Status::OK();
@@ -2193,8 +2193,8 @@ REGISTER_OP("ExpandDims")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
- const Shape* expand_dim;
+ ShapeHandle input = c->input(0);
+ ShapeHandle expand_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &expand_dim));
const Tensor* dim_t = c->input_tensor(1);
@@ -2207,11 +2207,11 @@ REGISTER_OP("ExpandDims")
which_dim += c->Rank(input) + 1;
}
- const Shape* end;
+ ShapeHandle end;
TF_RETURN_IF_ERROR(c->Subshape(input, which_dim, &end));
// Build output as start + 1 + end.
- const Shape* output;
+ ShapeHandle output;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, which_dim, &output));
TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
@@ -2265,7 +2265,7 @@ REGISTER_OP("Squeeze")
.Attr("T: type")
.Attr("squeeze_dims: list(int) >= 0 = []")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
+ ShapeHandle input = c->input(0);
if (!c->RankKnown(input)) {
// Input shape unknown.
return shape_inference::UnknownShape(c);
@@ -2287,7 +2287,7 @@ REGISTER_OP("Squeeze")
}
}
- std::vector<const Dimension*> result_shape;
+ std::vector<DimensionHandle> result_shape;
for (int i = 0; i < input_rank; ++i) {
// True if squeeze_dims contains an entry to squeeze this
// dimension.
@@ -2295,7 +2295,7 @@ REGISTER_OP("Squeeze")
std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
squeeze_dims.end();
- const Dimension* dim = c->Dim(input, i);
+ DimensionHandle dim = c->Dim(input, i);
if (!c->ValueKnown(dim)) {
// Assume that the squeezed dimension will be 1 at runtime.
@@ -2362,11 +2362,11 @@ REGISTER_OP("ListDiff")
.Output("idx: int32")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
// TODO(mrry): Indicate that the length falls within an interval?
- const Shape* out = c->Vector(InferenceContext::kUnknownDim);
+ ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, out);
c->set_output(1, out);
return Status::OK();
@@ -2410,14 +2410,14 @@ REGISTER_OP("SpaceToBatch")
.Attr("T: type")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Shape* paddings;
+ ShapeHandle paddings;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
- const Dimension* pad0_dim = c->Dim(paddings, 0);
- const Dimension* pad1_dim = c->Dim(paddings, 1);
+ DimensionHandle pad0_dim = c->Dim(paddings, 0);
+ DimensionHandle pad1_dim = c->Dim(paddings, 1);
if (!c->ValueKnown(pad0_dim) || !c->ValueKnown(pad1_dim)) {
return shape_inference::UnknownShape(c);
@@ -2433,8 +2433,8 @@ REGISTER_OP("SpaceToBatch")
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
- const Dimension* output_height;
- const Dimension* output_width;
+ DimensionHandle output_height;
+ DimensionHandle output_width;
const Tensor* paddings_t = c->input_tensor(1);
if (paddings_t == nullptr) {
@@ -2457,7 +2457,7 @@ REGISTER_OP("SpaceToBatch")
c->Add(c->Dim(input, 2), pad_left + pad_right, &output_width));
}
- const Dimension* batch;
+ DimensionHandle batch;
TF_RETURN_IF_ERROR(
c->Multiply(c->Dim(input, 0), block_size * block_size, &batch));
@@ -2575,14 +2575,14 @@ REGISTER_OP("BatchToSpace")
.Attr("T: type")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Shape* crops;
+ ShapeHandle crops;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &crops));
- const Dimension* crops0_dim = c->Dim(crops, 0);
- const Dimension* crops1_dim = c->Dim(crops, 1);
+ DimensionHandle crops0_dim = c->Dim(crops, 0);
+ DimensionHandle crops1_dim = c->Dim(crops, 1);
if (!c->ValueKnown(crops0_dim) || !c->ValueKnown(crops1_dim)) {
return shape_inference::UnknownShape(c);
@@ -2598,13 +2598,13 @@ REGISTER_OP("BatchToSpace")
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
- const Dimension* batch;
+ DimensionHandle batch;
// Will return an error if does not evenly divide
TF_RETURN_IF_ERROR(
c->Divide(c->Dim(input, 0), block_size * block_size, &batch));
- const Dimension* output_height;
- const Dimension* output_width;
+ DimensionHandle output_height;
+ DimensionHandle output_width;
const Tensor* crops_t = c->input_tensor(1);
if (crops_t == nullptr) {
@@ -2733,15 +2733,15 @@ REGISTER_OP("SpaceToDepth")
.Attr("T: type")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
- const Dimension* output_height;
- const Dimension* output_width;
- const Dimension* output_depth;
+ DimensionHandle output_height;
+ DimensionHandle output_width;
+ DimensionHandle output_depth;
// Will return an error if does not evenly divide
TF_RETURN_IF_ERROR(
c->Divide(c->Dim(input, 1), block_size, &output_height));
@@ -2840,15 +2840,15 @@ REGISTER_OP("DepthToSpace")
.Attr("T: type")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
- const Dimension* output_height;
- const Dimension* output_width;
- const Dimension* output_depth;
+ DimensionHandle output_height;
+ DimensionHandle output_width;
+ DimensionHandle output_depth;
TF_RETURN_IF_ERROR(
c->Multiply(c->Dim(input, 1), block_size, &output_height));
TF_RETURN_IF_ERROR(
@@ -2956,7 +2956,7 @@ REGISTER_OP("ExtractImagePatches")
.Attr("T: realnumbertype")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
std::vector<int32> ksizes;
@@ -2998,10 +2998,10 @@ REGISTER_OP("ExtractImagePatches")
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* output_depth_dim = c->Dim(input_shape, 3);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle output_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
@@ -3020,7 +3020,7 @@ REGISTER_OP("ExtractImagePatches")
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
&padding_before, &padding_after));
- const Shape* output_shape = c->MakeShape(
+ ShapeHandle output_shape = c->MakeShape(
{batch_size_dim, output_rows, output_cols, output_depth_dim});
c->set_output(0, output_shape);
return Status::OK();
@@ -3057,7 +3057,7 @@ REGISTER_OP("Bitcast")
.Attr("T: numbertype")
.Attr("type: numbertype")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input = c->input(0);
+ ShapeHandle input = c->input(0);
if (!c->RankKnown(input)) {
// Input shape unknown.
return shape_inference::UnknownShape(c);
@@ -3079,7 +3079,7 @@ REGISTER_OP("Bitcast")
"one of the type sizes is zero.");
}
- const Shape* new_shape;
+ ShapeHandle new_shape;
if (input_type_size == output_type_size) {
// No change in size.
new_shape = input;
@@ -3087,7 +3087,7 @@ REGISTER_OP("Bitcast")
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
int64 divisor_val = output_type_size / input_type_size;
- const Dimension* last_dim = c->Dim(new_shape, -1);
+ DimensionHandle last_dim = c->Dim(new_shape, -1);
if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
} else {
@@ -3098,7 +3098,7 @@ REGISTER_OP("Bitcast")
} else {
// Input type size is larger than output type size.
int64 divisor_val = input_type_size / output_type_size;
- const Shape* extension = c->Vector(divisor_val);
+ ShapeHandle extension = c->Vector(divisor_val);
TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
}
@@ -3136,10 +3136,10 @@ REGISTER_OP("OneHot")
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
- const Dimension* depth;
+ DimensionHandle depth;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
- const Shape* indices = c->input(0);
+ ShapeHandle indices = c->input(0);
if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
int32 new_rank = c->Rank(indices) + 1;
@@ -3147,9 +3147,9 @@ REGISTER_OP("OneHot")
// C++ returns negative values from % if the dividend is negative.
int32 depth_index = (axis + new_rank) % new_rank;
// Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
- const Shape* front;
- const Shape* back;
- const Shape* out;
+ ShapeHandle front;
+ ShapeHandle back;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc
index 556090231f..037c393574 100644
--- a/tensorflow/core/ops/candidate_sampling_ops.cc
+++ b/tensorflow/core/ops/candidate_sampling_ops.cc
@@ -18,9 +18,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
@@ -30,11 +30,11 @@ Status CandidateSamplerShapeFn(InferenceContext* c) {
int64 num_true;
TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
- const Shape* true_classes_shape;
+ ShapeHandle true_classes_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes_shape));
- const Dimension* batch_size = c->Dim(true_classes_shape, 0);
+ DimensionHandle batch_size = c->Dim(true_classes_shape, 0);
- const Shape* num_sampled_v = c->Vector(num_sampled);
+ ShapeHandle num_sampled_v = c->Vector(num_sampled);
c->set_output(0, num_sampled_v);
c->set_output(1, c->Matrix(batch_size, num_true));
c->set_output(2, num_sampled_v);
@@ -378,14 +378,14 @@ REGISTER_OP("ComputeAccidentalHits")
TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
// Validate true_classes.
- const Shape* true_classes;
+ ShapeHandle true_classes;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
// All three outputs are the same shape.
- const Shape* v = c->Vector(InferenceContext::kUnknownDim);
+ ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, v);
c->set_output(1, v);
c->set_output(2, v);
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
index 3b1b7c63d3..3214017939 100644
--- a/tensorflow/core/ops/control_flow_ops.cc
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -20,14 +20,14 @@ limitations under the License.
namespace tensorflow {
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
namespace {
Status SwitchShape(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
- const Shape* out = c->input(0);
+ ShapeHandle out = c->input(0);
c->set_output(0, out);
c->set_output(1, out);
return Status::OK();
@@ -85,16 +85,16 @@ REGISTER_OP("RefSelect")
.Attr("T: type")
.Attr("N: int >= 1")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- const Shape* first_input = c->input(1);
+ ShapeHandle first_input = c->input(1);
if (!c->FullyDefined(first_input)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
// If any inputs aren't fully defined or don't match, we return unknown.
for (int i = 2; i < c->num_inputs(); ++i) {
- const Shape* input = c->input(i);
+ ShapeHandle input = c->input(i);
if (!c->FullyDefined(input) ||
!c->Merge(first_input, input, &unused).ok()) {
c->set_output(0, c->UnknownShape());
@@ -115,13 +115,13 @@ output: The forwarded tensor.
// --------------------------------------------------------------------------
namespace {
Status MergeShape(InferenceContext* c) {
- const Shape* out = c->input(0);
+ ShapeHandle out = c->input(0);
if (!c->RankKnown(out)) {
out = c->UnknownShape();
} else {
int32 rank = c->Rank(out);
for (int i = 1; i < c->num_inputs(); ++i) {
- const Shape* input = c->input(i);
+ ShapeHandle input = c->input(i);
if (c->Rank(input) != rank) {
out = c->UnknownShape();
break;
diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc
index 7e2313ea3a..0b58a8d817 100644
--- a/tensorflow/core/ops/ctc_ops.cc
+++ b/tensorflow/core/ops/ctc_ops.cc
@@ -18,9 +18,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.
@@ -34,23 +34,23 @@ REGISTER_OP("CTCLoss")
.Output("loss: float")
.Output("gradient: float")
.SetShapeFn([](InferenceContext* c) {
- const Shape* inputs;
- const Shape* labels_indices;
- const Shape* labels_values;
- const Shape* sequence_length;
+ ShapeHandle inputs;
+ ShapeHandle labels_indices;
+ ShapeHandle labels_values;
+ ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
c->Dim(labels_values, 0), &unused));
// Get batch size from inputs and sequence_length, and update inputs
// with the merged batch_size since it is returned.
- const Dimension* batch_size;
+ DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
@@ -89,18 +89,18 @@ REGISTER_OP("CTCGreedyDecoder")
.Output("decoded_shape: int64")
.Output("log_probability: float")
.SetShapeFn([](InferenceContext* c) {
- const Shape* inputs;
- const Shape* sequence_length;
+ ShapeHandle inputs;
+ ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
// Get batch size from inputs and sequence_length.
- const Dimension* batch_size;
+ DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
- const Dimension* total_decoded_outputs = c->UnknownDim();
+ DimensionHandle total_decoded_outputs = c->UnknownDim();
c->set_output(0, c->Matrix(total_decoded_outputs, 2));
c->set_output(1, c->Vector(total_decoded_outputs));
c->set_output(2, c->Vector(2));
@@ -144,14 +144,14 @@ REGISTER_OP("CTCBeamSearchDecoder")
.Output("decoded_shape: top_paths * int64")
.Output("log_probability: float")
.SetShapeFn([](InferenceContext* c) {
- const Shape* inputs;
- const Shape* sequence_length;
+ ShapeHandle inputs;
+ ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
// Get batch size from inputs and sequence_length.
- const Dimension* batch_size;
+ DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
@@ -166,7 +166,7 @@ REGISTER_OP("CTCBeamSearchDecoder")
for (int i = 0; i < top_paths; ++i) { // decoded_values
c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
}
- const Shape* shape_v = c->Vector(2);
+ ShapeHandle shape_v = c->Vector(2);
for (int i = 0; i < top_paths; ++i) { // decoded_shape
c->set_output(out_idx++, shape_v);
}
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 63af3e28b4..724d83b7f0 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -20,9 +20,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
@@ -36,8 +36,8 @@ REGISTER_OP("DynamicPartition")
int64 num_partitions;
TF_RETURN_IF_ERROR(c->GetAttr("num_partitions", &num_partitions));
- const Shape* data_shape = c->input(0);
- const Shape* partitions_shape = c->input(1);
+ ShapeHandle data_shape = c->input(0);
+ ShapeHandle partitions_shape = c->input(1);
if (!c->RankKnown(partitions_shape)) {
return shape_inference::UnknownShape(c);
@@ -46,17 +46,17 @@ REGISTER_OP("DynamicPartition")
const int64 rank = c->Rank(partitions_shape);
// data shape must start with partitions_shape
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->MergePrefix(data_shape, partitions_shape, &unused, &unused));
// The partition shape is dynamic in the 0th dimension, and matches
// data_shape in the remaining dimensions.
- const Shape* unknown_dim0 = c->MakeShape({c->UnknownDim()});
+ ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()});
- const Shape* data_suffix_shape;
+ ShapeHandle data_suffix_shape;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape));
- const Shape* result_shape;
+ ShapeHandle result_shape;
TF_RETURN_IF_ERROR(
c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape));
@@ -115,10 +115,10 @@ REGISTER_OP("DynamicStitch")
int64 num_partitions;
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
- const Shape* extra_shape = c->UnknownShape();
+ ShapeHandle extra_shape = c->UnknownShape();
for (int i = 0; i < num_partitions; ++i) {
- const Shape* indices_shape = c->input(i);
- const Shape* data_shape = c->input(i + num_partitions);
+ ShapeHandle indices_shape = c->input(i);
+ ShapeHandle data_shape = c->input(i + num_partitions);
if (!c->RankKnown(indices_shape)) {
continue;
}
@@ -126,17 +126,17 @@ REGISTER_OP("DynamicStitch")
const int64 indices_rank = c->Rank(indices_shape);
// Assert that data_shape starts with indices_shape.
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->MergePrefix(data_shape, indices_shape, &unused, &unused));
// The rest belongs to output.
- const Shape* rest;
+ ShapeHandle rest;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
}
- const Shape* output_shape = c->Vector(c->UnknownDim());
+ ShapeHandle output_shape = c->Vector(c->UnknownDim());
TF_RETURN_IF_ERROR(
c->Concatenate(output_shape, extra_shape, &output_shape));
c->set_output(0, output_shape);
@@ -547,7 +547,7 @@ REGISTER_OP("TensorArray")
.Output("handle: Ref(string)")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Vector(2));
return Status::OK();
@@ -576,8 +576,8 @@ REGISTER_OP("TensorArrayGrad")
.Attr("source: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
c->set_output(0, c->Vector(2));
@@ -637,8 +637,8 @@ REGISTER_OP("TensorArrayWrite")
.Output("flow_out: float")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -662,8 +662,8 @@ REGISTER_OP("TensorArrayRead")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -686,8 +686,8 @@ REGISTER_OP("TensorArrayPack")
.Attr("dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -715,8 +715,8 @@ REGISTER_OP("TensorArrayUnpack")
.Output("flow_out: float")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
@@ -739,8 +739,8 @@ REGISTER_OP("TensorArrayConcat")
.Attr("dtype: type")
.Attr("element_shape_except0: shape = { unknown_rank: true }")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -785,8 +785,8 @@ REGISTER_OP("TensorArraySplit")
.Output("flow_out: float")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
@@ -827,8 +827,8 @@ REGISTER_OP("TensorArraySize")
.Input("flow_in: float")
.Output("size: int32")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
return shape_inference::ScalarShape(c);
@@ -844,8 +844,8 @@ size: The current size of the TensorArray.
REGISTER_OP("TensorArrayClose")
.Input("handle: Ref(string)")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
return Status::OK();
@@ -900,9 +900,9 @@ REGISTER_OP("BarrierInsertMany")
.Attr("T: type")
.Attr("component_index: int")
.SetShapeFn([](InferenceContext* c) {
- const Shape* keys = c->input(1);
- const Shape* values = c->input(2);
- const Shape* unused;
+ ShapeHandle keys = c->input(1);
+ ShapeHandle values = c->input(2);
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
@@ -1016,7 +1016,7 @@ REGISTER_OP("LookupTableFind")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(0, c->UnknownShape());
@@ -1044,7 +1044,7 @@ REGISTER_OP("LookupTableInsert")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &unused));
return Status::OK();
@@ -1064,7 +1064,7 @@ REGISTER_OP("LookupTableSize")
.Input("table_handle: Ref(string)")
.Output("size: int64")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Scalar());
return Status::OK();
@@ -1083,12 +1083,12 @@ REGISTER_OP("LookupTableExport")
.Attr("Tkeys: type")
.Attr("Tvalues: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- const Shape* values = c->UnknownShape();
+ ShapeHandle values = c->UnknownShape();
TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- const Shape* keys = c->Vector(c->Dim(values, 0));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
c->set_output(0, keys);
c->set_output(1, values);
return Status::OK();
@@ -1108,7 +1108,7 @@ REGISTER_OP("LookupTableImport")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &unused));
return Status::OK();
@@ -1211,9 +1211,9 @@ REGISTER_OP("InitializeTable")
.Attr("Tkey: type")
.Attr("Tval: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- const Shape* keys;
+ ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
@@ -1234,7 +1234,7 @@ REGISTER_OP("InitializeTableFromTextFile")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return Status::OK();
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 5a55493517..7eb380798c 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -19,26 +19,26 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
// height and width come from the size_tensor.
-Status SetOutputToSizedImage(InferenceContext* c, const Dimension* batch_dim,
- int size_input_idx, const Dimension* channel_dim) {
+Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
+ int size_input_idx, DimensionHandle channel_dim) {
// Verify shape of size input.
- const Shape* size;
+ ShapeHandle size;
TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
// Get size values from the size tensor.
const Tensor* size_tensor = c->input_tensor(size_input_idx);
- const Dimension* width;
- const Dimension* height;
+ DimensionHandle width;
+ DimensionHandle height;
if (size_tensor == nullptr) {
width = c->UnknownDim();
height = c->UnknownDim();
@@ -51,16 +51,16 @@ Status SetOutputToSizedImage(InferenceContext* c, const Dimension* batch_dim,
}
Status ResizeShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
c->Dim(input, 3));
}
Status DecodeImageShapeFn(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- const Dimension* channels_dim;
+ DimensionHandle channels_dim;
int32 channels;
Status s = c->GetAttr("channels", &channels);
if (s.ok()) {
@@ -79,20 +79,20 @@ Status DecodeImageShapeFn(InferenceContext* c) {
}
Status EncodeImageShapeFn(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
c->set_output(0, c->Scalar());
return Status::OK();
}
Status ColorspaceShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
// The last dimension value is always 3.
- const Dimension* last_dim;
+ DimensionHandle last_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
c->set_output(0, out);
@@ -224,10 +224,10 @@ REGISTER_OP("ResizeNearestNeighborGrad")
.Attr("T: {uint8, int8, int32, half, float, double}")
.Attr("align_corners: bool = false")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Shape* unused;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
const Tensor* size = c->input_tensor(1);
@@ -665,15 +665,15 @@ REGISTER_OP("ExtractGlimpse")
.Attr("normalized: bool = true")
.Attr("uniform_noise: bool = true")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Shape* offsets;
+ ShapeHandle offsets;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
- const Dimension* batch_dim;
+ DimensionHandle batch_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
@@ -734,20 +734,20 @@ REGISTER_OP("CropAndResize")
.Attr("extrapolation_value: float = 0")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Shape* boxes;
+ ShapeHandle boxes;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
- const Shape* box_ind;
+ ShapeHandle box_ind;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
// boxes[0] and box_ind[0] are both num_boxes.
- const Dimension* num_boxes_dim;
+ DimensionHandle num_boxes_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
// boxes.dim(1) is 4.
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
@@ -797,7 +797,7 @@ REGISTER_OP("CropAndResizeGradImage")
.Attr("T: {float, half, double}")
.Attr("method: {'bilinear'} = 'bilinear'")
.SetShapeFn([](InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
c->set_output(0, out);
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 1d528660cf..83e4a83897 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -19,14 +19,14 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
Status ScalarInputsAndOutputs(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
for (int i = 0; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
@@ -44,9 +44,9 @@ REGISTER_OP("Save")
.Input("data: T")
.Attr("T: list(type)")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Shape* s;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ ShapeHandle s;
+ DimensionHandle unused_dim;
// Validate filename.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
@@ -79,9 +79,9 @@ REGISTER_OP("SaveSlices")
.Input("data: T")
.Attr("T: list(type)")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Shape* s;
- const Dimension* unused_dim;
+ ShapeHandle unused;
+ ShapeHandle s;
+ DimensionHandle unused_dim;
// Validate filename.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
@@ -136,7 +136,7 @@ REGISTER_OP("Restore")
.Attr("dt: type")
.Attr("preferred_shard: int = -1")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
c->set_output(0, c->UnknownShape());
@@ -180,7 +180,7 @@ REGISTER_OP("RestoreSlice")
.Attr("dt: type")
.Attr("preferred_shard: int = -1")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
@@ -353,11 +353,11 @@ REGISTER_OP("ReaderReadUpTo")
.Output("keys: string")
.Output("values: string")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- const Shape* out = c->Vector(InferenceContext::kUnknownDim);
+ ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, out);
c->set_output(1, out);
return Status::OK();
@@ -451,7 +451,7 @@ REGISTER_OP("MatchingFiles")
.Input("pattern: string")
.Output("filenames: string")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 54b8e22b7e..d7582bd4cb 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -18,62 +18,61 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
// Return in <out> the result of making <s> a square matrix.
-Status MakeSquareMatrix(InferenceContext* c, const Shape* s,
- const Shape** out) {
+Status MakeSquareMatrix(InferenceContext* c, ShapeHandle s, ShapeHandle* out) {
TF_RETURN_IF_ERROR(c->WithRank(s, 2, &s));
- const Dimension* d;
+ DimensionHandle d;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, 0), c->Dim(s, 1), &d));
*out = c->Matrix(d, d);
return Status::OK();
}
Status UnchangedSquareShapeFn(InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &out));
c->set_output(0, out);
return Status::OK();
}
// Return in <out> the result of making the end of <s> a square matrix.
-Status MakeBatchSquareMatrix(InferenceContext* c, const Shape* input,
- const Shape** out) {
- const Shape* s;
+Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
+ ShapeHandle* out) {
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
- const Dimension* d;
+ DimensionHandle d;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
- const Shape* batch_shape;
+ ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
return Status::OK();
}
Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
c->set_output(0, out);
return Status::OK();
}
Status SquareMatrixSolveShapeFn(InferenceContext* c) {
- const Shape* lhs;
- const Shape* rhs;
+ ShapeHandle lhs;
+ ShapeHandle rhs;
TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &lhs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs));
// lhs and rhs have the same number of rows. Make a new output
// shape that uses rows to replace rhs.dim[0].
- const Dimension* rows;
+ DimensionHandle rows;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &rows));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->ReplaceDim(rhs, 0, rows, &out));
c->set_output(0, out);
return Status::OK();
@@ -82,8 +81,8 @@ Status SquareMatrixSolveShapeFn(InferenceContext* c) {
// Inputs are [...,M,N] and [...,M,K]. Output is [...,N,K].
// If <square>, then input is [...,M,M].
Status BatchMatrixSolveShapeFn(InferenceContext* c, bool square) {
- const Shape* lhs;
- const Shape* rhs;
+ ShapeHandle lhs;
+ ShapeHandle rhs;
if (square) {
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
} else {
@@ -92,44 +91,44 @@ Status BatchMatrixSolveShapeFn(InferenceContext* c, bool square) {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
// Make the common batch subshape between the two dimensions.
- const Shape* lhs_batch_shape;
- const Shape* batch_shape;
+ ShapeHandle lhs_batch_shape;
+ ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &batch_shape));
TF_RETURN_IF_ERROR(c->Merge(lhs_batch_shape, batch_shape, &batch_shape));
// lhs and rhs have the same value for m.
- const Dimension* m;
+ DimensionHandle m;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
- const Dimension* n = c->Dim(lhs, -1);
+ DimensionHandle n = c->Dim(lhs, -1);
if (square) {
TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
}
// Build final shape (batch_shape + n + k) in <out>.
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &out));
TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
c->set_output(0, out);
return Status::OK();
}
-Status BatchSvdShapeHelperFn(InferenceContext* c, const Shape* input) {
- const Dimension* m = c->Dim(input, -2);
- const Dimension* n = c->Dim(input, -1);
- const Dimension* p;
+Status BatchSvdShapeHelperFn(InferenceContext* c, ShapeHandle input) {
+ DimensionHandle m = c->Dim(input, -2);
+ DimensionHandle n = c->Dim(input, -1);
+ DimensionHandle p;
TF_RETURN_IF_ERROR(c->Min(m, n, &p));
- const Shape* batch_shape;
+ ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
- const Shape* e_shape;
+ ShapeHandle e_shape;
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
c->set_output(0, e_shape);
bool compute_uv;
TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
if (compute_uv) {
- const Shape* u_shape;
- const Shape* v_shape;
+ ShapeHandle u_shape;
+ ShapeHandle v_shape;
bool full_matrices;
TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
if (full_matrices) {
@@ -159,7 +158,7 @@ Status BatchSvdShapeHelperFn(InferenceContext* c, const Shape* input) {
// [M,P]; [N,P], if compute_uv is true and full_matrices is false,
// where P = min(M,N).
Status SvdShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
return BatchSvdShapeHelperFn(c, input);
}
@@ -171,7 +170,7 @@ Status SvdShapeFn(InferenceContext* c) {
// [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
// where P = min(M,N).
Status BatchSvdShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
return BatchSvdShapeHelperFn(c, input);
}
@@ -180,9 +179,9 @@ Status BatchSvdShapeFn(InferenceContext* c) {
// [N];[0], if compute_v is false,
// [N];[N,N], if compute_v is true.
Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input));
- const Dimension* n;
+ DimensionHandle n;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, 0), c->Dim(input, 1), &n));
c->set_output(0, c->Vector(n));
bool compute_v;
@@ -199,19 +198,19 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
// [...,N];[0], if compute_v is false,
// [...,N];[...,N,N], if compute_v is true.
Status BatchSelfAdjointEigV2ShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
- const Dimension* n;
+ DimensionHandle n;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
- const Shape* batch_shape;
+ ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
- const Shape* e_shape;
+ ShapeHandle e_shape;
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
c->set_output(0, e_shape);
bool compute_v;
TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
if (compute_v) {
- const Shape* v_shape;
+ ShapeHandle v_shape;
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
c->set_output(1, v_shape);
} else {
@@ -227,7 +226,7 @@ REGISTER_OP("MatrixDeterminant")
.Output("output: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input));
c->set_output(0, c->Scalar());
return Status::OK();
@@ -244,14 +243,14 @@ REGISTER_OP("BatchMatrixDeterminant")
.Output("output: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
c->set_output(0, out);
return Status::OK();
@@ -395,11 +394,11 @@ REGISTER_OP("SelfAdjointEig")
.Attr("T: {double, float}")
.Deprecated(11, "Use SelfAdjointEigV2 instead.")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input));
- const Dimension* d = c->Dim(input, 0);
- const Dimension* d_plus_1;
+ DimensionHandle d = c->Dim(input, 0);
+ DimensionHandle d_plus_1;
TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
c->set_output(0, c->Matrix(d_plus_1, d));
return Status::OK();
@@ -423,14 +422,14 @@ REGISTER_OP("BatchSelfAdjointEig")
.Attr("T: {double, float}")
.Deprecated(11, "Use BatchSelfAdjointEigV2 instead.")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
- const Dimension* d = c->Dim(input, -1);
- const Dimension* d_plus_1;
+ DimensionHandle d = c->Dim(input, -1);
+ DimensionHandle d_plus_1;
TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
c->set_output(0, s);
@@ -627,13 +626,13 @@ REGISTER_OP("MatrixSolveLs")
.Attr("T: {double, float}")
.Attr("fast: bool = True")
.SetShapeFn([](InferenceContext* c) {
- const Shape* lhs;
- const Shape* rhs;
+ ShapeHandle lhs;
+ ShapeHandle rhs;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &lhs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs));
// The matrix and right-hand side must have the same number of rows.
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &unused));
c->set_output(0, c->Matrix(c->Dim(lhs, 1), c->Dim(rhs, 1)));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 6b5c450f0e..4d42ab4b47 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -20,9 +20,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("AddN")
.Input("inputs: N * T")
@@ -32,7 +32,7 @@ REGISTER_OP("AddN")
.SetIsCommutative()
.SetIsAggregate()
.SetShapeFn([](InferenceContext* c) {
- const Shape* cur = c->input(c->num_inputs() - 1);
+ ShapeHandle cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
"From merging shape ", i,
@@ -51,8 +51,8 @@ namespace {
// Shape inference function for binary operators that broadcast their inputs.
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
- const Shape* shape_x = c->input(0);
- const Shape* shape_y = c->input(1);
+ ShapeHandle shape_x = c->input(0);
+ ShapeHandle shape_y = c->input(1);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
@@ -64,8 +64,8 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// To compute the broadcast dimensions, we zip together shape_x and shape_y
// and
// pad with 1 to make them the same length.
- std::vector<const Dimension*> dims;
- const Dimension* dim_one = rank_x == rank_y ? nullptr : c->MakeDim(1);
+ std::vector<DimensionHandle> dims;
+ DimensionHandle dim_one = rank_x == rank_y ? nullptr : c->MakeDim(1);
for (int i = 0; i < rank_out; ++i) {
const auto* dim_x = i < (rank_out - rank_x)
? dim_one
@@ -103,7 +103,7 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
dims.push_back(dim_x);
}
} else {
- const Dimension* dim;
+ DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
dims.push_back(dim);
}
@@ -125,8 +125,8 @@ REGISTER_OP("BatchMatMul")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn([](InferenceContext* c) {
- const Shape* a_shape;
- const Shape* b_shape;
+ ShapeHandle a_shape;
+ ShapeHandle b_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &a_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 3, &b_shape));
@@ -135,23 +135,23 @@ REGISTER_OP("BatchMatMul")
bool adj_y;
TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
- const Dimension* output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
- const Dimension* output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
+ DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
+ DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
// Batch dims match between inputs.
- const Shape* a_batch_dims;
- const Shape* b_batch_dims;
- const Shape* batch_dims;
+ ShapeHandle a_batch_dims;
+ ShapeHandle b_batch_dims;
+ ShapeHandle batch_dims;
TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
// Assert inner dims match.
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
c->Dim(b_shape, adj_y ? -1 : -2), &unused));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(
batch_dims, c->Matrix(output_rows, output_cols), &out));
c->set_output(0, out);
@@ -814,8 +814,8 @@ REGISTER_OP("Select")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* cond = c->input(0);
- const Shape* data = c->input(1);
+ ShapeHandle cond = c->input(0);
+ ShapeHandle data = c->input(1);
TF_RETURN_IF_ERROR(c->Merge(data, c->input(2), &data));
// Validate condition's shape if possible.
@@ -830,12 +830,12 @@ REGISTER_OP("Select")
if (c->Rank(cond) == 1) {
// Must be a vector whose first dimension matches first dimension
// of the data vectors.
- const Dimension* merged_dim;
+ DimensionHandle merged_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(data, 0), c->Dim(cond, 0), &merged_dim));
if (merged_dim != c->Dim(data, 0)) {
// Merging used the cond dim. Update data to refer to it.
- std::vector<const Dimension*> dims{merged_dim};
+ std::vector<DimensionHandle> dims{merged_dim};
for (int i = 1; i < data_rank; ++i) {
dims.push_back(c->Dim(data, i));
}
@@ -1075,10 +1075,10 @@ output: The reduced tensor.
namespace {
Status ArgOpShape(shape_inference::InferenceContext* c) {
- const Shape* dimension_shape;
+ ShapeHandle dimension_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
- const Shape* input_shape = c->input(0);
+ ShapeHandle input_shape = c->input(0);
if (!c->RankKnown(input_shape)) {
return shape_inference::UnknownShape(c);
}
@@ -1094,7 +1094,7 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
// We don't know the value of the dimension, but we
// know the rank of the input, so return the correct
// rank with unknown dimensions.
- std::vector<const Dimension*> dims(input_rank - 1);
+ std::vector<DimensionHandle> dims(input_rank - 1);
for (int i = 0; i < dims.size(); ++i) {
dims[i] = c->UnknownDim();
}
@@ -1112,7 +1112,7 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
}
// Return the input shape without the dimension being reduced.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < input_rank; ++i) {
if (dimension_val != i) {
dims.emplace_back(c->Dim(input_shape, i));
@@ -1153,15 +1153,15 @@ dimension: int32, 0 <= dimension < rank(input). Describes which dimension
namespace {
Status SegmentReductionShapeFn(InferenceContext* c) {
- const Shape* data_shape;
- const Shape* segment_ids_shape;
+ ShapeHandle data_shape;
+ ShapeHandle segment_ids_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
- const Shape* subshape;
+ ShapeHandle subshape;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(
c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
c->set_output(0, out);
@@ -1169,23 +1169,23 @@ Status SegmentReductionShapeFn(InferenceContext* c) {
}
Status SparseSegmentReductionShapeFn(InferenceContext* c) {
- const Shape* data_shape;
+ ShapeHandle data_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
- const Shape* indices_shape;
+ ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
- const Shape* segment_ids_shape;
+ ShapeHandle segment_ids_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
// indices and segment_ids should merge cleanly.
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
- const Shape* subshape;
+ ShapeHandle subshape;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(
c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
c->set_output(0, out);
@@ -1193,24 +1193,24 @@ Status SparseSegmentReductionShapeFn(InferenceContext* c) {
}
Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
- const Shape* data_shape;
+ ShapeHandle data_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
- const Shape* indices_shape;
+ ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
// indices and segment_ids should merge cleanly.
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
// output_dim0 should be a scalar
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
- const Shape* subshape;
+ ShapeHandle subshape;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
const Tensor* dim0 = c->input_tensor(3);
- const Shape* dim0_shape;
+ ShapeHandle dim0_shape;
if (dim0 == nullptr) {
// We don't have the value at inference time, so the output
// shape is unknown.
@@ -1224,7 +1224,7 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
dim0_shape = c->Vector(dim0_value);
}
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
c->set_output(0, out);
return Status::OK();
@@ -1384,12 +1384,12 @@ REGISTER_OP("UnsortedSegmentSum")
.Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* s_data = c->input(0);
- const Shape* s_segment_ids = c->input(1);
- const Shape* s_num_segments = c->input(2);
+ ShapeHandle s_data = c->input(0);
+ ShapeHandle s_segment_ids = c->input(1);
+ ShapeHandle s_num_segments = c->input(2);
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
- const Shape* out;
+ ShapeHandle out;
// Leading dimensions of data must be compatible with dimensions of
// <s_segment_ids>.
@@ -1398,11 +1398,11 @@ REGISTER_OP("UnsortedSegmentSum")
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
// Get the value of the num_segments input tensor.
- const Dimension* num_segments_dim;
+ DimensionHandle num_segments_dim;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
// Output is {segment_id_rank} + s_data[segment_id_rank:].
- const Shape* s_data_suffix;
+ ShapeHandle s_data_suffix;
TF_RETURN_IF_ERROR(
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
TF_RETURN_IF_ERROR(
@@ -1629,7 +1629,7 @@ REGISTER_OP("Range")
.Input("delta: int32")
.Output("output: int32")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
" for 'start'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
@@ -1685,7 +1685,7 @@ REGISTER_OP("LinSpace")
.Output("output: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
" for 'start'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 38556f0e35..affbd26966 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -22,9 +22,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
@@ -33,7 +33,7 @@ namespace {
// unknown dims.
Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
int ndims) {
- const Shape* out;
+ ShapeHandle out;
const Tensor* input = c->input_tensor(input_idx);
if (input == nullptr) {
out = c->UnknownShapeOfRank(ndims);
@@ -122,17 +122,17 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Dimension* last_dim = c->Dim(input, 3);
+ DimensionHandle last_dim = c->Dim(input, 3);
for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
- const Shape* vec;
+ ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
}
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
c->set_output(0, out);
return Status::OK();
@@ -175,23 +175,23 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
TF_RETURN_IF_ERROR(
c->Merge(input, c->input(4), &input)); // with backprop
- const Dimension* last_dim = c->Dim(input, 3);
+ DimensionHandle last_dim = c->Dim(input, 3);
for (int i = 1; i < 4; ++i) { // covers m, v, gamma
- const Shape* vec;
+ ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
}
- const Shape* dx;
+ ShapeHandle dx;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
c->set_output(0, dx);
- const Shape* vector_shape = c->Vector(last_dim);
+ ShapeHandle vector_shape = c->Vector(last_dim);
c->set_output(1, vector_shape);
c->set_output(2, vector_shape);
c->set_output(3, vector_shape);
@@ -586,7 +586,7 @@ REGISTER_OP("Conv3DBackpropFilter")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropFilterV2")
.SetShapeFn([](InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
c->set_output(0, out);
return Status::OK();
@@ -614,7 +614,7 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
c->set_output(0, s);
@@ -645,7 +645,7 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
c->set_output(0, s);
@@ -698,7 +698,7 @@ REGISTER_OP("AvgPool3DGrad")
.Attr(GetPaddingAttrString())
.Attr("T: numbertype")
.SetShapeFn([](InferenceContext* c) {
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
c->set_output(0, s);
@@ -829,7 +829,7 @@ REGISTER_OP("LRNGrad")
.Attr("beta: float = 0.5")
.Attr("T: {float, half} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
@@ -975,9 +975,9 @@ REGISTER_OP("Dilation2D")
.Attr("rates: list(int) >= 4")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
- const Shape* input_shape;
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
- const Shape* filter_shape;
+ ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
std::vector<int32> strides;
@@ -1004,14 +1004,14 @@ REGISTER_OP("Dilation2D")
int32 rate_rows = rates[1];
int32 rate_cols = rates[2];
- const Dimension* batch_size_dim = c->Dim(input_shape, 0);
- const Dimension* in_rows_dim = c->Dim(input_shape, 1);
- const Dimension* in_cols_dim = c->Dim(input_shape, 2);
- const Dimension* filter_rows_dim = c->Dim(filter_shape, 0);
- const Dimension* filter_cols_dim = c->Dim(filter_shape, 1);
- const Dimension* output_depth_dim = c->Dim(filter_shape, 2);
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
+ DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
@@ -1040,7 +1040,7 @@ REGISTER_OP("Dilation2D")
in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
&padding_before, &padding_after));
- const Shape* output_shape = c->MakeShape(
+ ShapeHandle output_shape = c->MakeShape(
{batch_size_dim, output_rows, output_cols, output_depth_dim});
c->set_output(0, output_shape);
return Status::OK();
@@ -1305,11 +1305,11 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
.Output("backprop: T")
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input));
- const Dimension* batch_size = c->Dim(input, 0);
+ DimensionHandle batch_size = c->Dim(input, 0);
c->set_output(0, c->Vector(batch_size));
c->set_output(1, input);
return Status::OK();
@@ -1335,12 +1335,12 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
.Attr("T: {half, float, double}")
.Attr("Tlabels: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
- const Shape* features;
- const Shape* labels;
+ ShapeHandle features;
+ ShapeHandle labels;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
- const Dimension* batch_size;
+ DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
@@ -1375,11 +1375,11 @@ REGISTER_OP("InTopK")
.Attr("k: int")
.Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
- const Shape* predictions;
- const Shape* targets;
+ ShapeHandle predictions;
+ ShapeHandle targets;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
- const Dimension* batch_size;
+ DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
c->set_output(0, c->Vector(batch_size));
@@ -1413,11 +1413,11 @@ precision: Computed Precision at `k` as a `bool Tensor`.
namespace {
Status TopKShapeFn(InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
// Get the k value, either from input tensor or attribute.
- const Dimension* k_dim;
+ DimensionHandle k_dim;
if (c->num_inputs() >= 2) {
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
} else {
@@ -1429,7 +1429,7 @@ Status TopKShapeFn(InferenceContext* c) {
k_dim = c->MakeDim(k);
}
- const Dimension* last_dim = c->Dim(input, -1);
+ DimensionHandle last_dim = c->Dim(input, -1);
if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
c->Value(last_dim) < c->Value(k_dim)) {
return errors::InvalidArgument("input must have last dimension >= k = ",
@@ -1438,7 +1438,7 @@ Status TopKShapeFn(InferenceContext* c) {
}
// Replace last_dim with k_dim.
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
c->set_output(0, s);
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 1b1af7d68a..0dc743ee3b 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -20,9 +20,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("DecodeRaw")
.Input("bytes: string")
@@ -31,7 +31,7 @@ REGISTER_OP("DecodeRaw")
.Attr("little_endian: bool = true")
.SetShapeFn([](InferenceContext* c) {
// Note: last dimension is data dependent.
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(
c->input(0), c->Vector(InferenceContext::kUnknownDim), &out));
c->set_output(0, out);
@@ -68,9 +68,9 @@ REGISTER_OP("ParseExample")
ParseSingleExampleAttrs attrs;
TF_RETURN_IF_ERROR(attrs.Init(c));
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // names
// Output sparse_indices, sparse_values, and sparse_shapes.
@@ -89,7 +89,7 @@ REGISTER_OP("ParseExample")
TensorShapeProto shape_proto;
for (int i = 0; i < attrs.num_dense; ++i) {
attrs.dense_shapes[i].AsProto(&shape_proto);
- const Shape* dense;
+ ShapeHandle dense;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &dense));
TF_RETURN_IF_ERROR(c->Concatenate(input, dense, &dense));
c->set_output(output_idx++, dense);
@@ -161,11 +161,11 @@ REGISTER_OP("ParseSingleSequenceExample")
.Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []")
.Attr("feature_list_dense_shapes: list(shape) >= 0 = []")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
ParseSingleSequenceExampleAttrs attrs;
TF_RETURN_IF_ERROR(attrs.Init(c));
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
// feature_list_dense_missing_assumed_empty
@@ -189,7 +189,7 @@ REGISTER_OP("ParseSingleSequenceExample")
TensorShapeProto shape_proto;
for (int i = 0; i < attrs.num_context_dense; ++i) {
attrs.context_dense_shapes[i].AsProto(&shape_proto);
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &s));
c->set_output(output_idx++, s);
}
@@ -209,7 +209,7 @@ REGISTER_OP("ParseSingleSequenceExample")
// Output feature_list_dense_shapes.
for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
attrs.feature_list_dense_shapes[i].AsProto(&shape_proto);
- const Shape* s;
+ ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &s));
TF_RETURN_IF_ERROR(
c->Concatenate(c->Vector(InferenceContext::kUnknownDim), s, &s));
@@ -312,7 +312,7 @@ REGISTER_OP("DecodeCSV")
.SetShapeFn([](InferenceContext* c) {
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
- const Shape* v;
+ ShapeHandle v;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
if (c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 5d648a6a7e..776523f33f 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -19,14 +19,14 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
Status RandomShape(InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
@@ -217,9 +217,9 @@ REGISTER_OP("Multinomial")
.Attr("seed2: int = 0")
.Attr("T: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
- const Shape* logits_shape;
- const Shape* unused;
- const Dimension* num_samples;
+ ShapeHandle logits_shape;
+ ShapeHandle unused;
+ DimensionHandle num_samples;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
@@ -249,7 +249,7 @@ REGISTER_OP("RandomGamma")
.Attr("S: {int32, int64}")
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
c->set_output(0, out);
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 17d5983d76..f4544062b7 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -19,14 +19,14 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
namespace {
Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape
@@ -50,8 +50,8 @@ REGISTER_OP("SparseAddGrad")
.Output("b_val_grad: T")
.Attr("T: numbertype")
.SetShapeFn([](InferenceContext* c) {
- const Shape* a_indices;
- const Shape* b_indices;
+ ShapeHandle a_indices;
+ ShapeHandle b_indices;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
@@ -92,7 +92,7 @@ REGISTER_OP("SparseAdd")
.Attr("T: numbertype")
.Attr("Treal: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
- const Shape* a_shape;
+ ShapeHandle a_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
c->set_output(
0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
@@ -137,10 +137,10 @@ REGISTER_OP("SparseTensorDenseMatMul")
.Attr("adjoint_a: bool = false")
.Attr("adjoint_b: bool = false")
.SetShapeFn([](InferenceContext* c) {
- const Dimension* unused_dim;
- const Shape* unused;
- const Shape* b;
- const Shape* a_shape;
+ DimensionHandle unused_dim;
+ ShapeHandle unused;
+ ShapeHandle b;
+ ShapeHandle a_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
@@ -152,7 +152,7 @@ REGISTER_OP("SparseTensorDenseMatMul")
// TODO(zongheng): 1) incorporate adjoint_a. 2) When both attrs are
// considered, check the inner dimensions match.
- const Dimension* output_right = c->Dim(b, adjoint_b ? 0 : 1);
+ DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, output_right));
return Status::OK();
})
@@ -186,7 +186,7 @@ REGISTER_OP("SerializeSparse")
.Attr("T: type")
.Output("serialized_sparse: string")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
@@ -208,7 +208,7 @@ REGISTER_OP("SerializeManySparse")
.Attr("T: type")
.Output("serialized_sparse: string")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
@@ -239,9 +239,9 @@ REGISTER_OP("DeserializeManySparse")
.Output("sparse_shape: int64")
.SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?,3] matrix.
- const Shape* serialized_sparse;
+ ShapeHandle serialized_sparse;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
@@ -311,7 +311,7 @@ REGISTER_OP("SparseToDense")
.Output("dense: T")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
c->set_output(0, out);
return Status::OK();
@@ -363,23 +363,23 @@ REGISTER_OP("SparseConcat")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
// These accumulates the sum.
- const Dimension* output_row_count = c->MakeDim(0ll);
+ DimensionHandle output_row_count = c->MakeDim(0ll);
// These are only merged.
- const Dimension* output_ind_cols = c->UnknownDim();
- const Shape* output_shape = c->UnknownShape();
+ DimensionHandle output_ind_cols = c->UnknownDim();
+ ShapeHandle output_shape = c->UnknownShape();
const int n = c->num_inputs() / 3;
for (int i = 0; i < n; i++) {
- const Shape* ind;
+ ShapeHandle ind;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
- const Shape* val;
+ ShapeHandle val;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
- const Shape* shape;
+ ShapeHandle shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
// Add to output_ind_rows.
- const Dimension* num_dim;
+ DimensionHandle num_dim;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
TF_RETURN_IF_ERROR(
c->Add(output_row_count, num_dim, &output_row_count));
@@ -460,11 +460,11 @@ REGISTER_OP("SparseSplit")
.Attr("num_split: int >= 1")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input_shape = c->input(3);
- const Shape* output_indices =
+ ShapeHandle input_shape = c->input(3);
+ ShapeHandle output_indices =
c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
- const Shape* output_values = c->Vector(InferenceContext::kUnknownDim);
- const Shape* output_shape = input_shape;
+ ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
+ ShapeHandle output_shape = input_shape;
// Copy the outputs into the output ranges.
int num_splits = c->num_outputs() / 3;
@@ -520,9 +520,9 @@ REGISTER_OP("SparseReorder")
.Output("output_values: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- const Shape* indices;
- const Shape* values;
- const Shape* unused;
+ ShapeHandle indices;
+ ShapeHandle values;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
@@ -560,9 +560,9 @@ REGISTER_OP("SparseReshape")
.Output("output_indices: int64")
.Output("output_shape: int64")
.SetShapeFn([](InferenceContext* c) {
- const Shape* indices;
- const Shape* unused;
- const Shape* new_shape;
+ ShapeHandle indices;
+ ShapeHandle unused;
+ ShapeHandle new_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
@@ -670,7 +670,7 @@ output: `R-K`-D. The reduced Tensor.
.Output("output: T") \
.Attr("T: numbertype") \
.SetShapeFn([](InferenceContext* c) { \
- const Shape* input; \
+ ShapeHandle input; \
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &input)); \
c->set_output(0, c->Vector(c->Dim(input, 0))); \
return Status::OK(); \
@@ -737,8 +737,8 @@ REGISTER_OP("SparseSoftmax")
.Output("output: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unused;
- const Shape* values;
+ ShapeHandle unused;
+ ShapeHandle values;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 684e86a00d..cc0c652107 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("Variable")
.Output("ref: Ref(dtype)")
@@ -69,7 +69,7 @@ REGISTER_OP("TemporaryVariable")
.SetShapeFn([](InferenceContext* c) {
TensorShapeProto shape_proto;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto));
- const Shape* output;
+ ShapeHandle output;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &output));
c->set_output(0, output);
return Status::OK();
@@ -201,12 +201,12 @@ output_ref:= Same as "ref". Returned as a convenience for operations that want
namespace {
Status ScatterUpdateShape(InferenceContext* c) {
- const Shape* var_shape = c->input(0);
- const Shape* indices_shape = c->input(1);
+ ShapeHandle var_shape = c->input(0);
+ ShapeHandle indices_shape = c->input(1);
- const Shape* unused_updates_shape;
- const Shape* concat;
- const Shape* var_subshape;
+ ShapeHandle unused_updates_shape;
+ ShapeHandle concat;
+ ShapeHandle var_subshape;
TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
@@ -354,7 +354,7 @@ REGISTER_OP("CountUpTo")
.Attr("limit: int")
.Attr("T: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
- const Shape* output;
+ ShapeHandle output;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &output));
c->set_output(0, output);
return Status::OK();
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 3b9f96e496..dd4cb12f5d 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -19,9 +19,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
@@ -180,7 +180,7 @@ REGISTER_OP("StringJoin")
// Merge the non-scalars to find the output shape.
// Don't merge inputs with unknown rank, as they can actually be scalars
// or the output shape.
- const Shape* out = c->UnknownShape();
+ ShapeHandle out = c->UnknownShape();
for (int i = 0; i < c->num_inputs(); ++i) {
if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) {
TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out));
@@ -206,7 +206,7 @@ REGISTER_OP("StringSplit")
.Output("values: string")
.Output("shape: int64")
.SetShapeFn([](InferenceContext* c) {
- const Shape* unsed_shape;
+ ShapeHandle unsed_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unsed_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unsed_shape));
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 04f3058651..ab82617a13 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -18,28 +18,28 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
// Handle the gradient and, if <sparse>, indices inputs.
// <s> is an input+output parameter, containing the current known input shape to
// the gradient.
static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
- int grad_idx, const Shape** s) {
- const Shape* grad = c->input(grad_idx);
+ int grad_idx, ShapeHandle* s) {
+ ShapeHandle grad = c->input(grad_idx);
if (!sparse) {
TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
return Status::OK();
}
// Indices is a vector where indices.dim[0].rank == grad[0].rank.
- const Shape* indices;
+ ShapeHandle indices;
TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
- const Dimension* unused;
+ DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
// Trailing part of grad matches *s.
- const Shape* grad_subshape;
+ ShapeHandle grad_subshape;
TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape));
TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s));
@@ -47,8 +47,8 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
}
static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // delta
c->set_output(0, s);
@@ -76,8 +76,8 @@ use_locking: If `True`, the subtraction will be protected by a lock;
static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2
@@ -146,8 +146,8 @@ use_locking: If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // accum update
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
@@ -224,8 +224,8 @@ a lock; otherwise the behavior is undefined, but may exhibit less contention.
)doc");
static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(
@@ -261,8 +261,8 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
contention.
)doc");
static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1
@@ -335,8 +335,8 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
)doc");
static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // grad_accumulator
TF_RETURN_IF_ERROR(
c->Merge(s, c->input(2), &s)); // gradient_squared_accumulator
@@ -453,8 +453,8 @@ a lock; otherwise the behavior is undefined, but may exhibit less contention.
)doc");
static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // linear
TF_RETURN_IF_ERROR(
@@ -549,8 +549,8 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
)doc");
static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(
@@ -635,8 +635,8 @@ var - lr * momentum * accum.
)doc");
static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // m
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // v
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
@@ -693,8 +693,8 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
)doc");
static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
- const Shape* unused;
- const Shape* s = c->input(0); // var
+ ShapeHandle unused;
+ ShapeHandle s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // ms
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // mom
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr