aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc1
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc28
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h6
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/context.h2
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc67
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc130
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD10
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc283
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc254
-rw-r--r--tensorflow/core/kernels/BUILD7
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc68
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc7
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc6
-rw-r--r--tensorflow/core/kernels/string_util.cc4
-rw-r--r--tensorflow/core/kernels/string_util.h44
-rw-r--r--tensorflow/core/kernels/substr_op.cc162
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc100
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt88
-rw-r--r--tensorflow/core/ops/dataset_ops.cc11
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
-rw-r--r--tensorflow/core/ops/ops.pbtxt26
-rw-r--r--tensorflow/core/ops/stateless_random_grad.cc23
-rw-r--r--tensorflow/core/ops/string_ops.cc1
58 files changed, 1402 insertions, 274 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6a3ee3c1cb..900a0e11c4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1242,6 +1242,7 @@ cc_library(
srcs = [
"ops/math_grad.cc",
"ops/random_grad.cc",
+ "ops/stateless_random_grad.cc",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
index 4433693759..d158f4b502 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
@@ -4,16 +4,23 @@ op {
in_arg {
name: "arguments"
description: <<END
- A list of tensors whose types are Targuments, corresponding to the inputs the
- function should be mapped over.
+ A list of tensors whose types are `Targuments`, corresponding to the inputs
+ the function should be mapped over.
+END
+ }
+ in_arg {
+ name: "captured_inputs"
+ description: <<END
+ A list of tensors whose types are `Tcaptured`, corresponding to the captured
+ inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
- A list of output tensors whose types are output_types and whose dimensions 0
- are the same as the dimensions 0 of the tensors in arguments, and whose
- remaining dimensions correspond to those in output_shapes.
+ A list of output tensors whose types are `output_types` and whose dimensions
+ 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
+ remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
@@ -21,6 +28,10 @@ END
description: "A list of types."
}
attr {
+ name: "Tcaptured"
+ description: "A list of types."
+ }
+ attr {
name: "output_types"
description: "A list of types."
}
@@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
- Maps a function on the list of tensors unpacked from inputs on dimension 0.
+ Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 5246090ab3..fe0fcc9508 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -18,6 +18,16 @@ END
Scalar defining the number of characters to include in each substring
END
}
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is used to create the substring. One of: `"BYTE"` (for
+defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8
+encoded Unicode code points). The default is `"BYTE"`. Results are undefined if
+`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid
+UTF-8.
+END
+ }
out_arg {
name: "output"
description: <<END
diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
new file mode 100644
index 0000000000..3d937c745c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBackBatch"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
new file mode 100644
index 0000000000..44f25b5d93
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EmptyTensorList"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 4778d7927c..4fb9ee56e9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "Substr"
- endpoint {
- name: "strings.substr"
- }
- endpoint {
- name: "substr"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
new file mode 100644
index 0000000000..45fc55e71e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListConcatLists"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
new file mode 100644
index 0000000000..e1ad713e7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListElementShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
new file mode 100644
index 0000000000..4aaefba3c5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListFromTensor"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..aaf607d70e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGather"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
new file mode 100644
index 0000000000..3bb5f39cbc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
new file mode 100644
index 0000000000..a04c20bb8a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListLength"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
new file mode 100644
index 0000000000..9287162f22
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPopBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
new file mode 100644
index 0000000000..da2bc11721
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
new file mode 100644
index 0000000000..77e63747d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListReserve"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..0015189d7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListScatter"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
new file mode 100644
index 0000000000..4999ee7ad9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListSetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
new file mode 100644
index 0000000000..2dc7b2784b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListStack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 3b2dc6a050..7cb90de3c7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -522,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
- DCHECK(!ir->out_mu.try_lock());
DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 419867ff58..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -473,16 +473,16 @@ bool ReplaceTensorWithConstant(
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 3) If the size of the constant in bytes is too large (>
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
+ // constraint, do not replace it.
+ // 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
- // 4) If the constant op created does not have a kernel implementation
+ // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) If the constant op for the device has different output memory type
- // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if (memory_type == HOST_MEMORY && !is_int32) {
+ if ((memory_type == HOST_MEMORY && !is_int32) ||
+ (memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@@ -535,23 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
- }
- }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index cf1cd4134e..5c8369de87 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
m->insert(*it);
}
}
+ // For any attr-value pairs that exist in the op def (from op registry) but
+ // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
+ // specify all the default attr values (e.g. for matmul, the `transpose_a`
+ // attr defaults to false).
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name_.c_str(), &op_def);
+ // This is expected, if this op is a custom function, and is therefore not
+ // present in the op registry.
+ if (!s.ok()) return;
+
+ DCHECK(op_def);
+ for (const auto& attr_def : op_def->attr()) {
+ if (attr_def.has_default_value() && !m->count(attr_def.name())) {
+ SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
+ }
+ }
}
const NodeDef& AttrBuilder::BuildNodeDef() {
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index cbe6a1cb50..c114ea4ba0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -110,6 +110,12 @@ class AttrBuilder {
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
+ // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
+ // well as any default attr-value pairs from the associated op_def, if there
+ // is one.
+ //
+ // If `include_those_in_node_def` is true, also include any attr-value pairs
+ // from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
template <class T>
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 18420b60fd..f23cefb33d 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -70,7 +70,9 @@ EagerContext::EagerContext(const SessionOptions& opts,
async_default_(async),
log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
- use_send_tensor_rpc_(false) {
+ use_send_tensor_rpc_(false),
+ pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
+ "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) {
if (device_mgr_owned) {
local_device_manager_.reset(device_mgr);
local_unowned_device_manager_ = nullptr;
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 5ed6057ec6..15eeaa8066 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -202,6 +202,7 @@ class EagerContext {
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
+ bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
private:
void InitDeviceMapAndAsync();
@@ -293,6 +294,7 @@ class EagerContext {
#endif
bool use_send_tensor_rpc_;
+ const bool pin_small_ops_to_cpu_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1bc63616d0..a52f933d75 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -579,19 +579,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return Status::OK();
#endif
}
-} // namespace
-Status EagerExecute(EagerOperation* op,
- gtl::InlinedVector<TensorHandle*, 2>* retvals,
- int* num_retvals) {
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
+// The Op device may be updated if:
+// - A resource touching input is specified: all resource-touching ops run in
+// the device the resource is, regardless of anything else that has been
+// specified. This is identical to the graph mode behavior.
+//
+// - All op inputs are on the CPU, small (<64 elements) and integers
+// (int32/int64). This can be disabled by setting the environment variable
+// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
+Status MaybeUpdateOpDevice(EagerOperation* op) {
EagerContext* ctx = op->EagerContext();
+ bool device_set_for_resource_variable = false;
+ bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU();
+
for (int i = 0; i < op->Inputs().size(); ++i) {
Device* input_op_device = nullptr;
- auto status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device));
VLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(op->Inputs()[i]->dtype) << " "
<< (input_op_device == nullptr ? "cpu" : input_op_device->name())
@@ -603,8 +607,53 @@ Status EagerExecute(EagerOperation* op,
<< d->name() << " because input #" << i
<< " is a resource in this device.";
op->SetDevice(d);
+
+ device_set_for_resource_variable = true;
+ all_inputs_eligible_for_cpu_pinning = false;
+ } else if (all_inputs_eligible_for_cpu_pinning) {
+ TensorHandle* handle = op->Inputs()[i];
+
+ // Input is on CPU.
+ if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ if (handle->dtype != DataType::DT_INT32 &&
+ handle->dtype != DataType::DT_INT64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ int64 num_elements;
+ TF_RETURN_IF_ERROR(handle->NumElements(&num_elements));
+ if (num_elements > 64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ }
}
}
+
+ // Ops without inputs are usually ops that generate a tensor in some way and
+ // usually require being present on whatever device they are scheduled on
+ // - for e.g. VarHandleOp or _Recv).
+ // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
+ // an op, but there is a GPU kernel?
+ if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
+ VLOG(1) << "Forcing op " << op->Name()
+ << " to be on the CPU since all input tensors have an "
+ "int32/int64 dtype, and are small (less than 64 elements).";
+ op->SetDevice(ctx->HostCPU());
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status EagerExecute(EagerOperation* op,
+ gtl::InlinedVector<TensorHandle*, 2>* retvals,
+ int* num_retvals) {
+ TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
+
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 362092a6cf..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ca5d3a6dfd..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b09360a2c2..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 755af3361e..ee7c14e3ab 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -524,6 +524,7 @@ cc_library(
deps = [
":function_utils",
":graph_utils",
+ "//tensorflow/cc:ops",
"@com_google_absl//absl/strings",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 9328a7ca99..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -44,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
// Function inputs and outputs are the same as original, just
// with different shapes.
*vectorized_func->mutable_signature() = orig_func.signature();
- graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
+ graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library,
vectorized_func);
// Add MapDefun node
@@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 37aa24b947..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -13,9 +13,19 @@ VECTORIZER_DEPS = [
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 74ce520ce1..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -19,15 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
- if (node.num_inputs() != 1) {
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
@@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer {
int new_axis = node.def().attr().at("axis").i() + 1;
new_unpack_node->AddAttr("axis", new_axis);
- // Add the input mappings
- input_ports->push_back({new_unpack_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
// Add the output mappings
int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- output_ports->push_back({new_unpack_node, i});
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index 56eb88c95e..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -18,15 +18,12 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
-
-// Describes a tensor with its operation Node and output position
-typedef std::pair<Node*, int> Port;
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -36,17 +33,17 @@ class Vectorizer {
// Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs. The new Node(s) collectively have the
+ // on elements of `inputs`. The new Node(s) collectively have the
// same number of input and output ports as the node being converted.
- // Adds mappings for the new nodes' input and output ports to `inputs` and
- // `outputs` respectively, where the i'th Port in inputs/outputs
- // corresponds to the i'th input/output port of the node to be converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
virtual Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) = 0;
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 663ceba027..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* inputs,
- std::vector<Port>* outputs) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) {
NodeDef node_def;
Status s;
Node* node = g.AddNode(node_def, &s);
- std::vector<Port> inputs, outputs;
- EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 2d6cf562b1..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,10 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
-#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
@@ -28,13 +28,13 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -132,7 +132,8 @@ class Vectorization {
const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Converts FunctionDefs to Graphs.
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
Status Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node);
@@ -162,9 +163,30 @@ class Vectorization {
// the conversion map.
Status AddConversionMapping(Node* op_node);
- // Maps a tensor to the corresponding vectorized tensor. For example,
- // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
- std::map<TensorDesc, TensorDesc> conversion_map_;
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
// Unconvertible ret nodes
std::set<Node*> unconvertible_;
@@ -180,6 +202,10 @@ class Vectorization {
std::unique_ptr<Graph> outer_scope_;
std::unique_ptr<FunctionBody> map_defun_fn_;
Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
Status status_;
};
@@ -197,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -239,13 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- converted_output = *found;
- } else {
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output);
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -297,6 +346,7 @@ void Vectorization::VectorizeHelper() {
map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}
+
Status Vectorization::Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node) {
// Convert outer_scope and map_defun_fn to FunctionBodys so we can
@@ -337,16 +387,183 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope,
}
map_defun_node_ = outer_scope_->FindNodeId(node_id);
- // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
- // the conversion map
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
+
+ return Status::OK();
+}
+
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
+ }
+
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
+ }
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->node) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
+ return Status::OK();
+}
+
+Status Vectorization::AddArgNodeMappings() {
for (auto arg_node : map_defun_fn_->arg_nodes) {
Node* input_node;
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {input_node, Graph::kControlSlot, true}});
+ }
+ return Status::OK();
+}
+
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
+
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
}
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->node, found->output_index, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
+ }
+ }
+
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor(node, Graph::kControlSlot, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index 1ff62217dd..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
@@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
@@ -670,6 +673,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
cast_node.input(1) == control_input);
}
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
+}
+
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9439ab332c..3a920f26f3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4458,7 +4458,12 @@ cc_library(
name = "string_util",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
- deps = ["//tensorflow/core:lib"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@icu//:common",
+ ],
)
STRING_DEPS = [
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index a04f150e71..9607e9444c 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
+ return PartialTensorShape();
+ PartialTensorShape output_tensorshape({});
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
+ output_tensorshape.AddDim(dims1[d]);
else
- output_tensorshape.Concatenate(-1);
+ output_tensorshape.AddDim(-1);
}
return output_tensorshape;
}
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 6657f2b2b3..705b0393de 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
- Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
- // Validates inputs and gets the size of their leading dimension.
- *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
- return errors::InvalidArgument(
- "All inputs must have rank at least 1. Input ", i,
- " has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != *batch_size) {
- return errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size);
- }
- }
- return Status::OK();
- }
-
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
@@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
- const std::vector<Tensor> args;
+ OpInputList args;
const std::vector<TensorShape> arg_shapes;
+ OpInputList captured_inputs;
const int64 batch_size;
// Output of a compute call
@@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
- ComputeOptions(std::vector<Tensor> args,
+ ComputeOptions(OpInputList args, OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr)
- : args(std::move(args)),
+ : args(args),
arg_shapes(std::move(arg_shapes)),
+ captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {}
};
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
- int64 batch_size =
- ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ OpInputList arguments;
+ TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
+ OpInputList captured_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
+
+ int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != batch_size) {
+ } else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
@@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
}
}
- std::vector<Tensor> args;
std::vector<TensorShape> arg_shapes;
- args.reserve(ctx->num_inputs());
- arg_shapes.reserve(ctx->num_inputs());
+ arg_shapes.reserve(arguments.size());
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args.push_back(ctx->input(i));
- arg_shapes.push_back(ctx->input(i).shape());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
- *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
- batch_size, output_shapes_);
+ *compute_opts =
+ new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
+ batch_size, output_shapes_);
return Status::OK();
}
@@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= compute_opts_->args.size()) {
+ if (index < 0 || index >= compute_opts_->args.size() +
+ compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
+
+ if (index >= compute_opts_->args.size()) {
+ // The function is calling for a captured input
+ *val =
+ compute_opts_->captured_inputs[index - compute_opts_->args.size()];
+ return Status::OK();
+ }
+
bool result =
- val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
@@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
-
return Status::OK();
}
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 42fbf95cd3..28940e0849 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
- // QuantizeAndDequantizeV3 to match this SCALED mode again.
const float scale_factor =
std::numeric_limits<T>::min() == 0
? (max_range / std::numeric_limits<T>::max())
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fdb4c84c46..3979e4b53a 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,6 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ OP_REQUIRES_ASYNC(
+ ctx, args.size() == fbody->arg_nodes.size(),
+ errors::InvalidArgument(
+ "Wrong number of arguments to the op; function expects ",
+ fbody->arg_nodes.size(), " but PartitionedCall received ",
+ args.size()),
+ done);
// We need to pass global op_registry as default_registry when creating
// graph. So that graph optimization passes can lookup all possible ops
// by name.
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 23d76986bf..678d675c4a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -426,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel {
// ADD if value's refcount was 1.
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
+ OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
+ errors::InvalidArgument("Cannot update variable with shape ",
+ var_tensor->shape().DebugString(),
+ " using a Tensor with shape ",
+ value.shape().DebugString(),
+ ", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, var_tensor));
functor::DenseUpdate<Device, T, Op> update_functor;
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
index 3a9803a052..92c73220d8 100644
--- a/tensorflow/core/kernels/string_util.cc
+++ b/tensorflow/core/kernels/string_util.cc
@@ -16,10 +16,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
-namespace {
-inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
-} // namespace
-
namespace tensorflow {
// Sets unit value based on str.
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
index 390cf57702..d40e93ea33 100644
--- a/tensorflow/core/kernels/string_util.h
+++ b/tensorflow/core/kernels/string_util.h
@@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 };
// TODO(edloper): Add support for: UTF32_CHAR, etc.
enum class CharUnit { BYTE, UTF8_CHAR };
+// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char.
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+
// Sets `encoding` based on `str`.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
@@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit);
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string);
+// Get the next UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset, and
+// should never be `null`. The function return true if successful. However, if
+// the end of the string is reached before the requested characters, then the
+// position will point to the end of string and this function will return false.
+template <typename T>
+bool ForwardNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t size = in.size();
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) {
+ // move forward one utf-8 character
+ do {
+ ++*pos;
+ } while (IsTrailByte(in[*pos]) && *pos < size);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+// Get the previous UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset with a
+// positive value, relative to the beginning of the string, and should never be
+// `null`. The function return true if successful. However, if the beginning of
+// the string is reached before the requested character, then the position will
+// point to the beginning of the string and this function will return false.
+template <typename T>
+bool BackNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t start = 0;
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) {
+ // move back one utf-8 character
+ do {
+ --*pos;
+ } while (IsTrailByte(in[*pos]) && *pos > start);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 07f1d6e767..93c427039d 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
@@ -37,7 +38,11 @@ namespace tensorflow {
template <typename T>
class SubstrOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
// Get inputs
@@ -69,11 +74,23 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
StringPiece in(input(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
} else {
@@ -84,11 +101,23 @@ class SubstrOp : public OpKernel {
StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
}
@@ -151,12 +180,24 @@ class SubstrOp : public OpKernel {
StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(
- context,
- FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
break;
@@ -205,12 +246,24 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for ",
- "string b'", in, "' at index (", i,
- ", ", j, ")"));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (",
+ i, ", ", j, ")"));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i, j).assign(sub_in.data(), sub_in.size());
}
}
@@ -227,12 +280,73 @@ class SubstrOp : public OpKernel {
private:
// This adjusts the requested position. Note it does not perform any bound
// checks.
- T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
if (pos_requested < 0) {
return s.size() + pos_requested;
}
return pos_requested;
}
+
+ // Return true if successful; otherwise, return false if the `pos` argument
+ // is out of range in the string.
+ static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
+ T* len) {
+ if (*pos >= 0) {
+ return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
+ } else {
+ return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
+ }
+ }
+
+ static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ *char_pos = 0;
+ // Determine byte position of the substring start.
+ if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
+ return false;
+ }
+ // Determine position of the end of the substring.
+ // The length will be capped at the end of the string, and we ignore whether
+ // the string had enough characters to handle it or not.
+ *char_len = *char_pos;
+ ForwardNUTF8CharPositions(in, len, char_len);
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ // This function expects a negative position relative to the end of the
+ // string, but will update the character position to a positive number
+ // relative to the beginning of the string.
+ static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ // Initially treat the length as position of the end of the substring.
+ *char_len = in.size();
+ // This is the number of character to skip from the end of the string to
+ // arrive at the position where the substring should end.
+ T utf8_chars_to_skip = -pos - len;
+ if (utf8_chars_to_skip < 0) {
+ utf8_chars_to_skip = 0;
+ }
+ // Find the byte position where the substring should end using the computed
+ // number of characters to skip.
+ if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
+ return false;
+ }
+ // Next, determine where the substring should begin. The number of chars to
+ // skip is the requested position minus the chars we've previously skipped.
+ *char_pos = *char_len;
+ if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
+ return false;
+ }
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ CharUnit unit_ = CharUnit::BYTE;
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
index 2e07050260..ea6b1ed500 100644
--- a/tensorflow/core/kernels/substr_op_test.cc
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -42,7 +42,7 @@ limitations under the License.
namespace tensorflow {
// Test data from the TensorFlow README.md.
-const char* lines[] = {
+const char* ascii_lines[] = {
"**TensorFlow** is an open source software library for numerical "
"computation using data flow graphs.",
"The graph nodes represent mathematical operations, while the graph edges "
@@ -64,17 +64,76 @@ const char* lines[] = {
"backwards compatibility guarantee like C++, Go, Java, JavaScript and "
"Swift."};
+const char* unicode_lines[] = {
+ "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6"
+ "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6"
+ "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6"
+ "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82",
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c"
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81"
+ "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae"
+ "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89"
+ "\xe3\x80\x82",
+ "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93"
+ "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf"
+ "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2"
+ "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1"
+ "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87"
+ "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a"
+ "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9"
+ "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82",
+ "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88"
+ "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef"
+ "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6"
+ "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5"
+ "\x8c\x85\xe3\x80\x82",
+ "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7"
+ "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5"
+ "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd"
+ "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain"
+ "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c"
+ "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba"
+ "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6"
+ "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6"
+ "\xe3\x80\x82",
+ "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82"
+ "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96"
+ "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4"
+ "\xe3\x80\x82",
+ "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84"
+ "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2"
+ "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6"
+ "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c"
+ "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82",
+};
+
+const char* const kByteUnit = "BYTE";
+const char* const kUTF8Unit = "UTF8_CHAR";
+
Tensor GetTestTensor(int batch) {
- const int sz = TF_ARRAYSIZE(lines);
+ const int sz = TF_ARRAYSIZE(ascii_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = ascii_lines[i % sz];
+ }
+ return t;
+}
+
+Tensor GetTestUTF8Tensor(int batch) {
+ const int sz = TF_ARRAYSIZE(unicode_lines);
Tensor t(DT_STRING, {batch});
auto s = t.flat<string>();
for (int i = 0; i < batch; ++i) {
- s(i) = lines[i % sz];
+ s(i) = unicode_lines[i % sz];
}
return t;
}
-Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len,
+ const char* const unit) {
Graph* g = new Graph(OpRegistry::Global());
Tensor position(DT_INT32, TensorShape({}));
position.flat<int32>().setConstant(pos);
@@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
.Input(test::graph::Constant(g, input))
.Input(test::graph::Constant(g, position))
.Input(test::graph::Constant(g, length))
+ .Attr("unit", unit)
.Finalize(g, nullptr /* node */));
return g;
}
-void BM_Substr(int iters, int batch_size) {
+void BM_SubstrByte(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
Tensor input = GetTestTensor(batch_size);
- Graph* g = SetupSubstrGraph(input, 3, 30);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+void BM_SubstrUTF8(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestUTF8Tensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
-BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
- 256);
+BENCHMARK(BM_SubstrByte)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+BENCHMARK(BM_SubstrUTF8)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 33f18ae13f..780c6f6448 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -30567,6 +30567,52 @@ op {
}
}
op {
+ name: "MapDefun"
+ input_arg {
+ name: "arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "MapIncompleteSize"
output_arg {
name: "size"
@@ -71844,6 +71890,48 @@ op {
}
}
op {
+ name: "Substr"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pos"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "len"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "Sum"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 889a6a4640..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
- "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0e58a9475d..0d8997c1bd 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -15262,6 +15262,10 @@ op {
name: "arguments"
type_list_attr: "Targuments"
}
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
output_arg {
name: "output"
type_list_attr: "output_types"
@@ -15273,6 +15277,15 @@ op {
minimum: 1
}
attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
name: "output_types"
type: "list(type)"
has_minimum: true
@@ -33748,6 +33761,19 @@ op {
}
}
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "Sum"
diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc
new file mode 100644
index 0000000000..331e1d0152
--- /dev/null
+++ b/tensorflow/core/ops/stateless_random_grad.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+namespace tensorflow {
+REGISTER_OP_NO_GRADIENT("StatelessRandomUniform");
+REGISTER_OP_NO_GRADIENT("StatelessRandomNormal");
+REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal");
+REGISTER_OP_NO_GRADIENT("StatelessMultinomial");
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index b4fbde54d9..94d71a4113 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -223,6 +223,7 @@ REGISTER_OP("Substr")
.Input("len: T")
.Output("output: string")
.Attr("T: {int32, int64}")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle pos_shape = c->input(1);
ShapeHandle len_shape = c->input(2);