aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc24
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index ef8ad7972c..1d11ec00ce 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -427,7 +427,19 @@ REGISTER_OP("UnravelIndex")
.Input("dims: Tidx")
.Output("output: Tidx")
.Attr("Tidx: {int32, int64} = DT_INT32")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices = c->input(0);
+ ShapeHandle dims;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
+ if (c->RankKnown(indices) && c->Rank(indices) == 0) {
+ c->set_output(0, c->Vector(c->Dim(dims, 0)));
+ } else if (c->RankKnown(indices)) {
+ c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
+ } else {
+ c->set_output(0, c->UnknownShape());
+ }
+ return Status::OK();
+ });
REGISTER_OP("BroadcastTo")
.Input("input: T")
@@ -690,6 +702,16 @@ REGISTER_OP("Const")
return Status::OK();
});
+// Returns a constant tensor on the host. Useful for writing C++ tests
+// and benchmarks which run on GPU but require arguments pinned to the host.
+// Used by test::graph::HostConstant.
+// value: Attr `value` is the tensor to return.
+REGISTER_OP("HostConst")
+ .Output("output: dtype")
+ .Attr("value: tensor")
+ .Attr("dtype: type")
+ .SetShapeFn(shape_inference::UnknownShape);
+
// --------------------------------------------------------------------------
// TODO(mgubin): Update the doc when the freeze_graph script supports converting
// into memmapped format.