diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index ef8ad7972c..1d11ec00ce 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -427,7 +427,19 @@ REGISTER_OP("UnravelIndex") .Input("dims: Tidx") .Output("output: Tidx") .Attr("Tidx: {int32, int64} = DT_INT32") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }); + .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices = c->input(0); + ShapeHandle dims; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); + if (c->RankKnown(indices) && c->Rank(indices) == 0) { + c->set_output(0, c->Vector(c->Dim(dims, 0))); + } else if (c->RankKnown(indices)) { + c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices))); + } else { + c->set_output(0, c->UnknownShape()); + } + return Status::OK(); + }); REGISTER_OP("BroadcastTo") .Input("input: T") @@ -690,6 +702,16 @@ REGISTER_OP("Const") return Status::OK(); }); +// Returns a constant tensor on the host. Useful for writing C++ tests +// and benchmarks which run on GPU but require arguments pinned to the host. +// Used by test::graph::HostConstant. +// value: Attr `value` is the tensor to return. +REGISTER_OP("HostConst") + .Output("output: dtype") + .Attr("value: tensor") + .Attr("dtype: type") + .SetShapeFn(shape_inference::UnknownShape); + // -------------------------------------------------------------------------- // TODO(mgubin): Update the doc when the freeze_graph script supports converting // into memmapped format. |