diff options
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 52 | ||||
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 140 | ||||
-rw-r--r-- | tensorflow/core/ops/manip_ops.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/ops/random_ops.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 51 |
7 files changed, 253 insertions, 21 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 2a8b9f9bee..88fc03826a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -429,6 +429,58 @@ REGISTER_OP("UnravelIndex") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }); +REGISTER_OP("BroadcastTo") + .Input("input: T") + .Input("shape: Tidx") + .Output("output: T") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle in = c->input(0); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); + + if (!c->RankKnown(out)) { + // We have no information about the shape of the output. + c->set_output(0, out); + return Status::OK(); + } + + if (!c->RankKnown(in)) { + // We have no information about the shape of the input, + // nothing to do here. + c->set_output(0, out); + return Status::OK(); + } + if (c->Rank(out) < c->Rank(in)) { + return errors::InvalidArgument("Cannot broadcast a tensor with shape ", + c->DebugString(in), " shape ", + c->DebugString(out)); + } + + int32 in_offset = c->Rank(out) - c->Rank(in); + for (int32 i = 0; i < c->Rank(out); ++i) { + DimensionHandle dim = c->Dim(out, i); + if (c->ValueKnown(dim)) { + // The first in_offset dimensions for input will be expanded with 1, + // so no check needed. + if (i >= in_offset) { + DimensionHandle in_dim = c->Dim(in, i - in_offset); + if (c->ValueKnown(in_dim)) { + if (c->Value(dim) % c->Value(in_dim) != 0) { + return errors::InvalidArgument( + "Cannot broadcast a tensor with shape ", c->DebugString(in), + " shape ", c->DebugString(out)); + } + } + } + } + } + + c->set_output(0, out); + return Status::OK(); + }); + // -------------------------------------------------------------------------- // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph // in the N == 1 case to remove the node. diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 67c6c58fe2..4ba3f15ef0 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -148,7 +148,11 @@ REGISTER_OP("BytesProducedStatsDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle tag_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("LatencyStatsDataset") .Input("input_dataset: variant") @@ -156,7 +160,11 @@ REGISTER_OP("LatencyStatsDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle tag_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("SetStatsAggregatorDataset") .Input("input_dataset: variant") @@ -206,7 +214,12 @@ REGISTER_OP("PrefetchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ScanDataset") .Input("input_dataset: variant") @@ -290,7 +303,12 @@ REGISTER_OP("BatchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); // TODO(mrry): move SlideDataset to contrib in the future. REGISTER_OP("SlideDataset") @@ -300,7 +318,13 @@ REGISTER_OP("SlideDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // window_size and stride should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("PaddedBatchDataset") .Input("input_dataset: variant") @@ -330,7 +354,14 @@ REGISTER_OP("DenseToSparseBatchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // row_shape should be a 1-D vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("RangeDataset") .Input("start: int64") @@ -341,7 +372,14 @@ REGISTER_OP("RangeDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // start, stop, and step should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("RandomDataset") .Input("seed: int64") @@ -351,7 +389,13 @@ REGISTER_OP("RandomDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, and seed2 should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ShuffleDataset") .Input("input_dataset: variant") @@ -362,7 +406,14 @@ REGISTER_OP("ShuffleDataset") .Attr("reshuffle_each_iteration: bool = true") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, and seed2 should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ShuffleAndRepeatDataset") .Input("input_dataset: variant") @@ -373,7 +424,15 @@ REGISTER_OP("ShuffleAndRepeatDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, seed2, and count should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("CacheDataset") .Input("input_dataset: variant") @@ -381,7 +440,12 @@ REGISTER_OP("CacheDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // filename should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("TextLineDataset") .Input("filenames: string") @@ -390,10 +454,16 @@ REGISTER_OP("TextLineDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): validate - // that `filenames` is - // a scalar or a - // vector. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + return shape_inference::ScalarShape(c); + // `compression_type` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // `buffer_size` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + }); REGISTER_OP("SqlDataset") .Input("driver_name: string") @@ -404,7 +474,14 @@ REGISTER_OP("SqlDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // driver_name, data_source_name, and query should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") @@ -415,7 +492,18 @@ REGISTER_OP("FixedLengthRecordDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // header_bytes, record_bytes, footer_bytes, buffer_size should be + // scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("TFRecordDataset") .Input("filenames: string") @@ -424,7 +512,16 @@ REGISTER_OP("TFRecordDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `compression_type` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // `buffer_size` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("Iterator") .Output("handle: resource") @@ -540,7 +637,12 @@ REGISTER_OP("PrependFromQueueAndPaddedBatchDataset") // length of `output_types` is `N`, the `output_shapes` are // (as far as possible to tell statically) compatible with `padded_shapes`, // and that `padding_values` are all scalars. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("EnqueueInQueueDataset") .Input("queue: variant") diff --git a/tensorflow/core/ops/manip_ops.cc b/tensorflow/core/ops/manip_ops.cc index 95b4774fe6..e180f3d5f6 100644 --- a/tensorflow/core/ops/manip_ops.cc +++ b/tensorflow/core/ops/manip_ops.cc @@ -28,6 +28,17 @@ REGISTER_OP("Roll") .Attr("T: type") .Attr("Tshift: {int32,int64}") .Attr("Taxis: {int32,int64}") - .SetShapeFn(shape_inference::UnchangedShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // The `input` must be 1-D or higher + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); + // The `shift` must be scalar or 1-D. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused)); + // The `axis` must be scalar or 1-D. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + // Validate 'shift' is the same shape as axis'. + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &unused)); + return shape_inference::UnchangedShape(c); + }); } // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 6dc3d9df31..bb46dafd42 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1535,6 +1535,7 @@ REGISTER_OP("__MklDummyConv2DWithBias") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::Conv2DShape) .Doc(R"doc( Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node does not perform anything. It is just created as an intermediate output of @@ -1561,6 +1562,7 @@ REGISTER_OP("_MklConv2DWithBias") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::Conv2DShape) .Doc(R"doc( MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform 2D convolution and add Bias to the output of convolution. @@ -1683,6 +1685,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +#ifdef INTEL_MKL_ML REGISTER_OP("_MklConv2DWithBiasBackpropBias") .Input("out_backprop: T") .Input("mkl_out_backprop: uint8") @@ -1699,6 +1702,7 @@ gradients of convolution with respect to the bias. NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +#endif REGISTER_OP("_MklConv2DBackpropInput") .Input("input_sizes: int32") @@ -2156,6 +2160,7 @@ REGISTER_OP("_MklToTf") .Output("output: T") .Attr("T: {half, float, double}") .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to convert a tensor from MKL layout to TensorFlow layout. @@ -2177,6 +2182,7 @@ REGISTER_OP("_MklInputConversion") "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " "complex64, complex128}") .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to process the inputs to an elementwise MKL op. Both inputs need to be either in TF or in MKL format. This op is added before every diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index f6c668f5c9..416ce9c0d8 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -43,7 +43,12 @@ REGISTER_OP("RandomUniformInt") .Attr("seed2: int = 0") .Attr("Tout: {int32, int64}") .Attr("T: {int32, int64}") - .SetShapeFn(shape_inference::RandomShape); + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::RandomShape(c); + }); REGISTER_OP("RandomStandardNormal") .Input("shape: T") diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 05f216a83e..469f193cf4 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -123,6 +123,11 @@ REGISTER_OP("StringSplit") return Status::OK(); }); +REGISTER_OP("StringStrip") + .Input("input: string") + .Output("output: string") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("EncodeBase64") .Input("input: string") .Output("output: string") diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 6ce9595fb6..dc7b588898 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -737,6 +737,57 @@ REGISTER_OP("ResourceApplyAdam") return ApplyAdamShapeFn(c, false /* sparse */); }); +static Status ApplyAdaMaxShapeFn(InferenceContext* c, bool sparse) { + ShapeHandle unused; + ShapeHandle s = ShapeOrHandleShape(c, 0); // var + TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m + TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1 + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2 + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s)); + if (c->num_outputs() > 0) { + c->set_output(0, s); + } + return Status::OK(); +} + +REGISTER_OP("ApplyAdaMax") + .Input("var: Ref(T)") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .SetShapeFn([](InferenceContext* c) { + return ApplyAdaMaxShapeFn(c, false /* sparse */); + }); + +REGISTER_OP("ResourceApplyAdaMax") + .Input("var: resource") + .Input("m: resource") + .Input("v: resource") + .Input("beta1_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .SetShapeFn([](InferenceContext* c) { + return ApplyAdaMaxShapeFn(c, false /* sparse */); + }); + static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var |