/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("SetSize") .Input("set_indices: int64") .Input("set_values: T") .Input("set_shape: int64") .Attr("validate_indices: bool = true") .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") .Output("size: int32") .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("DenseToDenseSetOperation") .Input("set1: T") .Input("set2: T") .Attr("set_operation: string") .Attr("validate_indices: bool = true") .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") .Output("result_indices: int64") .Output("result_values: T") .Output("result_shape: int64") .SetShapeFn([](InferenceContext* c) { if (c->num_inputs() != 2) { return errors::InvalidArgument("len(inputs) != 2."); } // The following should stay in sync with `ComputeDenseToDense` shape // assertions in kernels/set_kernels.cc. // Dimension n contains the set values to be compared, so ranks must be // >= 2, and the first n-1 dimensions of inputs and output must be // compatible. DimensionHandle output_rank; ShapeHandle input0_shape = c->input(0); TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); if (c->RankKnown(input0_shape)) { const int32 input0_rank = c->Rank(input0_shape); ShapeHandle input1_shape = c->input(1); TF_RETURN_IF_ERROR( c->WithRank(input1_shape, input0_rank, &input1_shape)); if (c->RankKnown(input1_shape)) { // If both ranks are specified, the first n-1 dims must be compatible. const int32 rank = c->Rank(input1_shape); ShapeHandle group0_shape; TF_RETURN_IF_ERROR( c->Subshape(input0_shape, 0, rank - 1, &group0_shape)); ShapeHandle group1_shape; TF_RETURN_IF_ERROR( c->Subshape(input1_shape, 0, rank - 1, &group1_shape)); ShapeHandle unused_shape; TF_RETURN_IF_ERROR( c->Merge(group0_shape, group1_shape, &unused_shape)); } output_rank = c->MakeDim(input0_rank); } else { ShapeHandle input1_shape = c->input(1); TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape)); if (c->RankKnown(input1_shape)) { output_rank = c->MakeDim(c->Rank(input1_shape)); } else { output_rank = c->UnknownDim(); } } c->set_output(0, c->Matrix(c->UnknownDim(), output_rank)); c->set_output(1, c->Vector(c->UnknownDim())); c->set_output(2, c->Vector(output_rank)); return Status::OK(); }); REGISTER_OP("DenseToSparseSetOperation") .Input("set1: T") .Input("set2_indices: int64") .Input("set2_values: T") .Input("set2_shape: int64") .Attr("set_operation: string") .Attr("validate_indices: bool = true") .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") .Output("result_indices: int64") .Output("result_values: T") .Output("result_shape: int64") .SetShapeFn([](InferenceContext* c) { if (c->num_inputs() != 4) { return errors::InvalidArgument("len(inputs) != 4."); } // The following should stay in sync with `ComputeDenseToSparse` shape // assertions in kernels/set_kernels.cc. // Ranks must be compatible, and be >= 2. ShapeHandle input1_shape_shape = c->input(3); TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( c, c->input(1), c->input(2), input1_shape_shape)); DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); DimensionHandle output_rank_dim; ShapeHandle input0_shape = c->input(0); TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); if (c->RankKnown(input0_shape)) { const int32 input0_rank = c->Rank(input0_shape); TF_RETURN_IF_ERROR( c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); output_rank_dim = c->MakeDim(input0_rank); } else if (c->ValueKnown(input1_rank_dim)) { output_rank_dim = input1_rank_dim; } else { output_rank_dim = c->UnknownDim(); } c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); c->set_output(1, c->Vector(c->UnknownDim())); c->set_output(2, c->Vector(output_rank_dim)); return Status::OK(); }); REGISTER_OP("SparseToSparseSetOperation") .Input("set1_indices: int64") .Input("set1_values: T") .Input("set1_shape: int64") .Input("set2_indices: int64") .Input("set2_values: T") .Input("set2_shape: int64") .Attr("set_operation: string") .Attr("validate_indices: bool = true") .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") .Output("result_indices: int64") .Output("result_values: T") .Output("result_shape: int64") .SetShapeFn([](InferenceContext* c) { if (c->num_inputs() != 6) { return errors::InvalidArgument("len(inputs) != 6."); } // The following should stay in sync with `ComputeSparseToSparse` shape // assertions in kernels/set_kernels.cc. // Ranks must be compatible, and be >= 2. ShapeHandle input0_shape_shape = c->input(2); ShapeHandle input1_shape_shape = c->input(5); TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( c, c->input(0), c->input(1), input0_shape_shape)); TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( c, c->input(3), c->input(4), input1_shape_shape)); DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0); DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); DimensionHandle output_rank_dim; if (c->ValueKnown(input0_rank_dim)) { const int64 input0_rank = c->Value(input0_rank_dim); if (input0_rank < 2) { return errors::InvalidArgument("Input 0, expected rank >= 2, got ", input0_rank, "."); } TF_RETURN_IF_ERROR( c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); output_rank_dim = input0_rank_dim; } else if (c->ValueKnown(input1_rank_dim)) { const int64 input1_rank = c->Value(input1_rank_dim); if (input1_rank < 2) { return errors::InvalidArgument("Input 1, expected rank >= 2, got ", input1_rank, "."); } output_rank_dim = input1_rank_dim; } else { output_rank_dim = c->UnknownDim(); } c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); c->set_output(1, c->Vector(c->UnknownDim())); c->set_output(2, c->Vector(output_rank_dim)); return Status::OK(); }); } // namespace tensorflow