aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/array_ops.cc52
-rw-r--r--tensorflow/core/ops/dataset_ops.cc140
-rw-r--r--tensorflow/core/ops/manip_ops.cc13
-rw-r--r--tensorflow/core/ops/nn_ops.cc6
-rw-r--r--tensorflow/core/ops/random_ops.cc7
-rw-r--r--tensorflow/core/ops/string_ops.cc5
-rw-r--r--tensorflow/core/ops/training_ops.cc51
7 files changed, 253 insertions, 21 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 2a8b9f9bee..88fc03826a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -429,6 +429,58 @@ REGISTER_OP("UnravelIndex")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+REGISTER_OP("BroadcastTo")
+ .Input("input: T")
+ .Input("shape: Tidx")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle in = c->input(0);
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
+
+ if (!c->RankKnown(out)) {
+ // We have no information about the shape of the output.
+ c->set_output(0, out);
+ return Status::OK();
+ }
+
+ if (!c->RankKnown(in)) {
+ // We have no information about the shape of the input,
+ // nothing to do here.
+ c->set_output(0, out);
+ return Status::OK();
+ }
+ if (c->Rank(out) < c->Rank(in)) {
+ return errors::InvalidArgument("Cannot broadcast a tensor with shape ",
+ c->DebugString(in), " shape ",
+ c->DebugString(out));
+ }
+
+ int32 in_offset = c->Rank(out) - c->Rank(in);
+ for (int32 i = 0; i < c->Rank(out); ++i) {
+ DimensionHandle dim = c->Dim(out, i);
+ if (c->ValueKnown(dim)) {
+ // The first in_offset dimensions for input will be expanded with 1,
+ // so no check needed.
+ if (i >= in_offset) {
+ DimensionHandle in_dim = c->Dim(in, i - in_offset);
+ if (c->ValueKnown(in_dim)) {
+ if (c->Value(dim) % c->Value(in_dim) != 0) {
+ return errors::InvalidArgument(
+ "Cannot broadcast a tensor with shape ", c->DebugString(in),
+ " shape ", c->DebugString(out));
+ }
+ }
+ }
+ }
+ }
+
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
// in the N == 1 case to remove the node.
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 67c6c58fe2..4ba3f15ef0 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -148,7 +148,11 @@ REGISTER_OP("BytesProducedStatsDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("LatencyStatsDataset")
.Input("input_dataset: variant")
@@ -156,7 +160,11 @@ REGISTER_OP("LatencyStatsDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
@@ -206,7 +214,12 @@ REGISTER_OP("PrefetchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ScanDataset")
.Input("input_dataset: variant")
@@ -290,7 +303,12 @@ REGISTER_OP("BatchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
// TODO(mrry): move SlideDataset to contrib in the future.
REGISTER_OP("SlideDataset")
@@ -300,7 +318,13 @@ REGISTER_OP("SlideDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // window_size and stride should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("PaddedBatchDataset")
.Input("input_dataset: variant")
@@ -330,7 +354,14 @@ REGISTER_OP("DenseToSparseBatchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // row_shape should be a 1-D vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("RangeDataset")
.Input("start: int64")
@@ -341,7 +372,14 @@ REGISTER_OP("RangeDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // start, stop, and step should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("RandomDataset")
.Input("seed: int64")
@@ -351,7 +389,13 @@ REGISTER_OP("RandomDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, and seed2 should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ShuffleDataset")
.Input("input_dataset: variant")
@@ -362,7 +406,14 @@ REGISTER_OP("ShuffleDataset")
.Attr("reshuffle_each_iteration: bool = true")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, and seed2 should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ShuffleAndRepeatDataset")
.Input("input_dataset: variant")
@@ -373,7 +424,15 @@ REGISTER_OP("ShuffleAndRepeatDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, seed2, and count should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("CacheDataset")
.Input("input_dataset: variant")
@@ -381,7 +440,12 @@ REGISTER_OP("CacheDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // filename should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("TextLineDataset")
.Input("filenames: string")
@@ -390,10 +454,16 @@ REGISTER_OP("TextLineDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): validate
- // that `filenames` is
- // a scalar or a
- // vector.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ // `compression_type` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // `buffer_size` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ });
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
@@ -404,7 +474,14 @@ REGISTER_OP("SqlDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // driver_name, data_source_name, and query should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("FixedLengthRecordDataset")
.Input("filenames: string")
@@ -415,7 +492,18 @@ REGISTER_OP("FixedLengthRecordDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // header_bytes, record_bytes, footer_bytes, buffer_size should be
+ // scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("TFRecordDataset")
.Input("filenames: string")
@@ -424,7 +512,16 @@ REGISTER_OP("TFRecordDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // `buffer_size` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("Iterator")
.Output("handle: resource")
@@ -540,7 +637,12 @@ REGISTER_OP("PrependFromQueueAndPaddedBatchDataset")
// length of `output_types` is `N`, the `output_shapes` are
// (as far as possible to tell statically) compatible with `padded_shapes`,
// and that `padding_values` are all scalars.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("EnqueueInQueueDataset")
.Input("queue: variant")
diff --git a/tensorflow/core/ops/manip_ops.cc b/tensorflow/core/ops/manip_ops.cc
index 95b4774fe6..e180f3d5f6 100644
--- a/tensorflow/core/ops/manip_ops.cc
+++ b/tensorflow/core/ops/manip_ops.cc
@@ -28,6 +28,17 @@ REGISTER_OP("Roll")
.Attr("T: type")
.Attr("Tshift: {int32,int64}")
.Attr("Taxis: {int32,int64}")
- .SetShapeFn(shape_inference::UnchangedShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // The `input` must be 1-D or higher
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
+ // The `shift` must be scalar or 1-D.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused));
+ // The `axis` must be scalar or 1-D.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ // Validate 'shift' is the same shape as axis'.
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &unused));
+ return shape_inference::UnchangedShape(c);
+ });
} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 6dc3d9df31..bb46dafd42 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1535,6 +1535,7 @@ REGISTER_OP("__MklDummyConv2DWithBias")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
does not perform anything. It is just created as an intermediate output of
@@ -1561,6 +1562,7 @@ REGISTER_OP("_MklConv2DWithBias")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
2D convolution and add Bias to the output of convolution.
@@ -1683,6 +1685,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+#ifdef INTEL_MKL_ML
REGISTER_OP("_MklConv2DWithBiasBackpropBias")
.Input("out_backprop: T")
.Input("mkl_out_backprop: uint8")
@@ -1699,6 +1702,7 @@ gradients of convolution with respect to the bias.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+#endif
REGISTER_OP("_MklConv2DBackpropInput")
.Input("input_sizes: int32")
@@ -2156,6 +2160,7 @@ REGISTER_OP("_MklToTf")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to convert a tensor from MKL layout to TensorFlow layout.
@@ -2177,6 +2182,7 @@ REGISTER_OP("_MklInputConversion")
"T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
"complex64, complex128}")
.Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to process the inputs to an elementwise MKL op. Both inputs
need to be either in TF or in MKL format. This op is added before every
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index f6c668f5c9..416ce9c0d8 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -43,7 +43,12 @@ REGISTER_OP("RandomUniformInt")
.Attr("seed2: int = 0")
.Attr("Tout: {int32, int64}")
.Attr("T: {int32, int64}")
- .SetShapeFn(shape_inference::RandomShape);
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::RandomShape(c);
+ });
REGISTER_OP("RandomStandardNormal")
.Input("shape: T")
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 05f216a83e..469f193cf4 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -123,6 +123,11 @@ REGISTER_OP("StringSplit")
return Status::OK();
});
+REGISTER_OP("StringStrip")
+ .Input("input: string")
+ .Output("output: string")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("EncodeBase64")
.Input("input: string")
.Output("output: string")
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 6ce9595fb6..dc7b588898 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -737,6 +737,57 @@ REGISTER_OP("ResourceApplyAdam")
return ApplyAdamShapeFn(c, false /* sparse */);
});
+static Status ApplyAdaMaxShapeFn(InferenceContext* c, bool sparse) {
+ ShapeHandle unused;
+ ShapeHandle s = ShapeOrHandleShape(c, 0); // var
+ TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
+ TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
+ if (c->num_outputs() > 0) {
+ c->set_output(0, s);
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("ApplyAdaMax")
+ .Input("var: Ref(T)")
+ .Input("m: Ref(T)")
+ .Input("v: Ref(T)")
+ .Input("beta1_power: T")
+ .Input("lr: T")
+ .Input("beta1: T")
+ .Input("beta2: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ return ApplyAdaMaxShapeFn(c, false /* sparse */);
+ });
+
+REGISTER_OP("ResourceApplyAdaMax")
+ .Input("var: resource")
+ .Input("m: resource")
+ .Input("v: resource")
+ .Input("beta1_power: T")
+ .Input("lr: T")
+ .Input("beta1: T")
+ .Input("beta2: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ return ApplyAdaMaxShapeFn(c, false /* sparse */);
+ });
+
static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var