aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-04-06 21:00:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 21:03:10 -0700
commit273495dc2c957402f832cae31a438e550db2b7f0 (patch)
tree98691c91e0af5a5a7464ca0f2645b434160710fb /tensorflow/core/framework/shape_inference.cc
parent7f97f1bf69765be51b9f79f5134eb44736d216eb (diff)
Improvements to ResourceVariable + Variant code.
* Works in graph + eager modes * Fixed shape inference * Updated shape inference + refiner + constant eval code to support static shape tensor of `-1` meaning unknown shape. * Gather and Scatter for Variants now properly supported. * Variable copy-on-write for Variants now does a more shallow copy (as Variants are not expected to be updated "in-place" inside a variable; instead Variants will be updated via read-update-write inside a CriticalSection) PiperOrigin-RevId: 191975898
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc78
1 files changed, 75 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 54ecaa5dd4..cc1ec47a83 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -726,6 +726,24 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
return MakeShape({dim1, dim2});
}
+Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ int input_idx, ShapeHandle* out) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
+
+ requested_input_tensor_as_partial_shape_[input_idx] = true;
+ if (input_idx < input_tensors_as_shapes_.size() &&
+ input_tensors_as_shapes_[input_idx].IsSet() &&
+ RankKnown(input_tensors_as_shapes_[input_idx])) {
+ *out = input_tensors_as_shapes_[input_idx];
+ return Status::OK();
+ }
+
+ return InternalMakeShapeFromTensor(
+ true /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
+}
+
Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
ShapeHandle* out) {
ShapeHandle input_shape;
@@ -739,13 +757,31 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
return Status::OK();
}
- return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
}
Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
ShapeHandle tensor_shape,
ShapeHandle* out) {
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
+ out);
+}
+
+Status InferenceContext::InternalMakeShapeFromTensor(
+ bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+ ShapeHandle tensor_shape, ShapeHandle* out) {
+ // Only callers who have set
+ if (!treat_unknown_scalar_tensor_as_unknown_shape) {
+ TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
+ }
if (t == nullptr) {
+ // This is guarded by the check above.
+ if (Rank(tensor_shape) == 0) {
+ return ReturnUnknownShape(out);
+ }
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
DimensionHandle shape_dim = Dim(tensor_shape, 0);
@@ -759,10 +795,46 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
return ReturnCreatedShape(dims, out);
}
+ if (t->shape().dims() == 0) {
+ if (t->dtype() == DataType::DT_INT32) {
+ auto flat_t = t->scalar<int32>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else if (t->dtype() == DataType::DT_INT64) {
+ auto flat_t = t->scalar<int64>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be int32 or int64, but was ",
+ DataTypeString(t->dtype()));
+ }
+ }
+
if (t->shape().dims() != 1) {
*out = nullptr;
- return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
- t->shape().dims());
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
+ ((t->shape().dims() == 0)
+ ? "If it is rank 0 rank 0 it must have statically known value -1 "
+ "(representing an unknown shape). "
+ : " "),
+ "Saw tensor shape ", t->shape().DebugString());
}
std::vector<DimensionHandle> dims;
if (t->dtype() == DataType::DT_INT32) {