diff options
author | 2017-07-05 10:10:03 -0700 | |
---|---|---|
committer | 2017-07-05 10:14:25 -0700 | |
commit | 19220809123803af8fdcbbe628ee9f80fb4521c8 (patch) | |
tree | 674109a5d013d044c31c5b265307f4a7d862b2c2 /tensorflow/compiler/tf2xla/shape_util.cc | |
parent | 94d52acdc0087d5829f220c4d46eea67e0d30305 (diff) |
[TF:XLA] Refactor XLAShapeToTensorShape so it returns an error Status if passed an XLA tuple shape, rather than CHECK-failing.
PiperOrigin-RevId: 160971216
Diffstat (limited to 'tensorflow/compiler/tf2xla/shape_util.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/shape_util.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index f5ecb51a5b..9d1992205b 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -24,12 +24,18 @@ limitations under the License. namespace tensorflow { // Convert an XLA Shape into the equivalent TensorFlow shape. -TensorShape XLAShapeToTensorShape(const xla::Shape& shape) { - TensorShape tensor_shape; +Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape) { + if (xla::ShapeUtil::IsTuple(shape)) { + return errors::InvalidArgument("XLA shape ", + xla::ShapeUtil::HumanString(shape), + " cannot be converted to a TensorShape"); + } + *tensor_shape = TensorShape(); for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { - tensor_shape.AddDim(shape.dimensions(i)); + tensor_shape->AddDim(shape.dimensions(i)); } - return tensor_shape; + return Status::OK(); } // Convert a TensorShape into the equivalent XLA Shape proto. |