/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace shape_inference { constexpr int32 InferenceContext::kUnknownRank; constexpr int64 InferenceContext::kUnknownDim; InferenceContext::InferenceContext( int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector& input_shapes, const std::vector& input_tensors, const std::vector& input_tensors_as_shapes, const std::vector< std::unique_ptr>>>& input_handle_shapes_and_types) : graph_def_version_(graph_def_version), node_def_(CHECK_NOTNULL(node_def)) { std::vector input_tensors_as_shape_handles; input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size()); for (const TensorShapeProto& p : input_tensors_as_shapes) { ShapeHandle shape; construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); if (!construction_status_.ok()) { return; } input_tensors_as_shape_handles.push_back(shape); } PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); if (!construction_status_.ok()) return; inputs_.reserve(input_shapes.size()); for (const TensorShapeProto& p : input_shapes) { ShapeHandle shape; construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); if (!construction_status_.ok()) { return; } inputs_.push_back(shape); } std::vector>> handle_data( input_shapes.size()); for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) { const auto& v = input_handle_shapes_and_types[i]; if (v == nullptr) { continue; } handle_data[i].reset(new std::vector(v->size())); auto& new_v = *handle_data[i]; for (int j = 0; j < v->size(); ++j) { const auto& p = (*v)[j]; construction_status_.Update( MakeShapeFromShapeProto(p.first, &new_v[j].shape)); if (!construction_status_.ok()) { return; } new_v[j].dtype = p.second; } } PostInputInit(std::move(handle_data)); } // Same as above, but with PartialTensorShape instead of TensorShapeProto InferenceContext::InferenceContext( int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector& input_shapes, const std::vector& input_tensors, const std::vector& input_tensors_as_shapes, const std::vector< std::unique_ptr>>>& input_handle_shapes_and_types) : graph_def_version_(graph_def_version), node_def_(CHECK_NOTNULL(node_def)) { std::vector input_tensors_as_shape_handles; input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size()); for (const PartialTensorShape& p : input_tensors_as_shapes) { ShapeHandle shape; construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); if (!construction_status_.ok()) { return; } input_tensors_as_shape_handles.push_back(shape); } PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); if (!construction_status_.ok()) return; inputs_.reserve(input_shapes.size()); for (const PartialTensorShape& p : input_shapes) { ShapeHandle shape; construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); if (!construction_status_.ok()) { return; } inputs_.push_back(shape); } std::vector>> handle_data( input_shapes.size()); for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) { const auto& v = input_handle_shapes_and_types[i]; if (v == nullptr) { continue; } handle_data[i].reset(new std::vector(v->size())); auto& new_v = *handle_data[i]; for (int j = 0; j < v->size(); ++j) { const auto& p = (*v)[j]; construction_status_.Update( MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape)); if (!construction_status_.ok()) { return; } new_v[j].dtype = p.second; } } PostInputInit(std::move(handle_data)); } InferenceContext::InferenceContext( int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector& input_shapes, const std::vector& input_tensors, const std::vector& input_tensors_as_shapes, std::vector>> input_handle_shapes_and_types) : graph_def_version_(graph_def_version), node_def_(CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; inputs_ = input_shapes; PostInputInit(std::move(input_handle_shapes_and_types)); } InferenceContext::~InferenceContext() {} Status InferenceContext::Run( const std::function& fn) { ForgetMerges(); Status s = fn(this); if (!s.ok()) { ForgetMerges(); return AttachContext(s); } #ifndef NDEBUG for (int i = 0; i < num_outputs(); ++i) { DCHECK(output(i).IsSet()) << i << " for " << node_def_->name() << " of type " << node_def_->op(); } #endif // NDEBUG return s; } Status InferenceContext::set_output(StringPiece output_name, const std::vector& shapes) { auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); } else { const int start = result->second.first; const int size = result->second.second - start; if (size != shapes.size()) { return errors::InvalidArgument("Must have exactly ", shapes.size(), " shapes."); } for (int i = 0; i < size; ++i) { outputs_[i + start] = shapes[i]; } } return Status::OK(); } Status InferenceContext::input(StringPiece input_name, std::vector* output) const { const auto result = input_name_map_.find(input_name); if (result == input_name_map_.end()) { return errors::InvalidArgument("Unknown input name: ", input_name); } else { output->clear(); for (int i = result->second.first; i < result->second.second; ++i) { output->push_back(inputs_[i]); } } return Status::OK(); } Status InferenceContext::output(StringPiece output_name, std::vector* output) const { const auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); } else { output->clear(); for (int i = result->second.first; i < result->second.second; ++i) { output->push_back(outputs_[i]); } } return Status::OK(); } string InferenceContext::op() const { return node_def_->op(); } void InferenceContext::PreInputInit( const OpDef& op_def, const std::vector& input_tensors, const std::vector& input_tensors_as_shapes) { input_tensors_ = input_tensors; input_tensors_as_shapes_ = input_tensors_as_shapes; construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_, &output_name_map_); if (!construction_status_.ok()) return; int num_outputs = 0; for (const auto& e : output_name_map_) { num_outputs = std::max(num_outputs, e.second.second); } outputs_.assign(num_outputs, nullptr); output_handle_shapes_and_types_.resize(num_outputs); } Status InferenceContext::ExpandOutputs(int new_output_size) { if (new_output_size < outputs_.size()) { return errors::InvalidArgument("Trying to reduce number of outputs of op."); } outputs_.resize(new_output_size, nullptr); output_handle_shapes_and_types_.resize(new_output_size); return Status::OK(); } void InferenceContext::PostInputInit( std::vector>> input_handle_data) { int num_inputs_from_node_def = 0; for (const auto& e : input_name_map_) { num_inputs_from_node_def = std::max(num_inputs_from_node_def, e.second.second); } // Allow passing empty shapes/dtypes to avoid changing every single test. if (input_handle_data.empty()) { input_handle_shapes_and_types_.resize(inputs_.size()); } else { if (input_handle_data.size() != inputs_.size()) { construction_status_ = errors::InvalidArgument( "Wrong number of handle shapes passed; expected ", inputs_.size(), " got ", input_handle_data.size()); return; } input_handle_shapes_and_types_ = std::move(input_handle_data); } if (inputs_.size() != num_inputs_from_node_def) { construction_status_ = errors::InvalidArgument( "Wrong number of inputs passed: ", inputs_.size(), " while ", num_inputs_from_node_def, " expected based on NodeDef"); return; } CHECK_LE(input_tensors_.size(), inputs_.size()); input_tensors_.resize(inputs_.size()); requested_input_tensor_.resize(inputs_.size()); requested_input_tensor_as_partial_shape_.resize(inputs_.size()); } void InferenceContext::ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto) { if (!RankKnown(handle)) { proto->set_unknown_rank(true); return; } for (int32 i = 0; i < Rank(handle); ++i) { DimensionHandle dim = Dim(handle, i); auto* dim_shape = proto->add_dim(); if (ValueKnown(dim)) { dim_shape->set_size(Value(dim)); } else { dim_shape->set_size(-1); } } } bool InferenceContext::FullyDefined(ShapeHandle s) { if (!RankKnown(s)) return false; for (int i = 0; i < Rank(s); ++i) { if (!ValueKnown(Dim(s, i))) return false; } return true; } DimensionHandle InferenceContext::NumElements(ShapeHandle s) { const auto rank = Rank(s); if (rank == kUnknownRank) return UnknownDim(); bool found_unknown = false; int64 size = 1; for (int i = 0; i < rank; ++i) { int64 dim_val = Value(Dim(s, i)); if (dim_val == kUnknownDim) { found_unknown = true; } else if (dim_val == 0) { return MakeDim(0); } else { size *= dim_val; } } if (found_unknown) { return UnknownDim(); } else { return MakeDim(size); } } string InferenceContext::DebugString(ShapeHandle s) { if (RankKnown(s)) { std::vector vals; for (auto d : s->dims_) vals.push_back(DebugString(d)); return strings::StrCat("[", str_util::Join(vals, ","), "]"); } else { return "?"; } } string InferenceContext::DebugString(DimensionHandle d) { return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; } string InferenceContext::DebugString() const { return strings::StrCat("InferenceContext for node: ", ProtoDebugString(*node_def_)); } string InferenceContext::DebugString(const ShapeAndType& shape_and_type) { return strings::StrCat(DebugString(shape_and_type.shape), ":", DataTypeString(shape_and_type.dtype)); } string InferenceContext::DebugString( gtl::ArraySlice shape_and_types) { std::vector pieces; for (const ShapeAndType& s : shape_and_types) { pieces.push_back(DebugString(s)); } return strings::StrCat("[", str_util::Join(pieces, ","), "]"); } Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } const int32 existing = Rank(shape); if (existing == rank) { *out = shape; return Status::OK(); } if (existing == kUnknownRank) { std::vector dims; dims.reserve(rank); for (int i = 0; i < rank; ++i) { dims.push_back(UnknownDim()); } ShapeHandle shp = shape_manager_.MakeShape(dims); return Merge(shape, shp, out); } *out = nullptr; return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ", existing); } Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank, ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } const int32 existing = Rank(shape); if (existing >= rank || existing == kUnknownRank) { *out = shape; return Status::OK(); } *out = nullptr; return errors::InvalidArgument("Shape must be at least rank ", rank, " but is rank ", existing); } Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank, ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } const int32 existing = Rank(shape); if (existing <= rank || existing == kUnknownRank) { *out = shape; return Status::OK(); } *out = nullptr; return errors::InvalidArgument("Shape must be at most rank ", rank, " but is rank ", existing); } Status InferenceContext::WithValue(DimensionHandle dim, int64 value, DimensionHandle* out) { const int64 existing = Value(dim); if (existing == value) { *out = dim; return Status::OK(); } if (existing == kUnknownDim) { DimensionHandle d = MakeDim(value); return Merge(dim, d, out); } *out = nullptr; return errors::InvalidArgument("Dimension must be ", value, " but is ", existing); } void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new, DimensionHandle* out) { if (d_old.SameHandle(d_new)) { *out = d_old; } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) { // The node will be fed by the dimension d_new instead of d_old: any // equality assertion between d_old and other input dimension on this node // may not be true anymore, so forget them all. ForgetMerges(); // Return the new shape handle to force the relaxation to propagate to the // fanout of the context. *out = d_new; } else if (!ValueKnown(d_new)) { ForgetMerges(); *out = d_new; } else if (Value(d_old) == Value(d_new)) { // Return the old shape handle. This will stop the relaxation in the fanout // of the context. *out = d_old; } else { // Return a new handle that encodes a different unknown dim. ForgetMerges(); *out = UnknownDim(); } } Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, DimensionHandle* out) { if (d0.SameHandle(d1)) { *out = d0; return Status::OK(); } else if (!ValueKnown(d1)) { *out = d0; merged_dims_.emplace_back(d0, d1); return Status::OK(); } else if (!ValueKnown(d0)) { *out = d1; merged_dims_.emplace_back(d0, d1); return Status::OK(); } else if (Value(d0) == Value(d1)) { *out = d0; return Status::OK(); } else { *out = nullptr; return errors::InvalidArgument("Dimensions must be equal, but are ", Value(d0), " and ", Value(d1)); } } Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, ShapeHandle* prefix_out) { *s_out = *prefix_out = nullptr; if (!RankKnown(prefix) || !RankKnown(s)) { *s_out = s; *prefix_out = prefix; return Status::OK(); } const int32 rank = Rank(prefix); TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s)); // Merge the prefix dims and create the new output shapes. const int32 rank_s = Rank(s); std::vector dims; dims.reserve(std::max(rank, rank_s)); dims.resize(rank); for (int i = 0; i < rank; ++i) { TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i])); } *prefix_out = MakeShape(dims); for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i)); *s_out = MakeShape(dims); return Status::OK(); } void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out) { if (s_old.SameHandle(s_new)) { *out = s_old; return; } else if (!RankKnown(s_new) || !s_old.IsSet()) { ForgetMerges(); *out = s_new; return; } const int32 rank = Rank(s_old); if (rank != Rank(s_new)) { ForgetMerges(); *out = UnknownShape(); return; } bool return_s_old = true; for (int i = 0; i < rank; ++i) { auto d0 = Dim(s_old, i); auto d1 = Dim(s_new, i); if (d0.SameHandle(d1)) continue; auto v0 = Value(d0); auto v1 = Value(d1); if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) { return_s_old = false; break; } } if (return_s_old) { *out = s_old; return; } // Relax dims. std::vector dims(rank); for (int i = 0; i < rank; ++i) { Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]); } ForgetMerges(); *out = MakeShape(dims); } Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { if (s0.SameHandle(s1)) { *out = s0; return Status::OK(); } else if (!RankKnown(s1)) { *out = s0; merged_shapes_.emplace_back(s0, s1); return Status::OK(); } else if (!RankKnown(s0)) { *out = s1; merged_shapes_.emplace_back(s0, s1); return Status::OK(); } const int32 rank = Rank(s0); if (rank != Rank(s1)) { *out = nullptr; return errors::InvalidArgument("Shapes must be equal rank, but are ", rank, " and ", Rank(s1)); } bool return_s0 = true; bool return_s1 = true; for (int i = 0; i < rank; ++i) { auto d0 = Dim(s0, i); auto d1 = Dim(s1, i); if (d0.SameHandle(d1)) continue; auto v0 = Value(d0); auto v1 = Value(d1); if (v0 == kUnknownDim) { if (v1 != kUnknownDim) { return_s0 = false; } } else if (v1 == kUnknownDim) { return_s1 = false; } else if (v0 != v1) { *out = nullptr; return errors::InvalidArgument( "Dimension ", i, " in both shapes must be equal, but are ", Value(d0), " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ", DebugString(s1), "."); } } merged_shapes_.emplace_back(s0, s1); if (return_s0 || return_s1) { *out = return_s0 ? s0 : s1; return Status::OK(); } // Merge dims. std::vector dims(rank, nullptr); for (int i = 0; i < rank; ++i) { // Invariant for merge was checked earlier, so CHECK is ok. TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i])); } Status s = ReturnCreatedShape(dims, out); if (s.ok()) { // Merge the new shape with s0. Since s0 and s1 are merged, this implies // that s1 and out are also merged. merged_shapes_.emplace_back(s0, *out); } return s; } Status InferenceContext::Subshape(ShapeHandle s, int64 start, ShapeHandle* out) { return Subshape(s, start, std::numeric_limits::max() /* end */, out); } Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, ShapeHandle* out) { return Subshape(s, start, end, 1 /* stride */, out); } Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, int64 stride, ShapeHandle* out) { int64 start_in = start; int64 end_in = end; const int32 rank = Rank(s); if (start == 0 && stride == 1 && ((RankKnown(s) && end >= rank) || end == std::numeric_limits::max())) { *out = s; return Status::OK(); } if (!RankKnown(s)) { return ReturnUnknownShape(out); } if (start > rank) start = rank; if (end > rank) end = rank; if (stride < 0 && start == rank) --start; if (start < 0) { start = rank + start; if (start < 0) { *out = nullptr; return errors::InvalidArgument("Subshape start out of bounds: ", start_in, ", for shape with rank ", rank); } } if (end < 0) { end = rank + end; if (end < 0) { *out = nullptr; return errors::InvalidArgument("Subshape end out of bounds: ", end_in, ", for shape with rank ", rank); } } if (stride > 0 && start > end) { *out = nullptr; return errors::InvalidArgument( "Subshape must have computed start <= end, but is ", start, " and ", end, " (computed from start ", start_in, " and end ", end_in, " over shape with rank ", rank, ")"); } else if (stride < 0 && start < end) { *out = nullptr; return errors::InvalidArgument( "Subshape must have computed start >= end since stride is negative, " "but is ", start, " and ", end, " (computed from start ", start_in, " and end ", end_in, " over shape with rank ", rank, " and stride", stride, ")"); } std::vector dims; for (int i = start; stride > 0 ? i < end : i > end; i += stride) { dims.push_back(Dim(s, i)); } return ReturnCreatedShape(dims, out); } Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, ShapeHandle* out) { if (!RankKnown(s1) || !RankKnown(s2)) { return ReturnUnknownShape(out); } const int32 s1_rank = Rank(s1); const int32 s2_rank = Rank(s2); const int32 rank = s1_rank + s2_rank; std::vector dims; dims.reserve(rank); for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i)); for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i)); return ReturnCreatedShape(dims, out); } Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in, DimensionHandle new_dim, ShapeHandle* out) { if (!RankKnown(s)) { return ReturnUnknownShape(out); } int64 dim_index = dim_index_in; if (dim_index < 0) { dim_index = s->dims_.size() + dim_index; } if (!FastBoundsCheck(dim_index, s->dims_.size())) { *out = nullptr; return errors::InvalidArgument("Out of range dim_index ", dim_index_in, " for shape with ", s->dims_.size(), " dimensions"); } std::vector dims(s->dims_); dims[dim_index] = new_dim; return ReturnCreatedShape(dims, out); } ShapeHandle InferenceContext::MakeShape( const std::vector& dims) { return shape_manager_.MakeShape(dims); } ShapeHandle InferenceContext::MakeShape( std::initializer_list dims) { std::vector dims_actual; dims_actual.reserve(dims.size()); for (const DimensionOrConstant& d : dims) { dims_actual.push_back(MakeDim(d)); } return shape_manager_.MakeShape(dims_actual); } ShapeHandle InferenceContext::UnknownShape() { return shape_manager_.UnknownShape(); } ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) { CHECK_LE(rank, kint32max) << "rank must be less than kint32max"; if (rank == kUnknownRank) { return UnknownShape(); } CHECK_GE(rank, 0) << "rank must not be negative"; std::vector dims(rank); for (int32 i = 0; i < rank; ++i) { dims[i] = UnknownDim(); } return MakeShape(dims); } ShapeHandle InferenceContext::Scalar() { return MakeShape({}); } ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) { return MakeShape({dim}); } ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2) { return MakeShape({dim1, dim2}); } Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( int input_idx, ShapeHandle* out) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape)); requested_input_tensor_as_partial_shape_[input_idx] = true; if (input_idx < input_tensors_as_shapes_.size() && input_tensors_as_shapes_[input_idx].IsSet() && RankKnown(input_tensors_as_shapes_[input_idx])) { *out = input_tensors_as_shapes_[input_idx]; return Status::OK(); } return InternalMakeShapeFromTensor( true /* treat_unknown_scalar_tensor_as_unknown_shape */, input_tensor(input_idx), input_shape, out); } Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape)); requested_input_tensor_as_partial_shape_[input_idx] = true; if (input_idx < input_tensors_as_shapes_.size() && input_tensors_as_shapes_[input_idx].IsSet() && RankKnown(input_tensors_as_shapes_[input_idx])) { *out = input_tensors_as_shapes_[input_idx]; return Status::OK(); } return InternalMakeShapeFromTensor( false /* treat_unknown_scalar_tensor_as_unknown_shape */, input_tensor(input_idx), input_shape, out); } Status InferenceContext::MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out) { return InternalMakeShapeFromTensor( false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape, out); } Status InferenceContext::InternalMakeShapeFromTensor( bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out) { // Only callers who have set if (!treat_unknown_scalar_tensor_as_unknown_shape) { TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape)); } if (t == nullptr) { // This is guarded by the check above. if (Rank(tensor_shape) == 0) { return ReturnUnknownShape(out); } // Shape tensor is not known, but if the shape of the shape tensor is then // the right number of unknown dims can be created. DimensionHandle shape_dim = Dim(tensor_shape, 0); if (!ValueKnown(shape_dim)) { return ReturnUnknownShape(out); } const auto num_dims = Value(shape_dim); std::vector dims; dims.reserve(num_dims); for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim()); return ReturnCreatedShape(dims, out); } if (t->shape().dims() == 0) { if (t->dtype() == DataType::DT_INT32) { auto flat_t = t->scalar(); if (flat_t() != -1) { *out = nullptr; return errors::InvalidArgument( "Input tensor must be rank 1, or if its rank 0 it must have value " "-1 " "(representing an unknown shape). Saw value: ", flat_t()); } return ReturnUnknownShape(out); } else if (t->dtype() == DataType::DT_INT64) { auto flat_t = t->scalar(); if (flat_t() != -1) { *out = nullptr; return errors::InvalidArgument( "Input tensor must be rank 1, or if its rank 0 it must have value " "-1 " "(representing an unknown shape). Saw value: ", flat_t()); } return ReturnUnknownShape(out); } else { *out = nullptr; return errors::InvalidArgument( "Input tensor must be int32 or int64, but was ", DataTypeString(t->dtype())); } } if (t->shape().dims() != 1) { *out = nullptr; return errors::InvalidArgument( "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".", ((t->shape().dims() == 0) ? "If it is rank 0 rank 0 it must have statically known value -1 " "(representing an unknown shape). " : " "), "Saw tensor shape ", t->shape().DebugString()); } std::vector dims; if (t->dtype() == DataType::DT_INT32) { auto flat_t = t->flat(); for (int i = 0; i < flat_t.size(); ++i) { const int32 val = flat_t(i); if (val < -1) { return errors::InvalidArgument( "Invalid value in tensor used for shape: ", val); } // -1 will become an unknown dim. dims.push_back(MakeDim(val)); } } else if (t->dtype() == DataType::DT_INT64) { auto flat_t = t->flat(); for (int i = 0; i < flat_t.size(); ++i) { const int64 val = flat_t(i); if (val < -1) { return errors::InvalidArgument( "Invalid value in tensor used for shape: ", val); } // -1 will become an unknown dim. dims.push_back(MakeDim(val)); } } else { *out = nullptr; return errors::InvalidArgument( "Input tensor must be int32 or int64, but was ", DataTypeString(t->dtype())); } return ReturnCreatedShape(dims, out); } Status InferenceContext::MakeShapeFromPartialTensorShape( const PartialTensorShape& partial_shape, ShapeHandle* out) { *out = nullptr; if (partial_shape.dims() == -1) { return ReturnUnknownShape(out); } const int num_dims = partial_shape.dims(); std::vector dims(num_dims); for (int i = 0; i < num_dims; ++i) { // -1 is unknown in PartialTensorShape and in InferenceContext, so this size // can be passed directly to MakeDim. dims[i] = MakeDim(partial_shape.dim_size(i)); } return ReturnCreatedShape(dims, out); } Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out) { return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()), out); } Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, ShapeHandle* out) { *out = nullptr; TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); PartialTensorShape partial_shape(proto); return MakeShapeFromPartialTensorShape(partial_shape, out); } Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { // Caller must ensure that is not NULL. const int rank = t->dims(); if (rank != 0) { return errors::InvalidArgument("Input must be scalar but has rank ", rank); } if (t->dtype() == DT_INT32) { *val = t->scalar()(); return Status::OK(); } else if (t->dtype() == DT_INT64) { *val = t->scalar()(); return Status::OK(); } else { return errors::InvalidArgument("Scalar input must be int32 or int64."); } } // Returns a new dimension whose value is given by a scalar input tensor. Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { int64 val; const Tensor* t = input_tensor(idx); if (t == nullptr) { *out = UnknownDim(); return Status::OK(); } TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); if (val < 0) { return errors::InvalidArgument("Dimension size, given by scalar input ", idx, ", must be non-negative but is ", val); } *out = MakeDim(val); return Status::OK(); } Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( int idx, int input_rank, DimensionHandle* out) { int64 val; const Tensor* t = input_tensor(idx); if (t == nullptr) { *out = UnknownDim(); return Status::OK(); } TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); if (val < 0) { if (input_rank < 0) { *out = UnknownDim(); return Status::OK(); } else if (val + input_rank < 0) { return errors::InvalidArgument("Dimension size, given by scalar input ", val, " must be in range [-", input_rank, ", ", input_rank, ")"); } else { val += input_rank; } } else if (input_rank >= 0 && val >= input_rank) { return errors::InvalidArgument("Dimension size, given by scalar input ", val, " must be in range [-", input_rank, ", ", input_rank, ")"); } *out = MakeDim(val); return Status::OK(); } Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { const int64 divisor_value = Value(divisor); if (divisor_value == 1) { *out = dividend; } else if (!ValueKnown(dividend) || (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) { *out = UnknownDim(); } else { const int64 v = Value(dividend); if (divisor_value <= 0) { return errors::InvalidArgument("Divisor must be positive but is ", divisor_value); } if (evenly_divisible && (v % divisor_value) != 0) { return errors::InvalidArgument( "Dimension size must be evenly divisible by ", divisor_value, " but is ", v); } *out = MakeDim(v / divisor_value); } return Status::OK(); } Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. if (first_value == 0) { *out = MakeDim(second); } else if (second_value == 0) { *out = first; } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { // Invariant: Both values are known and positive. Still in run-time we can // get pair of values which cannot be store in output. Check below will // report error. We still need to avoid undefined behavior of signed // overflow and use unsigned addition. const int64 sum = static_cast(first_value) + second_value; if (sum < 0) { return errors::InvalidArgument("Dimension size overflow from adding ", first_value, " and ", second_value); } *out = MakeDim(sum); } return Status::OK(); } Status InferenceContext::Subtract(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. if (second_value == 0) { *out = first; } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { // Invariant: Both values are known, first_value is non-negative, and // second_value is positive. if (first_value < second_value) { return errors::InvalidArgument( "Negative dimension size caused by subtracting ", second_value, " from ", first_value); } *out = MakeDim(first_value - second_value); } return Status::OK(); } Status InferenceContext::Multiply(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); // Special cases. if (first_value == 0) { *out = first; } else if (second_value == 0) { *out = MakeDim(second); } else if (first_value == 1) { *out = MakeDim(second); } else if (second_value == 1) { *out = first; } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { // Invariant: Both values are known and greater than 1. const int64 product = first_value * second_value; if (product < 0) { return errors::InvalidArgument( "Negative dimension size caused by overflow when multiplying ", first_value, " and ", second_value); } *out = MakeDim(product); } return Status::OK(); } Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); if (first_value == 0) { *out = first; } else if (second_value == 0) { *out = MakeDim(second); } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { if (first_value <= second_value) { *out = first; } else { *out = MakeDim(second); } } return Status::OK(); } Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { const int64 first_value = Value(first); const int64 second_value = Value(second); if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { if (first_value >= second_value) { *out = first; } else { *out = MakeDim(second); } } return Status::OK(); } Status InferenceContext::AttachContext(const Status& status) { std::vector input_shapes; input_shapes.reserve(inputs_.size()); for (const ShapeHandle& input_shape : inputs_) { input_shapes.emplace_back(DebugString(input_shape)); } // Add information about the input tensors and partial tensor shapes used. std::vector input_from_tensors_str; std::vector input_from_tensors_as_shape_str; input_from_tensors_as_shape_str.reserve(inputs_.size()); for (int i = 0; i < inputs_.size(); ++i) { if (requested_input_tensor_as_partial_shape_[i] && i < input_tensors_as_shapes_.size() && input_tensors_as_shapes_[i].IsSet() && RankKnown(input_tensors_as_shapes_[i])) { input_from_tensors_as_shape_str.push_back(strings::StrCat( "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i]))); } else if (requested_input_tensor_[i] && i < input_tensors_.size() && input_tensors_[i] != nullptr) { input_from_tensors_str.push_back(strings::StrCat( "input[", i, "] = <", input_tensors_[i]->SummarizeValue(256 /* max_values */), ">")); } } string error_context = strings::StrCat( " for '", node_def_->name(), "' (op: '", node_def_->op(), "') with input shapes: ", str_util::Join(input_shapes, ", ")); if (!input_from_tensors_str.empty()) { strings::StrAppend(&error_context, " and with computed input tensors: ", str_util::Join(input_from_tensors_str, ", ")); } if (!input_from_tensors_as_shape_str.empty()) { strings::StrAppend(&error_context, " and with input tensors computed as partial shapes: ", str_util::Join(input_from_tensors_as_shape_str, ",")); } strings::StrAppend(&error_context, "."); return Status(status.code(), strings::StrCat(status.error_message(), error_context)); } bool InferenceContext::MergeHandleShapesAndTypes( const std::vector& shapes_and_types, std::vector* to_update) { if (shapes_and_types.size() != to_update->size()) { return false; } std::vector new_values(shapes_and_types.size()); bool refined = false; for (int i = 0; i < shapes_and_types.size(); ++i) { const ShapeAndType& existing = (*to_update)[i]; if (shapes_and_types[i].dtype == existing.dtype) { new_values[i].dtype = existing.dtype; } else { if (existing.dtype != DT_INVALID) { return false; } else { new_values[i].dtype = shapes_and_types[i].dtype; refined = true; } } if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape) .ok()) { // merge failed, ignore the new value. new_values[i].shape = existing.shape; } if (!existing.shape.SameHandle(new_values[i].shape)) { refined = true; } } if (!refined) { return false; } for (int i = 0; i < new_values.size(); ++i) { (*to_update)[i] = new_values[i]; } return true; } bool InferenceContext::MergeOutputHandleShapesAndTypes( int idx, const std::vector& shapes_and_types) { if (output_handle_shapes_and_types_[idx] == nullptr) { output_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); return true; } return MergeHandleShapesAndTypes(shapes_and_types, output_handle_shapes_and_types_[idx].get()); } bool InferenceContext::MergeInputHandleShapesAndTypes( int idx, const std::vector& shapes_and_types) { if (input_handle_shapes_and_types_[idx] == nullptr) { input_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); return true; } return MergeHandleShapesAndTypes(shapes_and_types, input_handle_shapes_and_types_[idx].get()); } bool InferenceContext::RelaxHandleShapesAndMergeTypes( const std::vector& shapes_and_types, std::vector* to_update) { if (shapes_and_types.size() != to_update->size()) { return false; } std::vector new_values(shapes_and_types.size()); bool refined = false; for (int i = 0; i < shapes_and_types.size(); ++i) { const ShapeAndType& existing = (*to_update)[i]; if (shapes_and_types[i].dtype == existing.dtype) { new_values[i].dtype = existing.dtype; } else { if (existing.dtype != DT_INVALID) { return false; } else { new_values[i].dtype = shapes_and_types[i].dtype; refined = true; } } Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape); if (!existing.shape.SameHandle(new_values[i].shape)) { refined = true; } } if (!refined) { return false; } to_update->swap(new_values); return true; } bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { if (output_handle_shapes_and_types_[idx] == nullptr) { output_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); return true; } return RelaxHandleShapesAndMergeTypes( shapes_and_types, output_handle_shapes_and_types_[idx].get()); } bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { if (input_handle_shapes_and_types_[idx] == nullptr) { input_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); return true; } return RelaxHandleShapesAndMergeTypes( shapes_and_types, input_handle_shapes_and_types_[idx].get()); } // ----------------------------------------------------------------------------- // ShapeManager // ----------------------------------------------------------------------------- InferenceContext::ShapeManager::ShapeManager() {} InferenceContext::ShapeManager::~ShapeManager() { for (auto* s : all_shapes_) delete s; for (auto* d : all_dims_) delete d; } ShapeHandle InferenceContext::ShapeManager::MakeShape( const std::vector& dims) { all_shapes_.push_back(new Shape(dims)); return all_shapes_.back(); } ShapeHandle InferenceContext::ShapeManager::UnknownShape() { all_shapes_.push_back(new Shape()); return all_shapes_.back(); } } // namespace shape_inference } // namespace tensorflow