diff options
author | 2018-04-06 21:00:42 -0700 | |
---|---|---|
committer | 2018-04-06 21:03:10 -0700 | |
commit | 273495dc2c957402f832cae31a438e550db2b7f0 (patch) | |
tree | 98691c91e0af5a5a7464ca0f2645b434160710fb /tensorflow/core/framework/shape_inference.cc | |
parent | 7f97f1bf69765be51b9f79f5134eb44736d216eb (diff) |
Improvements to ResourceVariable + Variant code.
* Works in graph + eager modes
* Fixed shape inference
* Updated shape inference + refiner + constant eval code to support static shape tensor of `-1` meaning unknown shape.
* Gather and Scatter for Variants now properly supported.
* Variable copy-on-write for Variants now does a more shallow copy (as Variants are not expected to be updated "in-place" inside a variable; instead Variants will be updated via read-update-write inside a CriticalSection)
PiperOrigin-RevId: 191975898
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 78 |
1 files changed, 75 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 54ecaa5dd4..cc1ec47a83 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -726,6 +726,24 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1, return MakeShape({dim1, dim2}); } +Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + int input_idx, ShapeHandle* out) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape)); + + requested_input_tensor_as_partial_shape_[input_idx] = true; + if (input_idx < input_tensors_as_shapes_.size() && + input_tensors_as_shapes_[input_idx].IsSet() && + RankKnown(input_tensors_as_shapes_[input_idx])) { + *out = input_tensors_as_shapes_[input_idx]; + return Status::OK(); + } + + return InternalMakeShapeFromTensor( + true /* treat_unknown_scalar_tensor_as_unknown_shape */, + input_tensor(input_idx), input_shape, out); +} + Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out) { ShapeHandle input_shape; @@ -739,13 +757,31 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, return Status::OK(); } - return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out); + return InternalMakeShapeFromTensor( + false /* treat_unknown_scalar_tensor_as_unknown_shape */, + input_tensor(input_idx), input_shape, out); } Status InferenceContext::MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out) { + return InternalMakeShapeFromTensor( + false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape, + out); +} + +Status InferenceContext::InternalMakeShapeFromTensor( + bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, + ShapeHandle tensor_shape, ShapeHandle* out) { + // Only callers who have set + if (!treat_unknown_scalar_tensor_as_unknown_shape) { + TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape)); + } if (t == nullptr) { + // This is guarded by the check above. + if (Rank(tensor_shape) == 0) { + return ReturnUnknownShape(out); + } // Shape tensor is not known, but if the shape of the shape tensor is then // the right number of unknown dims can be created. DimensionHandle shape_dim = Dim(tensor_shape, 0); @@ -759,10 +795,46 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t, return ReturnCreatedShape(dims, out); } + if (t->shape().dims() == 0) { + if (t->dtype() == DataType::DT_INT32) { + auto flat_t = t->scalar<int32>(); + if (flat_t() != -1) { + *out = nullptr; + return errors::InvalidArgument( + "Input tensor must be rank 1, or if its rank 0 it must have value " + "-1 " + "(representing an unknown shape). Saw value: ", + flat_t()); + } + return ReturnUnknownShape(out); + } else if (t->dtype() == DataType::DT_INT64) { + auto flat_t = t->scalar<int64>(); + if (flat_t() != -1) { + *out = nullptr; + return errors::InvalidArgument( + "Input tensor must be rank 1, or if its rank 0 it must have value " + "-1 " + "(representing an unknown shape). Saw value: ", + flat_t()); + } + return ReturnUnknownShape(out); + } else { + *out = nullptr; + return errors::InvalidArgument( + "Input tensor must be int32 or int64, but was ", + DataTypeString(t->dtype())); + } + } + if (t->shape().dims() != 1) { *out = nullptr; - return errors::InvalidArgument("Input tensor must be rank 1, but was rank ", - t->shape().dims()); + return errors::InvalidArgument( + "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".", + ((t->shape().dims() == 0) + ? "If it is rank 0 rank 0 it must have statically known value -1 " + "(representing an unknown shape). " + : " "), + "Saw tensor shape ", t->shape().DebugString()); } std::vector<DimensionHandle> dims; if (t->dtype() == DataType::DT_INT32) { |