aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 16:35:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 17:54:30 -0700
commitf6bc8cabbd3ac1fb3acc36d3edbdce672cae7d12 (patch)
treea8c81b0269e68f57606052ab5573743f86995be6 /tensorflow/core/framework/shape_inference.cc
parentade1672d60d861c58e1930e93a1b396b22e7a4d9 (diff)
Add shape_inference::ShapeHandle and shape_inference::DimensionHandle to
replace uses of const Shape* and const Dimension*. This change only adds a typedef and updates references. A later change will make DimensionHandle and ShapeHandle real types instead of typedefs (to further hide the pointer access). Change: 131118981
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc139
1 files changed, 69 insertions, 70 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index c6da445165..90b3a6a688 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -35,7 +35,7 @@ InferenceContext::InferenceContext(
PreInputInit(op_def, input_tensors);
for (const string& spec : input_shapes) {
- const Shape* shape;
+ ShapeHandle shape;
construction_status_.Update(MakeShapeFromString(spec, &shape));
if (!construction_status_.ok()) {
return;
@@ -55,7 +55,7 @@ InferenceContext::InferenceContext(
PreInputInit(op_def, input_tensors);
if (!construction_status_.ok()) return;
for (const TensorShapeProto& p : input_shapes) {
- const Shape* shape;
+ ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
if (!construction_status_.ok()) {
return;
@@ -68,7 +68,7 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
- const std::vector<const Shape*>& input_shapes,
+ const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors);
@@ -118,7 +118,7 @@ void InferenceContext::PostInputInit() {
requested_input_tensor_.resize(inputs_.size());
}
-bool InferenceContext::FullyDefined(const Shape* s) {
+bool InferenceContext::FullyDefined(ShapeHandle s) {
if (!RankKnown(s)) return false;
for (int i = 0; i < Rank(s); ++i) {
if (!ValueKnown(Dim(s, i))) return false;
@@ -126,7 +126,7 @@ bool InferenceContext::FullyDefined(const Shape* s) {
return true;
}
-const Dimension* InferenceContext::NumElements(const Shape* s) {
+DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
const auto rank = Rank(s);
if (rank == kUnknownRank) return UnknownDim();
int64 size = 1;
@@ -138,7 +138,7 @@ const Dimension* InferenceContext::NumElements(const Shape* s) {
return MakeDim(size);
}
-string InferenceContext::DebugString(const Shape* s) {
+string InferenceContext::DebugString(ShapeHandle s) {
if (RankKnown(s)) {
std::vector<string> vals;
for (auto d : s->dims_) vals.push_back(DebugString(d));
@@ -148,19 +148,19 @@ string InferenceContext::DebugString(const Shape* s) {
}
}
-string InferenceContext::DebugString(const Dimension* d) {
+string InferenceContext::DebugString(DimensionHandle d) {
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
}
-Status InferenceContext::WithRank(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
return Status::OK();
}
if (existing == kUnknownRank) {
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(rank);
for (int i = 0; i < rank; ++i) {
all_dims_.push_back(new Dimension());
@@ -175,8 +175,8 @@ Status InferenceContext::WithRank(const Shape* shape, int32 rank,
existing);
}
-Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing >= rank) {
*out = shape;
@@ -190,8 +190,8 @@ Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank,
- const Shape** out) {
+Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank,
+ ShapeHandle* out) {
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
@@ -205,8 +205,8 @@ Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithValue(const Dimension* dim, int64 value,
- const Dimension** out) {
+Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
+ DimensionHandle* out) {
const int64 existing = Value(dim);
if (existing == value) {
*out = dim;
@@ -222,8 +222,8 @@ Status InferenceContext::WithValue(const Dimension* dim, int64 value,
existing);
}
-Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
- const Dimension** out) {
+Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
+ DimensionHandle* out) {
if (d0 == d1 || !ValueKnown(d1)) {
*out = d0;
return Status::OK();
@@ -240,9 +240,9 @@ Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
}
}
-Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
- const Shape** s_out,
- const Shape** prefix_out) {
+Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
+ ShapeHandle* s_out,
+ ShapeHandle* prefix_out) {
*s_out = *prefix_out = nullptr;
if (!RankKnown(prefix) || !RankKnown(s)) {
*s_out = s;
@@ -253,7 +253,7 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
// Merge the prefix dims and create the new output shapes.
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.resize(rank);
for (int i = 0; i < rank; ++i) {
TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
@@ -264,8 +264,8 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix,
return Status::OK();
}
-Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
- const Shape** out) {
+Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
+ ShapeHandle* out) {
if (s0 == s1 || !RankKnown(s1)) {
*out = s0;
return Status::OK();
@@ -309,7 +309,7 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
}
// Merge dims.
- std::vector<const Dimension*> dims(rank, nullptr);
+ std::vector<DimensionHandle> dims(rank, nullptr);
for (int i = 0; i < rank; ++i) {
// Invariant for merge was checked earlier, so CHECK is ok.
TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
@@ -317,13 +317,13 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::Subshape(const Shape* s, int64 start,
- const Shape** out) {
+Status InferenceContext::Subshape(ShapeHandle s, int64 start,
+ ShapeHandle* out) {
return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
}
-Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
- const Shape** out) {
+Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
+ ShapeHandle* out) {
int64 start = start_in;
int64 end = end_in;
const int32 rank = Rank(s);
@@ -362,7 +362,7 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
end, " (computed from start ", start_in, " and end ", end_in,
" over shape with rank ", rank, ")");
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(end - start);
for (int i = start; i < end; ++i) {
dims.push_back(Dim(s, i));
@@ -370,24 +370,23 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::Concatenate(const Shape* s1, const Shape* s2,
- const Shape** out) {
+Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
+ ShapeHandle* out) {
if (!RankKnown(s1) || !RankKnown(s2)) {
return ReturnUnknownShape(out);
}
const int32 s1_rank = Rank(s1);
const int32 s2_rank = Rank(s2);
const int32 rank = s1_rank + s2_rank;
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(rank);
for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
- const Dimension* new_dim,
- const Shape** out) {
+Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in,
+ DimensionHandle new_dim, ShapeHandle* out) {
if (!RankKnown(s)) {
return ReturnUnknownShape(out);
}
@@ -401,20 +400,20 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
" for shape with ", s->dims_.size(),
" dimensions");
}
- std::vector<const Dimension*> dims(s->dims_);
+ std::vector<DimensionHandle> dims(s->dims_);
dims[dim_index] = new_dim;
return ReturnCreatedShape(dims, out);
}
-const Shape* InferenceContext::MakeShape(
- const std::vector<const Dimension*>& dims) {
+ShapeHandle InferenceContext::MakeShape(
+ const std::vector<DimensionHandle>& dims) {
all_shapes_.push_back(new Shape(dims));
return all_shapes_.back();
}
-const Shape* InferenceContext::MakeShape(
+ShapeHandle InferenceContext::MakeShape(
std::initializer_list<DimensionOrConstant> dims) {
- std::vector<const Dimension*> dims_actual;
+ std::vector<DimensionHandle> dims_actual;
dims_actual.reserve(dims.size());
for (const DimensionOrConstant& d : dims) {
dims_actual.push_back(MakeDim(d));
@@ -422,45 +421,45 @@ const Shape* InferenceContext::MakeShape(
return MakeShape(dims_actual);
}
-const Shape* InferenceContext::UnknownShape() {
+ShapeHandle InferenceContext::UnknownShape() {
all_shapes_.push_back(new Shape());
return all_shapes_.back();
}
-const Shape* InferenceContext::UnknownShapeOfRank(int32 rank) {
- std::vector<const Dimension*> dims(rank);
+ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) {
+ std::vector<DimensionHandle> dims(rank);
for (int32 i = 0; i < rank; ++i) {
dims[i] = UnknownDim();
}
return MakeShape(dims);
}
-const Shape* InferenceContext::Scalar() { return MakeShape({}); }
+ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
-const Shape* InferenceContext::Vector(DimensionOrConstant dim) {
+ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
return MakeShape({dim});
}
-const Shape* InferenceContext::Matrix(DimensionOrConstant dim1,
- DimensionOrConstant dim2) {
+ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
+ DimensionOrConstant dim2) {
return MakeShape({dim1, dim2});
}
Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
- const Shape** out) {
- const Shape* input_shape;
+ ShapeHandle* out) {
+ ShapeHandle input_shape;
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
const Tensor* t = input_tensor(input_idx);
if (t == nullptr) {
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
- const Dimension* shape_dim = Dim(input_shape, 0);
+ DimensionHandle shape_dim = Dim(input_shape, 0);
if (!ValueKnown(shape_dim)) {
return ReturnUnknownShape(out);
}
const auto num_dims = Value(shape_dim);
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
return ReturnCreatedShape(dims, out);
}
@@ -470,7 +469,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
t->shape().dims());
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
if (t->dtype() == DataType::DT_INT32) {
auto flat_t = t->flat<int32>();
for (int i = 0; i < flat_t.size(); ++i) {
@@ -492,7 +491,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
}
Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
- const Shape** out) {
+ ShapeHandle* out) {
*out = nullptr;
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
PartialTensorShape partial_shape(proto);
@@ -500,7 +499,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return ReturnUnknownShape(out);
}
const int num_dims = partial_shape.dims();
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
dims.reserve(partial_shape.dims());
for (int i = 0; i < num_dims; ++i) {
// -1 is unknown in proto and in InferenceContext, so this size can be
@@ -511,7 +510,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
}
// Returns a new dimension whose value is given by a scalar input tensor.
-Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
+Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
const Tensor* t = input_tensor(idx);
if (t == nullptr) {
*out = UnknownDim();
@@ -539,8 +538,8 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
return Status::OK();
}
-Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
- const Dimension** out) {
+Status InferenceContext::Divide(DimensionHandle dividend, int64 divisor,
+ DimensionHandle* out) {
if (divisor == 1) {
*out = dividend;
} else if (!ValueKnown(dividend)) {
@@ -560,8 +559,8 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
return Status::OK();
}
-Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -583,9 +582,9 @@ Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
return Status::OK();
}
-Status InferenceContext::Subtract(const Dimension* first,
+Status InferenceContext::Subtract(DimensionHandle first,
DimensionOrConstant second,
- const Dimension** out) {
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -606,9 +605,9 @@ Status InferenceContext::Subtract(const Dimension* first,
return Status::OK();
}
-Status InferenceContext::Multiply(const Dimension* first,
+Status InferenceContext::Multiply(DimensionHandle first,
DimensionOrConstant second,
- const Dimension** out) {
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
// Special cases.
@@ -635,8 +634,8 @@ Status InferenceContext::Multiply(const Dimension* first,
return Status::OK();
}
-Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
if (first_value == 0) {
@@ -655,8 +654,8 @@ Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
return Status::OK();
}
-Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
- const Dimension** out) {
+Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
+ DimensionHandle* out) {
const int64 first_value = Value(first);
const int64 second_value = Value(second);
if (first_value == kUnknownDim || second_value == kUnknownDim) {
@@ -672,13 +671,13 @@ Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
}
Status InferenceContext::MakeShapeFromString(const string& spec,
- const Shape** output) {
+ ShapeHandle* output) {
if (spec == "?") {
*output = UnknownShape();
return Status::OK();
}
- std::vector<const Dimension*> dims;
+ std::vector<DimensionHandle> dims;
strings::Scanner scanner(spec);
scanner.OneLiteral("[");
while (scanner.Peek() != ']') {