diff options
Diffstat (limited to 'tensorflow/core')
58 files changed, 1402 insertions, 274 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6a3ee3c1cb..900a0e11c4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1242,6 +1242,7 @@ cc_library( srcs = [ "ops/math_grad.cc", "ops/random_grad.cc", + "ops/stateless_random_grad.cc", ], linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 visibility = ["//visibility:public"], diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt index 4433693759..d158f4b502 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt @@ -4,16 +4,23 @@ op { in_arg { name: "arguments" description: <<END - A list of tensors whose types are Targuments, corresponding to the inputs the - function should be mapped over. + A list of tensors whose types are `Targuments`, corresponding to the inputs + the function should be mapped over. +END + } + in_arg { + name: "captured_inputs" + description: <<END + A list of tensors whose types are `Tcaptured`, corresponding to the captured + inputs of the defun. END } out_arg { name: "output" description: <<END - A list of output tensors whose types are output_types and whose dimensions 0 - are the same as the dimensions 0 of the tensors in arguments, and whose - remaining dimensions correspond to those in output_shapes. + A list of output tensors whose types are `output_types` and whose dimensions + 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose + remaining dimensions correspond to those in `output_shapes`. END } attr { @@ -21,6 +28,10 @@ END description: "A list of types." } attr { + name: "Tcaptured" + description: "A list of types." + } + attr { name: "output_types" description: "A list of types." } @@ -29,6 +40,6 @@ END description: "A list of shapes." } summary: <<END - Maps a function on the list of tensors unpacked from inputs on dimension 0. + Maps a function on the list of tensors unpacked from arguments on dimension 0. END } diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 5246090ab3..fe0fcc9508 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -18,6 +18,16 @@ END Scalar defining the number of characters to include in each substring END } + attr { + name: "unit" + description: <<END +The unit that is used to create the substring. One of: `"BYTE"` (for +defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +UTF-8. +END + } out_arg { name: "output" description: <<END diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt new file mode 100644 index 0000000000..3d937c745c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPushBackBatch" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt new file mode 100644 index 0000000000..44f25b5d93 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "EmptyTensorList" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt index 4778d7927c..4fb9ee56e9 100644 --- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "Substr" - endpoint { - name: "strings.substr" - } - endpoint { - name: "substr" - deprecated: true - } + visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt new file mode 100644 index 0000000000..45fc55e71e --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListConcatLists" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt new file mode 100644 index 0000000000..e1ad713e7f --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListElementShape" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt new file mode 100644 index 0000000000..4aaefba3c5 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListFromTensor" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt new file mode 100644 index 0000000000..aaf607d70e --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListGather" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt new file mode 100644 index 0000000000..3bb5f39cbc --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListGetItem" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt new file mode 100644 index 0000000000..a04c20bb8a --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListLength" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt new file mode 100644 index 0000000000..9287162f22 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPopBack" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt new file mode 100644 index 0000000000..da2bc11721 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPushBack" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt new file mode 100644 index 0000000000..77e63747d5 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListReserve" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt new file mode 100644 index 0000000000..0015189d7f --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListScatter" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt new file mode 100644 index 0000000000..4999ee7ad9 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListSetItem" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt new file mode 100644 index 0000000000..2dc7b2784b --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListStack" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 3b2dc6a050..7cb90de3c7 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -522,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams( InitInstanceSharedParams( gr, cp, ir, [this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) { - DCHECK(!ir->out_mu.try_lock()); DCHECK(ir->out_mu_available); ir->status.Update(s); ir->out_mu.unlock(); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 419867ff58..e81e61b633 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -473,16 +473,16 @@ bool ReplaceTensorWithConstant( // 1) Do not replace another constant. // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 3) If the size of the constant in bytes is too large (> + // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY + // constraint, do not replace it. + // 4) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. - // 4) If the constant op created does not have a kernel implementation + // 5) If the constant op created does not have a kernel implementation // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) If the constant op for the device has different output memory type - // from the original op output memory type, do not replace it. if (tensor.first->IsConstant()) { return false; } @@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant( return false; } bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; - if (memory_type == HOST_MEMORY && !is_int32) { + if ((memory_type == HOST_MEMORY && !is_int32) || + (memory_type == DEVICE_MEMORY && is_int32)) { return false; } } @@ -535,23 +536,6 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } - if (partition_device && device_type != DEVICE_CPU) { - MemoryType original_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, - &original_output_memory_type) - .ok()) { - return false; - } - MemoryType const_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, - &const_output_memory_type) - .ok()) { - return false; - } - if (original_output_memory_type != const_output_memory_type) { - return false; - } - } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index cf1cd4134e..5c8369de87 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m, m->insert(*it); } } + // For any attr-value pairs that exist in the op def (from op registry) but + // not `m`, fill them into `m`, so that we can run a TFE_Op without having to + // specify all the default attr values (e.g. for matmul, the `transpose_a` + // attr defaults to false). + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name_.c_str(), &op_def); + // This is expected, if this op is a custom function, and is therefore not + // present in the op registry. + if (!s.ok()) return; + + DCHECK(op_def); + for (const auto& attr_def : op_def->attr()) { + if (attr_def.has_default_value() && !m->count(attr_def.name())) { + SetInAttrValueMap(m, attr_def.name(), attr_def.default_value()); + } + } } const NodeDef& AttrBuilder::BuildNodeDef() { diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index cbe6a1cb50..c114ea4ba0 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -110,6 +110,12 @@ class AttrBuilder { using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>; void MayBeInitializeNodeDef(); + // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as + // well as any default attr-value pairs from the associated op_def, if there + // is one. + // + // If `include_those_in_node_def` is true, also include any attr-value pairs + // from `node_def_`. void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const; template <class T> diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 18420b60fd..f23cefb33d 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -70,7 +70,9 @@ EagerContext::EagerContext(const SessionOptions& opts, async_default_(async), log_memory_(LogMemory::IsEnabled()), env_(opts.env), - use_send_tensor_rpc_(false) { + use_send_tensor_rpc_(false), + pin_small_ops_to_cpu_(ReadBoolFromEnvVar( + "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) { if (device_mgr_owned) { local_device_manager_.reset(device_mgr); local_unowned_device_manager_ = nullptr; diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 5ed6057ec6..15eeaa8066 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -202,6 +202,7 @@ class EagerContext { // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used // instead (which in-turn use WorkerService.RecvTensor RPCs). bool UseSendTensorRPC() { return use_send_tensor_rpc_; } + bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } private: void InitDeviceMapAndAsync(); @@ -293,6 +294,7 @@ class EagerContext { #endif bool use_send_tensor_rpc_; + const bool pin_small_ops_to_cpu_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 1bc63616d0..a52f933d75 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -579,19 +579,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, return Status::OK(); #endif } -} // namespace -Status EagerExecute(EagerOperation* op, - gtl::InlinedVector<TensorHandle*, 2>* retvals, - int* num_retvals) { - // Ensure all resource-touching ops run in the device the resource is, - // regardless of anything else that has been specified. This is identical to - // the graph mode behavior. +// The Op device may be updated if: +// - A resource touching input is specified: all resource-touching ops run in +// the device the resource is, regardless of anything else that has been +// specified. This is identical to the graph mode behavior. +// +// - All op inputs are on the CPU, small (<64 elements) and integers +// (int32/int64). This can be disabled by setting the environment variable +// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false". +Status MaybeUpdateOpDevice(EagerOperation* op) { EagerContext* ctx = op->EagerContext(); + bool device_set_for_resource_variable = false; + bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU(); + for (int i = 0; i < op->Inputs().size(); ++i) { Device* input_op_device = nullptr; - auto status = op->Inputs()[i]->OpDevice(&input_op_device); - if (!status.ok()) return status; + TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device)); VLOG(2) << "for op " << op->Name() << " input " << i << " " << DataTypeString(op->Inputs()[i]->dtype) << " " << (input_op_device == nullptr ? "cpu" : input_op_device->name()) @@ -603,8 +607,53 @@ Status EagerExecute(EagerOperation* op, << d->name() << " because input #" << i << " is a resource in this device."; op->SetDevice(d); + + device_set_for_resource_variable = true; + all_inputs_eligible_for_cpu_pinning = false; + } else if (all_inputs_eligible_for_cpu_pinning) { + TensorHandle* handle = op->Inputs()[i]; + + // Input is on CPU. + if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) { + all_inputs_eligible_for_cpu_pinning = false; + continue; + } + + if (handle->dtype != DataType::DT_INT32 && + handle->dtype != DataType::DT_INT64) { + all_inputs_eligible_for_cpu_pinning = false; + continue; + } + + int64 num_elements; + TF_RETURN_IF_ERROR(handle->NumElements(&num_elements)); + if (num_elements > 64) { + all_inputs_eligible_for_cpu_pinning = false; + } } } + + // Ops without inputs are usually ops that generate a tensor in some way and + // usually require being present on whatever device they are scheduled on + // - for e.g. VarHandleOp or _Recv). + // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for + // an op, but there is a GPU kernel? + if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) { + VLOG(1) << "Forcing op " << op->Name() + << " to be on the CPU since all input tensors have an " + "int32/int64 dtype, and are small (less than 64 elements)."; + op->SetDevice(ctx->HostCPU()); + } + + return Status::OK(); +} +} // namespace + +Status EagerExecute(EagerOperation* op, + gtl::InlinedVector<TensorHandle*, 2>* retvals, + int* num_retvals) { + TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op)); + bool op_is_local = IsLocal(op->EagerContext(), op->Device()); if (op_is_local) { diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a02084f223..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 362092a6cf..db10f586bc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); Output g = ops::Shape(s.WithOpName("g"), c); Output h = ops::Fill(s.WithOpName("h"), g, zero); + Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1}); + Output j = ops::Sum(s.WithOpName("j"), a, zero_idx); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { ASSERT_EQ(2, shape_f.dim_size()); EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size()); EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size()); + + const auto shape_j = properties.GetOutputProperties("j").at(0).shape(); + ASSERT_EQ(1, shape_j.dim_size()); + EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size()); } TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ca5d3a6dfd..3d0d95bba7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices( // We can't do anything if we don't know the rank of the input. return Status::OK(); } - const int rank = input_prop.shape().dim_size(); - if (rank == 0) { + const int input_rank = input_prop.shape().dim_size(); + if (input_rank < 1) { // Unexpected graph, don't try to change it. return Status::OK(); } + const OpInfo::TensorProperties& reduction_indices_prop = input_props[1]; + DataType dtype = reduction_indices_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape()); + const int num_reduction_indices = reduction_indices_shape.num_elements(); + const std::vector<OpInfo::TensorProperties>& output_props = properties.GetOutputProperties(node->name()); if (output_props.size() != 1) { return Status::OK(); } - const bool keep_dims = - node->attr().count("keep_dims") && node->attr().at("keep_dims").b(); const OpInfo::TensorProperties& output_prop = output_props[0]; - PartialTensorShape output_shape(output_prop.shape()); - if (output_shape.num_elements() != 1) { - bool full_reduction = false; + const int output_rank = + output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size(); + + bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank; + if (!full_reduction) { + // A full reduction will generate a tensor of one of the shapes + // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of + // elements in the output of the reduction, we may deduce it from reshape + // nodes following it. for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { - if (!IsReshape(*fanout) && !keep_dims) { - // Depending on how it's setup, a full reduction will generate a tensor - // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we - // rely on the existence of a reshape node following the reduction to - // ensure that the fanout is fed a scalar of the right shape. + full_reduction = false; + if (!IsReshape(*fanout)) { return Status::OK(); } const std::vector<OpInfo::TensorProperties>& reshape_props = @@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices( } } - const OpInfo::TensorProperties& reduction_prop = input_props[1]; - DataType dtype = reduction_prop.dtype(); - if (dtype != DT_INT32 && dtype != DT_INT64) { - return Status::OK(); - } - // We know it's a full reduction. We can generate the set of indices to - // reduce. + // We know it's a full reduction. We can generate the full set of indices to + // reduce as a constant node. string const_name = OptimizedNodeName(*node, "-reduction_indices"); if (node_map_->GetNode(const_name)) { return Status::OK(); } NodeDef* reduction_indices = graph_->add_node(); - Tensor value(dtype, TensorShape({rank})); - for (int i = 0; i < rank; ++i) { + Tensor value(dtype, TensorShape({input_rank})); + for (int i = 0; i < input_rank; ++i) { if (dtype == DT_INT32) { value.vec<int32>()(i) = i; } else { @@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices( } TF_RETURN_IF_ERROR( CreateNodeDef(const_name, TensorValue(&value), reduction_indices)); + reduction_indices->set_device(node->device()); string ctrl_dep = AddControlDependency(node->input(1), graph_, node_map_.get()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b09360a2c2..fab01edfed 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) { } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output input = - ops::Placeholder(s.WithOpName("input"), DT_FLOAT, - ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); - Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); - Output sum = ops::Sum(s.WithOpName("sum"), input, indices); - Output size = ops::Const(s.WithOpName("size"), 1, {1}); - Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + for (bool use_reshape : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + // If use_reshape is false, we need to now the number of indices to apply + // the rewrite. + Output indices = ops::Placeholder( + s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + if (use_reshape) { + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + } - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("reshape"); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back(use_reshape ? "reshape" : "sum"); - auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); - Tensor indices_t(DT_INT32, TensorShape({2})); - indices_t.flat<int>()(0) = 0; - indices_t.flat<int>()(1) = 1; - auto tensors_expected = EvaluateNodes( - item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors_expected.size()); + auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); + Tensor indices_t(DT_INT32, TensorShape({2})); + indices_t.flat<int>()(0) = 0; + indices_t.flat<int>()(1) = 1; + auto tensors_expected = EvaluateNodes( + item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors_expected.size()); - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - // Run a second time to make sure the optimization is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - int found = 0; - for (const auto& node : output.node()) { - if (node.name() == "ConstantFolding/sum-reduction_indices") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^indices", node.input(0)); - EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); - } else if (node.name() == "sum") { - ++found; - EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); - } else if (node.name() == "indices") { - ++found; + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, + TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } else if (node.name() == "indices") { + ++found; + } } + EXPECT_EQ(3, found); + + auto tensors = EvaluateNodes(output, item.fetch, + {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); } - EXPECT_EQ(3, found); +} - auto tensors = EvaluateNodes(output, item.fetch, - {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); +TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) { + for (bool input_rank_known : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape( + PartialTensorShape({-1, -1}))) + : ops::Placeholder(s.WithOpName("input"), DT_FLOAT)); + Output indices = + ops::Placeholder(s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape( + PartialTensorShape({input_rank_known ? 1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("sum"); + + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + CompareGraphs(item.graph, output); + } } TEST_F(ConstantFoldingTest, LargeConstant) { diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 755af3361e..ee7c14e3ab 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -524,6 +524,7 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + "//tensorflow/cc:ops", "@com_google_absl//absl/strings", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 9328a7ca99..a9254ed58b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -44,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, // Function inputs and outputs are the same as original, just // with different shapes. *vectorized_func->mutable_signature() = orig_func.signature(); - graph_utils::SetUniqueGraphFunctionName("vectorized_function", library, + graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library, vectorized_func); // Add MapDefun node @@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, map_defun_node->add_input(input.name()); } (*map_defun_node->mutable_attr())["Targuments"] = t_args; + AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 37aa24b947..985d6c6c3a 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -13,9 +13,19 @@ VECTORIZER_DEPS = [ ] + tf_protos_all() cc_library( + name = "wrapped_tensor", + hdrs = ["wrapped_tensor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "vectorizer", hdrs = ["vectorizer.h"], deps = [ + ":wrapped_tensor", "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ] + tf_protos_all(), diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index 3af6bab409..f445157531 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -19,13 +19,13 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class CastVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); @@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer { auto new_cast_node = outer_scope->AddNode(node.def(), &s); TF_RETURN_IF_ERROR(s); - // Add input and output mappings - input_ports->push_back({new_cast_node, 0}); - output_ports->push_back({new_cast_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, + 0); + + // Add output mappings + outputs->push_back({new_cast_node, 0, true}); return Status::OK(); } }; REGISTER_VECTORIZER("Cast", CastVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 74ce520ce1..f1ba741821 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -19,15 +19,15 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class UnpackVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; - if (node.num_inputs() != 1) { + if (node.num_inputs() != 1 || inputs.size() != 1) { return errors::Internal("Unpack op should only have one input."); } @@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer { int new_axis = node.def().attr().at("axis").i() + 1; new_unpack_node->AddAttr("axis", new_axis); - // Add the input mappings - input_ports->push_back({new_unpack_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, + new_unpack_node, 0); // Add the output mappings int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - output_ports->push_back({new_unpack_node, i}); + outputs->push_back({new_unpack_node, i, true}); } return Status::OK(); @@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer { REGISTER_VECTORIZER("Unpack", UnpackVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index 56eb88c95e..8d4676aae0 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -18,15 +18,12 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { - -// Describes a tensor with its operation Node and output position -typedef std::pair<Node*, int> Port; // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. @@ -36,17 +33,17 @@ class Vectorizer { // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs. The new Node(s) collectively have the + // on elements of `inputs`. The new Node(s) collectively have the // same number of input and output ports as the node being converted. - // Adds mappings for the new nodes' input and output ports to `inputs` and - // `outputs` respectively, where the i'th Port in inputs/outputs - // corresponds to the i'th input/output port of the node to be converted. + // Adds edges between the newly created nodes and nodes in `inputs`, and adds + // mappings to the new nodes' output ports to `outputs`, where the i'th + // value in `outputs` corresponds to the i'th output port of the node + // to be converted. virtual Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) = 0; + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) = 0; }; -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc index a6551e36ac..e1cf77a7d5 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -19,7 +19,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { VectorizerRegistry* VectorizerRegistry::Global() { static VectorizerRegistry* registry = new VectorizerRegistry; @@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type, vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( op_type, std::move(vectorizer))); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h index 16159d47ca..ad54c74933 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -23,7 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { // A global VectorizerRegistry is used to hold all the vectorizers. class VectorizerRegistry { @@ -59,16 +58,12 @@ class VectorizerRegistration { #define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) -#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ - static ::tensorflow::grappler::vectorization_utils:: \ - vectorizer_registration::VectorizerRegistration \ - vectorizer_registration_##ctr( \ - op_type, \ - ::std::unique_ptr< \ - ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ - new vectorizer())) +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorizer_registration:: \ + VectorizerRegistration vectorizer_registration_##ctr( \ + op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \ + new vectorizer())) -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc index 663ceba027..054aeb9a8f 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -20,13 +20,12 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { class TestVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* inputs, - std::vector<Port>* outputs) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { return Status::OK(); } }; @@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) { NodeDef node_def; Status s; Node* node = g.AddNode(node_def, &s); - std::vector<Port> inputs, outputs; - EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok()); + std::vector<WrappedTensor> inputs, outputs; + EXPECT_TRUE( + vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok()); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h new file mode 100644 index 0000000000..4439b4ab4e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace grappler { + +// Represents a tensor that has been vectorized. +struct WrappedTensor { + Node* const node; + const int output_index; + + // Whether the tensor is stacked, i.e. represents the results of applying + // the operation on all slices of the input, where each row i of the + // tensor corresponds to the op's output on slice i of the input. False + // if the tensor is not stacked, i.e. represents the result of the op on + // a single slice of the input, where the result does not vary between + // slices. + bool stacked; + + WrappedTensor(Node* node, int output_index, bool stacked) + : node(node), output_index(output_index), stacked(stacked) {} +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 2d6cf562b1..ba857ab5d9 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" -#include <memory> #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/ops.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" @@ -28,13 +28,13 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { @@ -132,7 +132,8 @@ class Vectorization { const NodeDef& map_defun_node, FunctionDef** result); private: - // Converts FunctionDefs to Graphs. + // Converts FunctionDefs to Graphs and adds mappings from + // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_. Status Initialize(const FunctionDef& outer_scope, const NodeDef& map_defun_node); @@ -162,9 +163,30 @@ class Vectorization { // the conversion map. Status AddConversionMapping(Node* op_node); - // Maps a tensor to the corresponding vectorized tensor. For example, - // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0} - std::map<TensorDesc, TensorDesc> conversion_map_; + // Given a tensor t in `unstacked`, stacks it by doing the equivalent of + // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of + // inputs to `map_defun_node_`. This stacked tensor will be compatible with + // the expected output shape of `map_defun_node_`. + // This is equivalent to the _stack function in python Pfor. + Status StackTensor(WrappedTensor* unstacked, TensorDesc* result); + + // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by + // doing a depth-first search from the ret nodes. Lifts nodes that are + // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly + // and add mappings to `conversion_map_`. + Status AddUnstackedNodeMappings(); + + // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor + // is unstacked. + bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status); + + // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input + // nodes to `conversion_map_`. + Status AddArgNodeMappings(); + + // Maps a tensor to the corresponding WrappedTensor. For example, + // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true) + std::map<TensorDesc, WrappedTensor> conversion_map_; // Unconvertible ret nodes std::set<Node*> unconvertible_; @@ -180,6 +202,10 @@ class Vectorization { std::unique_ptr<Graph> outer_scope_; std::unique_ptr<FunctionBody> map_defun_fn_; Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + + // Caches the loop_len_node_ needed for tiling unstacked output. This + // corresponds to a vector with one element. + Node* loop_len_node_ = nullptr; // Owned by `outer_scope` Status status_; }; @@ -197,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) { return errors::Unimplemented("No vectorizer registered for op: ", op_node->type_string()); } - std::vector<Port> input_ports, output_ports; - input_ports.reserve(op_node->num_inputs()); - output_ports.reserve(op_node->num_outputs()); - TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), - &input_ports, &output_ports)); + std::vector<WrappedTensor> inputs, outputs; + inputs.reserve(op_node->num_inputs()); + outputs.reserve(op_node->num_outputs()); std::vector<const Edge*> input_edges; TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); - if (op_node->num_outputs() != output_ports.size() || - op_node->num_inputs() != input_ports.size() || - input_edges.size() != input_ports.size()) { - return errors::Internal("Vectorizer inputs/outputs don't match."); - } - - // Promote the inputs of the op to MapDefun outputs and connect the edges - // accordingly. + // The inputs for the node to be converted may already have been converted + // themselves. For those that are not, we promote them to MapDefun outputs. for (size_t i = 0; i < op_node->num_inputs(); ++i) { auto edge = input_edges[i]; - TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, - {edge->src(), edge->src_output()})); - outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1, - input_ports[i].first, input_ports[i].second); + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + inputs.push_back(*found); + } else { + // TODO(rachelim): Handle the case where unconverted inputs are unstacked. + // We assume that all unconverted inputs will be stacked, since we + // converted all unstacked nodes in `Initialize`. However, it's actually + // possible that yet-unconverted nodes may produce unstacked outputs after + // they are vectorized. (For example, see the "Shape" converter in + // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects + // an unstacked input but receives a stacked one, vectorizer->Vectorize + // will return an error. + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + int output_index = map_defun_fn_->ret_nodes.size() - 1; + inputs.push_back({map_defun_node_, output_index, true}); + } + } + + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs)); + + if (op_node->num_outputs() != outputs.size()) { + return errors::Internal( + "Number of vectorizer outputs does not match. Expected: ", + op_node->num_outputs(), " Actual: ", outputs.size()); } // Add output mappings. for (size_t i = 0; i < op_node->num_outputs(); ++i) { - conversion_map_.insert({{op_node, i}, std::move(output_ports[i])}); + conversion_map_.insert({{op_node, i}, outputs[i]}); } return Status::OK(); @@ -239,13 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) { TensorDesc output({ret_edge->src(), ret_edge->src_output()}); TensorDesc converted_output; - if (auto found = gtl::FindOrNull(conversion_map_, output)) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - converted_output = *found; - } else { + + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + auto found = gtl::FindOrNull(conversion_map_, output); + if (!found) { TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); - converted_output = conversion_map_.at(output); + found = &conversion_map_.at(output); + } + + if (found->stacked) { + converted_output = {found->node, found->output_index}; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); } ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, @@ -297,6 +346,7 @@ void Vectorization::VectorizeHelper() { map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } } + Status Vectorization::Initialize(const FunctionDef& outer_scope, const NodeDef& map_defun_node) { // Convert outer_scope and map_defun_fn to FunctionBodys so we can @@ -337,16 +387,183 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope, } map_defun_node_ = outer_scope_->FindNodeId(node_id); - // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to - // the conversion map + TF_RETURN_IF_ERROR(AddArgNodeMappings()); + + TF_RETURN_IF_ERROR(AddUnstackedNodeMappings()); + loop_len_node_ = nullptr; + + return Status::OK(); +} + +// TODO(rachelim): It might be profitable to use the C++ API for this instead of +// NodeBuilder +Status Vectorization::StackTensor(WrappedTensor* unstacked, + TensorDesc* result) { + // Note that all these nodes are necessary as the size of the batch may not be + // constant. + if (unstacked->stacked) { + return errors::Internal("Can only stack unstacked tensor."); + } + + Graph* g = outer_scope_.get(); + auto node_builder = [](StringPiece op) { + return NodeBuilder(strings::StrCat("vectorized/stack/", op), op); + }; + + auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph, + Node** result) { + TF_RETURN_IF_ERROR(val.status); + return node_builder("Const") + .Attr("value", val.tensor) + .Attr("dtype", val.tensor.dtype()) + .Finalize(graph, result); + }; + + // If loop_len_node_ hasn't been created yet, add the node and cache it. + if (loop_len_node_ == nullptr) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node)); + + Node* shape_node; + TF_RETURN_IF_ERROR( + node_builder("Shape").Input(input_node).Finalize(g, &shape_node)); + + Node* const_vec_0; + TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0)); + Node* const_vec_1; + TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1)); + + Node* strided_slice_node; + TF_RETURN_IF_ERROR(node_builder("StridedSlice") + .Input(shape_node) // input + .Input(const_vec_0) // begin + .Input(const_vec_1) // end + .Input(const_vec_1) // strides + .Finalize(g, &strided_slice_node)); + + // Produces a vector of length 1 + TF_RETURN_IF_ERROR(node_builder("Reshape") + .Input(strided_slice_node) // tensor + .Input(const_vec_1) // shape + .Finalize(g, &loop_len_node_)); + } + + Node* ones_shape; + TF_RETURN_IF_ERROR(node_builder("Shape") + .Input(unstacked->node) // input + .Finalize(g, &ones_shape)); + + Node* ones; + TF_RETURN_IF_ERROR( + node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones)); + + Node* const_0; + TF_RETURN_IF_ERROR(make_const(0, g, &const_0)); + + Node* multiples; + TF_RETURN_IF_ERROR(node_builder("Concat") + .Input(const_0) // concat_dim + .Input({{loop_len_node_, 0}, {ones, 0}}) // values + .Finalize(g, &multiples)); + + Node* expand_dims; + TF_RETURN_IF_ERROR(node_builder("ExpandDims") + .Input(unstacked->node) // input + .Input(const_0) // dim + .Finalize(g, &expand_dims)); + + TF_RETURN_IF_ERROR(node_builder("Tile") + .Input(expand_dims) // input + .Input(multiples) // multiples + .Finalize(g, &result->first)); + result->second = 0; + return Status::OK(); +} + +Status Vectorization::AddArgNodeMappings() { for (auto arg_node : map_defun_fn_->arg_nodes) { Node* input_node; TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {input_node, 0}}); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); + + // Control inputs + conversion_map_.insert({{arg_node, Graph::kControlSlot}, + {input_node, Graph::kControlSlot, true}}); + } + return Status::OK(); +} + +bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, + Status* status) { + if (auto found = gtl::FindOrNull(conversion_map_, tensor)) { + return !found->stacked; + } + + if (tensor.first->op_def().is_stateful()) { + // We don't lift stateful nodes directly out of the MapDefun, since they may + // have to be executed N times. + return false; } + bool is_unstacked = true; + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + // A node is unstacked if all of its inputs are unstacked + is_unstacked &= AddUnstackedNodeMappingsHelper( + {edge->src(), edge->src_output()}, status); + } + + if (!is_unstacked) { + return false; + } + + // If the node is unstacked, we copy it into outer_scope_ and + // add it to the map. Note that we don't clean up the nodes that are copied + // in map_defun_fn_, and rely on them being pruned out later. + Node* node = outer_scope_->AddNode(tensor.first->def(), status); + if (!status->ok()) return true; + + // Add input edges to nodes that should already have been lifted. + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + outer_scope_->AddEdge(found->node, found->output_index, node, + edge->dst_input()); + } else { + status->Update(errors::Internal( + "Could not find input conversion even though we did depth first " + "conversion.")); + } + } + + // Add output mappings + for (int i = 0; i < tensor.first->num_outputs(); ++i) { + conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)}); + } + conversion_map_.insert({{tensor.first, Graph::kControlSlot}, + WrappedTensor(node, Graph::kControlSlot, false)}); + + return true; +} + +Status Vectorization::AddUnstackedNodeMappings() { + SetVector<Node*> unstacked_nodes; + Status s; + for (const auto& ret_node : map_defun_fn_->ret_nodes) { + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge)); + AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s); + TF_RETURN_IF_ERROR(s); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index 1ff62217dd..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, func.set_name(function_name); NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); graph_transforms::SetNodeAttr("output_types", output_types, node); graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); graph_transforms::SetNodeAttr("f", func, node); @@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; + Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); + LOG(ERROR) << s; EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); @@ -670,6 +673,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { cast_node.input(1) == control_input); } +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeConst) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Const:output:0"}}); + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Const", *vectorized)); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node.name()); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | +------+ +------+ | | +// | | |Const | |Const | | | +// | | +---+--+ +---+--+ | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | | +// | +------+ | +// | +------+ |Const | | +// | |Const | +---+--+ | +// | +---+--+ | | +// | : +---v--+ | +// | ::::::> Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | +Stack*+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, + false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto find_const = [vectorized](int val) -> const NodeDef* { + for (const auto& n : vectorized->node_def()) { + if (n.attr().at("value").tensor().int_val(0) == val) { + return &n; + } + } + return nullptr; + }; + + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = find_const(2); + auto const_dep_node = find_const(3); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node->name()); + EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); +} + // TODO(rachelim): More test cases when we get around to implementing them: // [] A badly defined converter, e.g. doesn't produce nodes that have the // same number of outputs/inputs as the nodes to be converted diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9439ab332c..3a920f26f3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4458,7 +4458,12 @@ cc_library( name = "string_util", srcs = ["string_util.cc"], hdrs = ["string_util.h"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@icu//:common", + ], ) STRING_DEPS = [ diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index a04f150e71..9607e9444c 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { static PartialTensorShape MostSpecificCompatibleShape( const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) - return output_tensorshape; + return PartialTensorShape(); + PartialTensorShape output_tensorshape({}); auto dims1 = ts1.dim_sizes(); auto dims2 = ts2.dim_sizes(); for (int d = 0; d < ts1.dims(); d++) { if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); + output_tensorshape.AddDim(dims1[d]); else - output_tensorshape.Concatenate(-1); + output_tensorshape.AddDim(-1); } return output_tensorshape; } diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 6657f2b2b3..705b0393de 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { - // Validates inputs and gets the size of their leading dimension. - *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { - return errors::InvalidArgument( - "All inputs must have rank at least 1. Input ", i, - " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != *batch_size) { - return errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size); - } - } - return Status::OK(); - } - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ComputeOptions* compute_opts = nullptr; @@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel { // all calls to the function are complete. This struct also encapsulates // all the components that need to be passed to each MapFunctionCallFrame. - const std::vector<Tensor> args; + OpInputList args; const std::vector<TensorShape> arg_shapes; + OpInputList captured_inputs; const int64 batch_size; // Output of a compute call @@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel { // Create a copy of output_shapes because every `Compute` may expect a // different output shape. - ComputeOptions(std::vector<Tensor> args, + ComputeOptions(OpInputList args, OpInputList captured_inputs, std::vector<TensorShape> arg_shapes, int64 batch_size, const std::vector<PartialTensorShape>& output_shapes_attr) - : args(std::move(args)), + : args(args), arg_shapes(std::move(arg_shapes)), + captured_inputs(captured_inputs), batch_size(batch_size), output_shapes(output_shapes_attr) {} }; // Get inputs to Compute and check that they are valid. Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { - int64 batch_size = - ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + OpInputList arguments; + TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); + OpInputList captured_inputs; + TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); + + int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i].dims() == 0) { return errors::InvalidArgument( "All inputs must have rank at least 1. Input ", i, " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != batch_size) { + } else if (arguments[i].dim_size(0) != batch_size) { return errors::InvalidArgument( "All inputs must have the same dimension 0. Input ", i, " has leading dimension ", ctx->input(i).dim_size(0), @@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel { } } - std::vector<Tensor> args; std::vector<TensorShape> arg_shapes; - args.reserve(ctx->num_inputs()); - arg_shapes.reserve(ctx->num_inputs()); + arg_shapes.reserve(arguments.size()); - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args.push_back(ctx->input(i)); - arg_shapes.push_back(ctx->input(i).shape()); + for (size_t i = 0; i < arguments.size(); ++i) { + arg_shapes.push_back(arguments[i].shape()); arg_shapes.at(i).RemoveDim(0); } - *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), - batch_size, output_shapes_); + *compute_opts = + new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes), + batch_size, output_shapes_); return Status::OK(); } @@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel { } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= compute_opts_->args.size()) { + if (index < 0 || index >= compute_opts_->args.size() + + compute_opts_->captured_inputs.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } + + if (index >= compute_opts_->args.size()) { + // The function is calling for a captured input + *val = + compute_opts_->captured_inputs[index - compute_opts_->args.size()]; + return Status::OK(); + } + bool result = - val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1), compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); @@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel { // Ensure alignment *val = tensor::DeepCopy(*val); } - return Status::OK(); } diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc index 42fbf95cd3..28940e0849 100644 --- a/tensorflow/core/kernels/dequantize_op.cc +++ b/tensorflow/core/kernels/dequantize_op.cc @@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel { output); } } else if (mode_ == QUANTIZE_MODE_SCALED) { - // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and - // QuantizeAndDequantizeV3 to match this SCALED mode again. const float scale_factor = std::numeric_limits<T>::min() == 0 ? (max_range / std::numeric_limits<T>::max()) diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index fdb4c84c46..3979e4b53a 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -97,6 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, fbody != nullptr, errors::Internal("Could not find handle ", handle), done); + OP_REQUIRES_ASYNC( + ctx, args.size() == fbody->arg_nodes.size(), + errors::InvalidArgument( + "Wrong number of arguments to the op; function expects ", + fbody->arg_nodes.size(), " but PartitionedCall received ", + args.size()), + done); // We need to pass global op_registry as default_registry when creating // graph. So that graph optimization passes can lookup all possible ops // by name. diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 23d76986bf..678d675c4a 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -426,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel { // ADD if value's refcount was 1. mutex_lock ml(*variable->mu()); Tensor* var_tensor = variable->tensor(); + OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()), + errors::InvalidArgument("Cannot update variable with shape ", + var_tensor->shape().DebugString(), + " using a Tensor with shape ", + value.shape().DebugString(), + ", shapes must be equal.")); OP_REQUIRES_OK(context, PrepareToUpdateVariable<Device, T>(context, var_tensor)); functor::DenseUpdate<Device, T, Op> update_functor; diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index 3a9803a052..92c73220d8 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -16,10 +16,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" -namespace { -inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } -} // namespace - namespace tensorflow { // Sets unit value based on str. diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 390cf57702..d40e93ea33 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 }; // TODO(edloper): Add support for: UTF32_CHAR, etc. enum class CharUnit { BYTE, UTF8_CHAR }; +// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char. +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } + // Sets `encoding` based on `str`. Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); @@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit); // Result may be incorrect if the input string is not valid UTF-8. int32 UTF8StrLen(const string& string); +// Get the next UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset, and +// should never be `null`. The function return true if successful. However, if +// the end of the string is reached before the requested characters, then the +// position will point to the end of string and this function will return false. +template <typename T> +bool ForwardNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t size = in.size(); + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) { + // move forward one utf-8 character + do { + ++*pos; + } while (IsTrailByte(in[*pos]) && *pos < size); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +// Get the previous UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset with a +// positive value, relative to the beginning of the string, and should never be +// `null`. The function return true if successful. However, if the beginning of +// the string is reached before the requested character, then the position will +// point to the beginning of the string and this function will return false. +template <typename T> +bool BackNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t start = 0; + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) { + // move back one utf-8 character + do { + --*pos; + } while (IsTrailByte(in[*pos]) && *pos > start); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 07f1d6e767..93c427039d 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/string_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +38,11 @@ namespace tensorflow { template <typename T> class SubstrOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { // Get inputs @@ -69,11 +74,23 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { StringPiece in(input(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } else { @@ -84,11 +101,23 @@ class SubstrOp : public OpKernel { StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } @@ -151,12 +180,24 @@ class SubstrOp : public OpKernel { StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); - OP_REQUIRES( - context, - FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, + FastBoundsCheck(byte_pos, input_bcast(i).size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } break; @@ -205,12 +246,24 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for ", - "string b'", in, "' at index (", i, - ", ", j, ")")); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", + i, ", ", j, ")")); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i, j).assign(sub_in.data(), sub_in.size()); } } @@ -227,12 +280,73 @@ class SubstrOp : public OpKernel { private: // This adjusts the requested position. Note it does not perform any bound // checks. - T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { if (pos_requested < 0) { return s.size() + pos_requested; } return pos_requested; } + + // Return true if successful; otherwise, return false if the `pos` argument + // is out of range in the string. + static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, + T* len) { + if (*pos >= 0) { + return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); + } else { + return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); + } + } + + static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + *char_pos = 0; + // Determine byte position of the substring start. + if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { + return false; + } + // Determine position of the end of the substring. + // The length will be capped at the end of the string, and we ignore whether + // the string had enough characters to handle it or not. + *char_len = *char_pos; + ForwardNUTF8CharPositions(in, len, char_len); + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + // This function expects a negative position relative to the end of the + // string, but will update the character position to a positive number + // relative to the beginning of the string. + static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + // Initially treat the length as position of the end of the substring. + *char_len = in.size(); + // This is the number of character to skip from the end of the string to + // arrive at the position where the substring should end. + T utf8_chars_to_skip = -pos - len; + if (utf8_chars_to_skip < 0) { + utf8_chars_to_skip = 0; + } + // Find the byte position where the substring should end using the computed + // number of characters to skip. + if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { + return false; + } + // Next, determine where the substring should begin. The number of chars to + // skip is the requested position minus the chars we've previously skipped. + *char_pos = *char_len; + if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { + return false; + } + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + CharUnit unit_ = CharUnit::BYTE; }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index 2e07050260..ea6b1ed500 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { // Test data from the TensorFlow README.md. -const char* lines[] = { +const char* ascii_lines[] = { "**TensorFlow** is an open source software library for numerical " "computation using data flow graphs.", "The graph nodes represent mathematical operations, while the graph edges " @@ -64,17 +64,76 @@ const char* lines[] = { "backwards compatibility guarantee like C++, Go, Java, JavaScript and " "Swift."}; +const char* unicode_lines[] = { + "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6" + "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6" + "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6" + "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82", + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba" + "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c" + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba" + "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81" + "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae" + "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89" + "\xe3\x80\x82", + "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93" + "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf" + "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2" + "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1" + "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87" + "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a" + "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9" + "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82", + "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88" + "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef" + "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6" + "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5" + "\x8c\x85\xe3\x80\x82", + "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7" + "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5" + "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd" + "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain" + "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c" + "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba" + "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6" + "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6" + "\xe3\x80\x82", + "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82" + "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96" + "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4" + "\xe3\x80\x82", + "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84" + "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2" + "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6" + "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c" + "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82", +}; + +const char* const kByteUnit = "BYTE"; +const char* const kUTF8Unit = "UTF8_CHAR"; + Tensor GetTestTensor(int batch) { - const int sz = TF_ARRAYSIZE(lines); + const int sz = TF_ARRAYSIZE(ascii_lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = ascii_lines[i % sz]; + } + return t; +} + +Tensor GetTestUTF8Tensor(int batch) { + const int sz = TF_ARRAYSIZE(unicode_lines); Tensor t(DT_STRING, {batch}); auto s = t.flat<string>(); for (int i = 0; i < batch; ++i) { - s(i) = lines[i % sz]; + s(i) = unicode_lines[i % sz]; } return t; } -Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len, + const char* const unit) { Graph* g = new Graph(OpRegistry::Global()); Tensor position(DT_INT32, TensorShape({})); position.flat<int32>().setConstant(pos); @@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { .Input(test::graph::Constant(g, input)) .Input(test::graph::Constant(g, position)) .Input(test::graph::Constant(g, length)) + .Attr("unit", unit) .Finalize(g, nullptr /* node */)); return g; } -void BM_Substr(int iters, int batch_size) { +void BM_SubstrByte(int iters, int batch_size) { testing::StopTiming(); testing::ItemsProcessed(static_cast<int64>(iters)); testing::UseRealTime(); Tensor input = GetTestTensor(batch_size); - Graph* g = SetupSubstrGraph(input, 3, 30); + Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +void BM_SubstrUTF8(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestUTF8Tensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit); testing::StartTiming(); test::Benchmark("cpu", g).Run(iters); } -BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( - 256); +BENCHMARK(BM_SubstrByte) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); +BENCHMARK(BM_SubstrUTF8) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); } // end namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 33f18ae13f..780c6f6448 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -30567,6 +30567,52 @@ op { } } op { + name: "MapDefun" + input_arg { + name: "arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "captured_inputs" + type_list_attr: "Tcaptured" + } + output_arg { + name: "output" + type_list_attr: "output_types" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "f" + type: "func" + } +} +op { name: "MapIncompleteSize" output_arg { name: "size" @@ -71844,6 +71890,48 @@ op { } } op { + name: "Substr" + input_arg { + name: "input" + type: DT_STRING + } + input_arg { + name: "pos" + type_attr: "T" + } + input_arg { + name: "len" + type_attr: "T" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } +} +op { name: "Sum" input_arg { name: "input" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 889a6a4640..ec22eee874 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset") REGISTER_OP("MapDefun") .Input("arguments: Targuments") + .Input("captured_inputs: Tcaptured") .Output("output: output_types") .Attr("Targuments: list(type) >= 1") + .Attr("Tcaptured: list(type) >= 0 = []") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + DataTypeVector t_args; + TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( "`output_shapes` must be the same length as `output_types` (", @@ -918,10 +922,11 @@ REGISTER_OP("MapDefun") } int64 dim_zero = -1; - for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + for (size_t i = 0; i < t_args.size(); ++i) { if (c->Rank(c->input(i)) == 0) { return errors::InvalidArgument( - "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + "Arguments must have rank at least 1. Input ", i, + " has rank of 0."); } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { @@ -929,7 +934,7 @@ REGISTER_OP("MapDefun") dim_zero = c->Value(dim_handle); } else if (c->Value(dim_handle) != dim_zero) { return errors::InvalidArgument( - "Inputs must have the same dimension 0."); + "Arguments must have the same dimension 0."); } } } diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3eff728f03..a9e5e7824d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar<int32>()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index be4c3ed2b6..05379a7d69 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 0e58a9475d..0d8997c1bd 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -15262,6 +15262,10 @@ op { name: "arguments" type_list_attr: "Targuments" } + input_arg { + name: "captured_inputs" + type_list_attr: "Tcaptured" + } output_arg { name: "output" type_list_attr: "output_types" @@ -15273,6 +15277,15 @@ op { minimum: 1 } attr { + name: "Tcaptured" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { name: "output_types" type: "list(type)" has_minimum: true @@ -33748,6 +33761,19 @@ op { } } } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } } op { name: "Sum" diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc new file mode 100644 index 0000000000..331e1d0152 --- /dev/null +++ b/tensorflow/core/ops/stateless_random_grad.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { +REGISTER_OP_NO_GRADIENT("StatelessRandomUniform"); +REGISTER_OP_NO_GRADIENT("StatelessRandomNormal"); +REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal"); +REGISTER_OP_NO_GRADIENT("StatelessMultinomial"); +} // end namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index b4fbde54d9..94d71a4113 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -223,6 +223,7 @@ REGISTER_OP("Substr") .Input("len: T") .Output("output: string") .Attr("T: {int32, int64}") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn([](InferenceContext* c) { ShapeHandle pos_shape = c->input(1); ShapeHandle len_shape = c->input(2); |