From f6bc8cabbd3ac1fb3acc36d3edbdce672cae7d12 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Aug 2016 16:35:14 -0800 Subject: Add shape_inference::ShapeHandle and shape_inference::DimensionHandle to replace uses of const Shape* and const Dimension*. This change only adds a typedef and updates references. A later change will make DimensionHandle and ShapeHandle real types instead of typedefs (to further hide the pointer access). Change: 131118981 --- tensorflow/core/framework/shape_inference.cc | 139 +++++++++++++-------------- 1 file changed, 69 insertions(+), 70 deletions(-) (limited to 'tensorflow/core/framework/shape_inference.cc') diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index c6da445165..90b3a6a688 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -35,7 +35,7 @@ InferenceContext::InferenceContext( PreInputInit(op_def, input_tensors); for (const string& spec : input_shapes) { - const Shape* shape; + ShapeHandle shape; construction_status_.Update(MakeShapeFromString(spec, &shape)); if (!construction_status_.ok()) { return; @@ -55,7 +55,7 @@ InferenceContext::InferenceContext( PreInputInit(op_def, input_tensors); if (!construction_status_.ok()) return; for (const TensorShapeProto& p : input_shapes) { - const Shape* shape; + ShapeHandle shape; construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); if (!construction_status_.ok()) { return; @@ -68,7 +68,7 @@ InferenceContext::InferenceContext( InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector& input_shapes_string, - const std::vector& input_shapes, + const std::vector& input_shapes, const std::vector& input_tensors) : node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors); @@ -118,7 +118,7 @@ void InferenceContext::PostInputInit() { requested_input_tensor_.resize(inputs_.size()); } -bool InferenceContext::FullyDefined(const Shape* s) { +bool InferenceContext::FullyDefined(ShapeHandle s) { if (!RankKnown(s)) return false; for (int i = 0; i < Rank(s); ++i) { if (!ValueKnown(Dim(s, i))) return false; @@ -126,7 +126,7 @@ bool InferenceContext::FullyDefined(const Shape* s) { return true; } -const Dimension* InferenceContext::NumElements(const Shape* s) { +DimensionHandle InferenceContext::NumElements(ShapeHandle s) { const auto rank = Rank(s); if (rank == kUnknownRank) return UnknownDim(); int64 size = 1; @@ -138,7 +138,7 @@ const Dimension* InferenceContext::NumElements(const Shape* s) { return MakeDim(size); } -string InferenceContext::DebugString(const Shape* s) { +string InferenceContext::DebugString(ShapeHandle s) { if (RankKnown(s)) { std::vector vals; for (auto d : s->dims_) vals.push_back(DebugString(d)); @@ -148,19 +148,19 @@ string InferenceContext::DebugString(const Shape* s) { } } -string InferenceContext::DebugString(const Dimension* d) { +string InferenceContext::DebugString(DimensionHandle d) { return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; } -Status InferenceContext::WithRank(const Shape* shape, int32 rank, - const Shape** out) { +Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, + ShapeHandle* out) { const int32 existing = Rank(shape); if (existing == rank) { *out = shape; return Status::OK(); } if (existing == kUnknownRank) { - std::vector dims; + std::vector dims; dims.reserve(rank); for (int i = 0; i < rank; ++i) { all_dims_.push_back(new Dimension()); @@ -175,8 +175,8 @@ Status InferenceContext::WithRank(const Shape* shape, int32 rank, existing); } -Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank, - const Shape** out) { +Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank, + ShapeHandle* out) { const int32 existing = Rank(shape); if (existing >= rank) { *out = shape; @@ -190,8 +190,8 @@ Status InferenceContext::WithRankAtLeast(const Shape* shape, int32 rank, " but is rank ", existing); } -Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank, - const Shape** out) { +Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank, + ShapeHandle* out) { const int32 existing = Rank(shape); if (existing == kUnknownRank) { return ReturnUnknownShape(out); @@ -205,8 +205,8 @@ Status InferenceContext::WithRankAtMost(const Shape* shape, int32 rank, " but is rank ", existing); } -Status InferenceContext::WithValue(const Dimension* dim, int64 value, - const Dimension** out) { +Status InferenceContext::WithValue(DimensionHandle dim, int64 value, + DimensionHandle* out) { const int64 existing = Value(dim); if (existing == value) { *out = dim; @@ -222,8 +222,8 @@ Status InferenceContext::WithValue(const Dimension* dim, int64 value, existing); } -Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1, - const Dimension** out) { +Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { if (d0 == d1 || !ValueKnown(d1)) { *out = d0; return Status::OK(); @@ -240,9 +240,9 @@ Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1, } } -Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix, - const Shape** s_out, - const Shape** prefix_out) { +Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, + ShapeHandle* s_out, + ShapeHandle* prefix_out) { *s_out = *prefix_out = nullptr; if (!RankKnown(prefix) || !RankKnown(s)) { *s_out = s; @@ -253,7 +253,7 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix, TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s)); // Merge the prefix dims and create the new output shapes. - std::vector dims; + std::vector dims; dims.resize(rank); for (int i = 0; i < rank; ++i) { TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i])); @@ -264,8 +264,8 @@ Status InferenceContext::MergePrefix(const Shape* s, const Shape* prefix, return Status::OK(); } -Status InferenceContext::Merge(const Shape* s0, const Shape* s1, - const Shape** out) { +Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { if (s0 == s1 || !RankKnown(s1)) { *out = s0; return Status::OK(); @@ -309,7 +309,7 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1, } // Merge dims. - std::vector dims(rank, nullptr); + std::vector dims(rank, nullptr); for (int i = 0; i < rank; ++i) { // Invariant for merge was checked earlier, so CHECK is ok. TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i])); @@ -317,13 +317,13 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1, return ReturnCreatedShape(dims, out); } -Status InferenceContext::Subshape(const Shape* s, int64 start, - const Shape** out) { +Status InferenceContext::Subshape(ShapeHandle s, int64 start, + ShapeHandle* out) { return Subshape(s, start, std::numeric_limits::max() /* end */, out); } -Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in, - const Shape** out) { +Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, + ShapeHandle* out) { int64 start = start_in; int64 end = end_in; const int32 rank = Rank(s); @@ -362,7 +362,7 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in, end, " (computed from start ", start_in, " and end ", end_in, " over shape with rank ", rank, ")"); } - std::vector dims; + std::vector dims; dims.reserve(end - start); for (int i = start; i < end; ++i) { dims.push_back(Dim(s, i)); @@ -370,24 +370,23 @@ Status InferenceContext::Subshape(const Shape* s, int64 start_in, int64 end_in, return ReturnCreatedShape(dims, out); } -Status InferenceContext::Concatenate(const Shape* s1, const Shape* s2, - const Shape** out) { +Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, + ShapeHandle* out) { if (!RankKnown(s1) || !RankKnown(s2)) { return ReturnUnknownShape(out); } const int32 s1_rank = Rank(s1); const int32 s2_rank = Rank(s2); const int32 rank = s1_rank + s2_rank; - std::vector dims; + std::vector dims; dims.reserve(rank); for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i)); for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i)); return ReturnCreatedShape(dims, out); } -Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in, - const Dimension* new_dim, - const Shape** out) { +Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in, + DimensionHandle new_dim, ShapeHandle* out) { if (!RankKnown(s)) { return ReturnUnknownShape(out); } @@ -401,20 +400,20 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in, " for shape with ", s->dims_.size(), " dimensions"); } - std::vector dims(s->dims_); + std::vector dims(s->dims_); dims[dim_index] = new_dim; return ReturnCreatedShape(dims, out); } -const Shape* InferenceContext::MakeShape( - const std::vector& dims) { +ShapeHandle InferenceContext::MakeShape( + const std::vector& dims) { all_shapes_.push_back(new Shape(dims)); return all_shapes_.back(); } -const Shape* InferenceContext::MakeShape( +ShapeHandle InferenceContext::MakeShape( std::initializer_list dims) { - std::vector dims_actual; + std::vector dims_actual; dims_actual.reserve(dims.size()); for (const DimensionOrConstant& d : dims) { dims_actual.push_back(MakeDim(d)); @@ -422,45 +421,45 @@ const Shape* InferenceContext::MakeShape( return MakeShape(dims_actual); } -const Shape* InferenceContext::UnknownShape() { +ShapeHandle InferenceContext::UnknownShape() { all_shapes_.push_back(new Shape()); return all_shapes_.back(); } -const Shape* InferenceContext::UnknownShapeOfRank(int32 rank) { - std::vector dims(rank); +ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) { + std::vector dims(rank); for (int32 i = 0; i < rank; ++i) { dims[i] = UnknownDim(); } return MakeShape(dims); } -const Shape* InferenceContext::Scalar() { return MakeShape({}); } +ShapeHandle InferenceContext::Scalar() { return MakeShape({}); } -const Shape* InferenceContext::Vector(DimensionOrConstant dim) { +ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) { return MakeShape({dim}); } -const Shape* InferenceContext::Matrix(DimensionOrConstant dim1, - DimensionOrConstant dim2) { +ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1, + DimensionOrConstant dim2) { return MakeShape({dim1, dim2}); } Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, - const Shape** out) { - const Shape* input_shape; + ShapeHandle* out) { + ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape)); const Tensor* t = input_tensor(input_idx); if (t == nullptr) { // Shape tensor is not known, but if the shape of the shape tensor is then // the right number of unknown dims can be created. - const Dimension* shape_dim = Dim(input_shape, 0); + DimensionHandle shape_dim = Dim(input_shape, 0); if (!ValueKnown(shape_dim)) { return ReturnUnknownShape(out); } const auto num_dims = Value(shape_dim); - std::vector dims; + std::vector dims; for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim()); return ReturnCreatedShape(dims, out); } @@ -470,7 +469,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, return errors::InvalidArgument("Input tensor must be rank 1, but was rank ", t->shape().dims()); } - std::vector dims; + std::vector dims; if (t->dtype() == DataType::DT_INT32) { auto flat_t = t->flat(); for (int i = 0; i < flat_t.size(); ++i) { @@ -492,7 +491,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, } Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, - const Shape** out) { + ShapeHandle* out) { *out = nullptr; TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); PartialTensorShape partial_shape(proto); @@ -500,7 +499,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return ReturnUnknownShape(out); } const int num_dims = partial_shape.dims(); - std::vector dims; + std::vector dims; dims.reserve(partial_shape.dims()); for (int i = 0; i < num_dims; ++i) { // -1 is unknown in proto and in InferenceContext, so this size can be @@ -511,7 +510,7 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, } // Returns a new dimension whose value is given by a scalar input tensor. -Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { +Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { const Tensor* t = input_tensor(idx); if (t == nullptr) { *out = UnknownDim(); @@ -539,8 +538,8 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { return Status::OK(); } -Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, - const Dimension** out) { +Status InferenceContext::Divide(DimensionHandle dividend, int64 divisor, + DimensionHandle* out) { if (divisor == 1) { *out = dividend; } else if (!ValueKnown(dividend)) { @@ -560,8 +559,8 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, return Status::OK(); } -Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second, - const Dimension** out) { +Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. @@ -583,9 +582,9 @@ Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second, return Status::OK(); } -Status InferenceContext::Subtract(const Dimension* first, +Status InferenceContext::Subtract(DimensionHandle first, DimensionOrConstant second, - const Dimension** out) { + DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. @@ -606,9 +605,9 @@ Status InferenceContext::Subtract(const Dimension* first, return Status::OK(); } -Status InferenceContext::Multiply(const Dimension* first, +Status InferenceContext::Multiply(DimensionHandle first, DimensionOrConstant second, - const Dimension** out) { + DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. @@ -635,8 +634,8 @@ Status InferenceContext::Multiply(const Dimension* first, return Status::OK(); } -Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second, - const Dimension** out) { +Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); if (first_value == 0) { @@ -655,8 +654,8 @@ Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second, return Status::OK(); } -Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second, - const Dimension** out) { +Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); if (first_value == kUnknownDim || second_value == kUnknownDim) { @@ -672,13 +671,13 @@ Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second, } Status InferenceContext::MakeShapeFromString(const string& spec, - const Shape** output) { + ShapeHandle* output) { if (spec == "?") { *output = UnknownShape(); return Status::OK(); } - std::vector dims; + std::vector dims; strings::Scanner scanner(spec); scanner.OneLiteral("["); while (scanner.Peek() != ']') { -- cgit v1.2.3