aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/tape.h118
-rw-r--r--tensorflow/c/python_api.cc7
-rw-r--r--tensorflow/c/python_api.h13
-rw-r--r--tensorflow/compiler/aot/tests/BUILD15
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py8
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc25
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_device.cc10
-rw-r--r--tensorflow/compiler/jit/xla_device.h12
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc6
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl4
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/quantized_ops_test.py48
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/xla_test.py13
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h15
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py19
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc1
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc42
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc10
-rw-r--r--tensorflow/compiler/xla/service/stream_pool_test.cc34
-rw-r--r--tensorflow/compiler/xla/shape_util.h4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD22
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc78
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h63
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc158
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc1
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc7
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py25
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD10
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml67
-rw-r--r--tensorflow/contrib/lite/g3doc/_index.yaml220
-rw-r--r--tensorflow/contrib/lite/g3doc/_project.yaml4
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml6
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md9
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.pngbin0 -> 10942 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.pngbin0 -> 578440 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.pngbin0 -> 7764 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.pngbin0 -> 16308 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.pngbin0 -> 20159 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.pngbin0 -> 35371 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.pngbin0 -> 12002 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.pngbin0 -> 25868 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.pngbin0 -> 7839 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.pngbin0 -> 27152 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.pngbin0 -> 17783 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.pngbin0 -> 17249 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD3
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h74
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h3
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc117
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc81
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc80
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc17
-rw-r--r--tensorflow/contrib/lite/toco/model.h1
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc9
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py4
-rw-r--r--tensorflow/core/BUILD11
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt38
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/device.h4
-rw-r--r--tensorflow/core/common_runtime/executor.cc6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h3
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc56
-rw-r--r--tensorflow/core/example/feature_util.h5
-rw-r--r--tensorflow/core/framework/cancellation.cc10
-rw-r--r--tensorflow/core/framework/cancellation.h9
-rw-r--r--tensorflow/core/framework/cancellation_test.cc52
-rw-r--r--tensorflow/core/framework/device_base.h3
-rw-r--r--tensorflow/core/framework/tensor.cc112
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/framework/tensor_test.cc57
-rw-r--r--tensorflow/core/graph/testlib.h2
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD76
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc196
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.h108
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils_test.cc164
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc82
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h26
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc82
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc341
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h90
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc600
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc2
-rw-r--r--tensorflow/core/kernels/BUILD27
-rw-r--r--tensorflow/core/kernels/logging_ops.cc57
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc2
-rw-r--r--tensorflow/core/kernels/queue_base.h4
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc10
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt61
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc9
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc11
-rw-r--r--tensorflow/core/ops/logging_ops.cc19
-rw-r--r--tensorflow/core/ops/ops.pbtxt61
-rw-r--r--tensorflow/core/ops/string_ops.cc27
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc3
-rw-r--r--tensorflow/core/platform/default/cord.h5
-rw-r--r--tensorflow/core/platform/file_system.h3
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h14
-rw-r--r--tensorflow/go/op/wrappers.go926
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/client/tf_session.i4
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/eager/BUILD2
-rw-r--r--tensorflow/python/eager/backprop.py5
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/function.py43
-rw-r--r--tensorflow/python/eager/function_test.py47
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc458
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py35
-rw-r--r--tensorflow/python/framework/function.py9
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/framework/test_util.py60
-rw-r--r--tensorflow/python/keras/callbacks.py16
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py39
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py313
-rw-r--r--tensorflow/python/kernel_tests/string_format_op_test.py384
-rw-r--r--tensorflow/python/ops/control_flow_ops.py7
-rw-r--r--tensorflow/python/ops/image_ops_impl.py40
-rw-r--r--tensorflow/python/ops/logging_ops.py260
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py2
-rw-r--r--tensorflow/python/ops/string_ops.py84
-rw-r--r--tensorflow/stream_executor/device_description.h6
-rw-r--r--tensorflow/stream_executor/plugin_registry.h2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h11
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt4
165 files changed, 5841 insertions, 1492 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 49990b6249..41b5b8ff36 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace eager {
-// Information about a tensor.
-struct TapeTensor {
- int64 id; // Expected to be unique in the lifetime of this process.
- DataType dtype;
- TensorShape shape;
-};
-
// Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@ struct OpTapeEntry {
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class VSpace {
public:
virtual ~VSpace() {}
@@ -93,10 +86,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
// Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@ class VSpace {
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
public:
// If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@ class GradientTape {
void Watch(int64 tensor_id);
void RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -146,17 +139,18 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ Status ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
bool IsPersistent() const { return persistent_; }
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape<BackwardFunction, TapeTensor> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) {
}
}
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
return false;
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+ int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -229,16 +224,18 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
tensors.push_back(o);
}
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+ op_type, std::move(tensors), ids, backward_function,
+ backward_function_deleter};
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+ int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
@@ -261,7 +258,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
@@ -304,9 +301,9 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
namespace {
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
+ OpTape<BackwardFunction, TapeTensor> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
@@ -322,17 +319,17 @@ struct BackpropInitialState {
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
- bool persistent_tape) {
+ OpTape<BackwardFunction, TapeTensor>* op_tape,
+ const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
- BackpropInitialState<BackwardFunction> result;
+ BackpropInitialState<BackwardFunction, TapeTensor> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
@@ -383,9 +380,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
return result;
}
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@ std::vector<int64> InitialStack(
return result;
}
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
+ if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true;
(*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
+ vspace.Ones(op_it->second.output_tensor_info[j]));
break;
}
}
@@ -469,16 +465,16 @@ gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
@@ -522,7 +518,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
out_gradients.reserve(trace.output_tensor_info.size());
bool any_gradient_nonzero = false;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
+ const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
@@ -531,9 +527,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
+ out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
}
} else {
any_gradient_nonzero = true;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b585c8..247236b760 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bde62..5cce84020b 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d44d..10fa33ab5e 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@ py_binary(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@ genrule(
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@ tf_cc_test(
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..de135d7a23 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000000..6b4ac2d7cb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb8a9..f10852c785 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe14a..859c84bb91 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@ def tf_library(
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 1afc305abe..e26fa27b31 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -65,9 +65,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 9> kAllXlaCpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797def04..32fce2bf94 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
return status;
}
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+ mutex_lock lock(mu_);
+ sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+ mutex_lock lock(mu_);
+ return sync_on_completion_;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ffa8c..0f06b3fc80 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
+ // Instructs this XlaDevice to return 'sync_on_completion' for
+ // RequiresSyncOnCompletion().
+ void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+ bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
- mutex mu_;
+ mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ // True if the device requires XlaDevice::Sync to be called on completion
+ // regardless of status.
+ bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 4cf556524d..c386984930 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -74,9 +74,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 10> kAllXlaGpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 97ed554171..3cf74fa788 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -978,7 +978,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1198,6 +1198,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "quantized_ops_test",
+ size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index b8583c9bdb..1d3979b21b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -62,12 +62,12 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
backend_tags += tf_cuda_tests_tags()
elif backend in plugins:
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95daab..a38e1edafe 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase):
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase):
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase):
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000000..80c338513b
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index c423fa5004..36ef6ed5fe 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -76,7 +76,8 @@ class RandomOpsTest(xla_test.XLATestCase):
for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
# fixed.
- if (self.device in ["XLA_GPU", "XLA_CPU"]) and (dtype == dtypes.bfloat16):
+ if (self.device in ["XLA_GPU", "XLA_CPU"
+ ]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
with self.cached_session() as sess:
with self.test_scope():
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index df5c81243a..98a41981cf 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,9 +97,16 @@ class XLATestCase(test.TestCase):
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
-
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
+
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
self.signed_int_types = set(dtype.as_numpy_dtype
for dtype in self.int_tf_types
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index a4b624820a..4b2c2bacd6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -51,13 +51,14 @@ constexpr std::array<DataType, 11> kNumericTypes = {
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 11> kCpuAllTypes = {
- {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
- DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
-
-constexpr std::array<DataType, 12> kGpuAllTypes = {
- {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
- DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index ef70c1f8ac..cc7390c6e6 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -245,6 +245,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0d2d..cd5fd33029 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5911..2166bb6721 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@ class LocalComputationBuilder {
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366ff07..bb303c5678 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1117,7 @@ class ComputationBuilder(object):
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1126,11 @@ class ComputationBuilder(object):
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1140,7 @@ class ComputationBuilder(object):
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@ class ComputationBuilder(object):
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@ class ComputationBuilder(object):
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1195,7 @@ class ComputationBuilder(object):
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1221,8 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19457..82103f0313 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4b183b4350..2bc50c70cf 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2605,7 +2605,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b3e4fab727..bf627986a5 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -631,7 +631,7 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:lib",
+ "//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
index cef5420f00..e0e7deb98e 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/dynamic_annotations.h"
-#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1d63..85f3682a5a 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,21 +29,51 @@ limitations under the License.
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
+
+ switch (params.kind) {
+ case CudnnConvKind::kForward:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ break;
+ }
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91fba..f53bc54198 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk {
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881026..c792dd2ddb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13781..735804e827 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@ class HloModule {
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@ class HloModule {
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@ class HloModule {
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89c54..37197b273b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ bool Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@ class HloParser {
const string& name, const optional<Shape>& shape = nullopt);
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
@@ -293,9 +290,7 @@ class HloParser {
computation_pool_;
HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
// Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ return ParseHloModule(module);
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@ std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +380,20 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@ bool HloParser::ParseComputations() {
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -3247,53 +3239,62 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
+ auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
- return parser.ConsumeHloModule();
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ if (!parser.Run(module)) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return Status::OK();
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
+ auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a184da..3696035514 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@ namespace xla {
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
index e16b4d4c0a..ee8cb12b23 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -19,21 +19,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
-class HloPassPipelineTest : public HloTestBase {
+class HloPassPipelineTest : public HloVerifiedTestBase {
protected:
StatusOr<HloModuleGroup> ParseModuleGroup(
absl::Span<const string> hlo_strings) {
HloModuleGroup group(TestName());
for (const string& hlo_string : hlo_strings) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- ParseHloString(hlo_string));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
group.push_back(std::move(module));
}
return std::move(group);
@@ -106,8 +106,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
@@ -129,8 +129,8 @@ ENTRY main {
ROOT blahblah = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
@@ -191,8 +191,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
{
// Run a pipeline with just the invariant checker. It should not fail
// because there is no 'bar' instruction in the module.
@@ -243,8 +243,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<BazToQuxModuleGroupPass>();
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cbcb5..6eb6658904 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@ Status VerifySendsAndRecvs(const HloModule& module) {
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17a23..ac2f79674f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) {
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c442..ec09dff924 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
- VLOG(1) << stream->DebugStreamPointers()
- << " StreamPool reusing existing stream";
+ if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
+ } else {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " stream was not ok, StreamPool deleting";
+ stream = nullptr;
+ }
}
}
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37b0d..92f47579d3 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Return the stream, but hold a handle to it.
+ se::Stream* stream1_ptr = stream1.get();
+ stream1 = nullptr;
+
+ // Now stream1 is back in the pool, force an error on the stream. Here we call
+ // a method that requires DNN support, which we know the Host platform doesn't
+ // support.
+ stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1_ptr->ok());
+
+ // Borrow stream2.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 623ae39de8..d8bb27beae 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <initializer_list>
#include <string>
+#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -479,8 +480,7 @@ class ShapeUtil {
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
- //
- // DEPRECATED: Use Equal() instead.
+ ABSL_DEPRECATED("Use Equal() instead.")
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index b49db029e2..fd3e3bfa94 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -154,11 +154,31 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c528d0..8bd0a729b7 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision)
- : HloTestBase(
- /*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
}
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
+ return verifier_.Run(this).status();
}
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
}
}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa753..388a99bb36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
+ std::unique_ptr<VerifiedHloModule> module_;
+
// Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
- bool tear_down_called_ = false;
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000000..5c0263e811
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7de6..2cc33ab096 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P(
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
+ R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
R2Spec{
511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
R2Spec{
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 51e0c2e431..af7006bff2 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -579,13 +579,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel {
const int end_index =
partition_boundaries[non_empty_partitions[root_idx]][j + 1]
.start_index;
- CHECK(bucket_ids_and_dimensions(start_index, 1) ==
- bucket_ids_and_dimensions(end_index - 1, 1))
- << "For bucket " << bucket_ids_and_dimensions(start_index, 0)
- << " the dimension was "
- << bucket_ids_and_dimensions(start_index, 1) << " and for "
- << bucket_ids_and_dimensions(end_index - 1, 0) << " "
- << bucket_ids_and_dimensions(end_index - 1, 1);
if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a667485be..c59d3682d4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
self._testOneLSTMParamsSize(num_layers, num_units, input_size,
direction)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLSTMParamsSizeShape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+
class CudnnRNNTestInference(TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 835fc2595e..52e71619de 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -1,5 +1,12 @@
package(default_visibility = ["//visibility:private"])
+package_group(
+ name = "experimental",
+ packages = [
+ "//tensorflow/contrib/lite/experimental/...",
+ ],
+)
+
licenses(["notice"]) # Apache 2.0
load(
@@ -51,6 +58,9 @@ cc_library(
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
copts = tflite_copts(),
+ visibility = [
+ ":experimental",
+ ],
deps = [
":c_api_internal",
"//tensorflow/contrib/lite:context",
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 1dffe30790..6f56e3139f 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -14,46 +14,49 @@ upper_tabs:
- name: Guide
contents:
- title: Overview
- path: /mobile/overview
- - title: Developer Guide
- path: /mobile/devguide
- - title: Android Demo App
- path: /mobile/demo_android
- - title: iOS Demo App
- path: /mobile/demo_ios
+ path: /lite/overview
+ - title: Developer guide
+ path: /lite/devguide
+ - title: Android demo app
+ path: /lite/demo_android
+ - title: iOS demo app
+ path: /lite/demo_ios
- title: Performance
- path: /mobile/performance
+ path: /lite/performance
- break: True
- title: TensorFlow Lite APIs
- path: /mobile/apis
+ path: /lite/apis
- title: Custom operators
- path: /mobile/custom_operators
- - title: TensorFlow Lite Ops Versioning
- path: /mobile/ops_versioning
- - title: TensorFlow Lite Compatibility Guide
- path: /mobile/tf_ops_compatibility
- - title: List of Hosted Models
- path: /mobile/models
+ path: /lite/custom_operators
+ - title: TensorFlow Lite ops versioning
+ path: /lite/ops_versioning
+ - title: TensorFlow Lite compatibility guide
+ path: /lite/tf_ops_compatibility
+ - title: List of hosted models
+ path: /lite/models
- title: TensorFlow Lite for iOS
- path: /mobile/ios
+ path: /lite/ios
- title: TensorFlow Lite for Raspberry Pi
- path: /mobile/rpi
+ path: /lite/rpi
- - heading: TF Mobile
+ - title: TF Mobile
+ style: accordion
status: deprecated
- - title: Overview
- path: /mobile/tfmobile/
- - title: Building TensorFlow on Android
- path: /mobile/tfmobile/android_build
- - title: Building TensorFlow on IOS
- path: /mobile/tfmobile/ios_build
- - title: Integrating TensorFlow libraries
- path: /mobile/tfmobile/linking_libs
- - title: Preparing models for mobile deployment
- path: /mobile/tfmobile/prepare_models
- - title: Optimizing for mobile
- path: /mobile/tfmobile/optimizing
+ section:
+ - title: Overview
+ path: /lite/tfmobile/
+ - title: Building TensorFlow on Android
+ path: /lite/tfmobile/android_build
+ - title: Building TensorFlow on IOS
+ path: /lite/tfmobile/ios_build
+ - title: Integrating TensorFlow libraries
+ path: /lite/tfmobile/linking_libs
+ - title: Preparing models for mobile deployment
+ path: /lite/tfmobile/prepare_models
+ - title: Optimizing for mobile
+ path: /lite/tfmobile/optimizing
- name: API
contents:
- - include: /mobile/api_docs/python/_toc.yaml
+ - title: API
+ path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index b3f21e21ac..bc66cc5dc1 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -1,60 +1,209 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
+project_path: /lite/_project.yaml
+book_path: /lite/_book.yaml
description: <!--no description-->
landing_page:
+ custom_css_path: /site-assets/css/style.css
rows:
- - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
+ - heading: TensorFlow Lite is for mobile and embedded devices.
+ description: >
+ <p style="max-width: 75%;">
+ TensorFlow Lite is the official solution for running machine learning
+ models on mobile and embedded devices. It enables on&#8209;device machine
+ learning inference with low latency and a small binary size on Android,
+ iOS, and other operating systems.
+ </p>
+ <style>
+ .tfo-landing-row-heading {
+ padding-top: 0 !important;
+ }
+ .tfo-landing-row-heading h2 {
+ margin-top: 0 !important;
+ }
+ .tfo-landing-row-heading-list ol, .tfo-landing-row-heading-list ul {
+ margin-top: 0;
+ }
+ </style>
+
+ - classname: tfo-landing-row-heading tfo-landing-row-heading-list
+ heading: Many benefits
+ description: >
+ On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these:
items:
- - classname: devsite-landing-row-50
- description: >
- TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
- embedded devices. It enables on-device machine learning inference with
- low latency and a small binary size. TensorFlow Lite also supports
- hardware acceleration with the
- <a href='https://developer.android.com/ndk/guides/neuralnetworks/index.html'>Android Neural Networks API</a>.
- list:
- - heading: Key point 1
+ - list:
+ - heading: Performance
+ description: >
+ TF Lite is fast with no noticeable accuracy loss—see the <a href="./performance">metrics</a>.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Portability
description: >
- [high-level overview]
+ <a href="https://developer.android.com/ndk/guides/neuralnetworks/" class="external">Android</a>,
+ iOS, and more specialized IoT devices.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 2
+ - list:
+ - heading: Low latency
description: >
- [high-level overview]
+ Optimized float- and fixed-point CPU kernels, op&#8209;fusing, and more.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 3
+ - heading: Acceleration
description: >
- [high-level overview]
+ Integration with GPU and internal/external accelerators.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- code_block: |
- <pre class = "prettyprint">
- $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
- --inference_type=FLOAT \
- --input_type=FLOAT \
- --input_arrays=input \
- --output_arrays=MobilenetV1/Predictions/Reshape_1 \
- --input_shapes=1,224,224,3
- </pre>
+ - list:
+ - heading: Small model size
+ description: >
+ Controlled dependencies, <a href="https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3" class="external">quantization</a>,
+ and op&nbsp;registration.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Tooling
+ description: >
+ Conversion, compression, benchmarking, power-consumption, and more.
+ icon:
+ icon_name: lens
+ foreground: theme
+
+ - classname: devsite-landing-row-logos tfo-landing-row-heading
+ heading: Companies using TensorFlow Lite
+ items:
+ - custom_image:
+ path: ./images/landing-page/photos_logo.png
+ path: https://www.photos.google.com
+ - custom_image:
+ path: ./images/landing-page/gboard_logo.png
+ path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US
+ - custom_image:
+ path: ./images/landing-page/gmail_logo.png
+ path: https://www.google.com/gmail/
+ - custom_image:
+ path: ./images/landing-page/assistant_logo.png
+ path: https://assistant.google.com/
+
+ - classname: devsite-landing-row-logos
+ items:
+ - custom_image:
+ path: ./images/landing-page/vsco_logo.png
+ path: https://vsco.co
+ - custom_image:
+ path: ./images/landing-page/shazam_logo.png
+ path: https://www.shazam.com/
+ - custom_image:
+ path: ./images/landing-page/nest_logo.png
+ path: https://nest.com/
+ - custom_image:
+ path: ./images/landing-page/loseit_logo.png
+ path: https://www.loseit.com/
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ background: grey
+ items:
+ - description: >
+ <em>“TensorFlow Lite helped us introduce machine learning and AI into our
+ app in an easy and streamlined way. We could reduce the size of our
+ models while keeping the accuracy high. This helped us create an amazing
+ fishing experience for our users by allowing them to identify any fish
+ species with just a photo.”</em>
+ image_path: ./images/landing-page/fishbrain_logo_big.png
+
+ - heading: How it works
+ items:
+ - heading: Build
+ icon:
+ icon_name: build
+ description: >
+ Build a new model or retrain an existing one, such as using transfer learning.
+ buttons:
+ - label: Read the developer guide
+ path: /lite/devguide
+ classname: button button-primary tfo-button-primary
+ - heading: Convert
+ icon:
+ icon_name: autorenew
+ description: >
+ Convert a TensorFlow model into a compressed flat buffer with the
+ TensorFlow Lite Optimizing Converter (TOCO).
+ buttons:
+ - label: Read the TOCO guide
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
+ classname: button button-primary tfo-button-primary
+ - heading: Deploy
+ icon:
+ icon_name: bolt
+ description: >
+ Take the compressed <code>.tflite</code> file and load it into a mobile
+ or embedded device.<br/>
+ See the <a href="#build-your-first-tensorflow-lite-app">tutorials below</a> to build an app.
+
+ - heading: Build your first TensorFlow Lite app
+ background: grey
+ items:
+ - classname: tfo-landing-row-item-inset-white
+ heading: Get started
+ description: >
+ <ul>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/" class="external">TensorFlow for Poets</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/" class="external">TensorFlow for Poets 2: Android</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/" class="external">TensorFlow for Poets 2: iOS </a></li>
+ <li>Intermediate: <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193" class="external">Object detection tutorial</a>
+ </ul>
+ - classname: tfo-landing-row-item-inset-white
+ heading: Share your TensorFlow Lite story
+ description: >
+ We love to hear what you're working on—it may even get highlighted on
+ our social media! <a href="https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss" class="external">Tell us</a>.
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ items:
+ - description: >
+ <p>
+ <em>“The release of TensorFlow Lite has allowed us to deploy an engaging
+ real-time experience to our users that eliminates the requirement
+ for a data connection. TensorFlow Lite’s ability to compress and
+ optimize the TensorFlow graph for mobile deployment has been
+ transformative in expanding the capabilities of Snap It.</em>
+ </p>
+ <p>
+ <em>Through TensorFlow Lite, our users can now enjoy a state of the
+ art, computer-vision-based food logging experience without worrying
+ about signal strength. We look forward to future collaborations
+ with the TensorFlow Lite team.”</em>
+ </p>
+ image_path: ./images/landing-page/loseit_logo_big.png
- classname: devsite-landing-row-cards
+ background: grey
+ heading: Updates
items:
+ - heading: Introducing the Model Optimization Toolkit
+ image_path: /ecosystem/images/tf-logo-card-16x9.png
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ buttons:
+ - label: Read on TensorFlow blog
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ - heading: East Africa Cassava App
+ image_path: ./images/landing-page/detect_crop_disease_in_africa.png
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
+ buttons:
+ - label: Read more
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
- heading: Using TensorFlow Lite on Android
image_path: /ecosystem/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
+
+ - classname: devsite-landing-row-cards
+ background: grey
+ items:
- heading: TensorFlow Lite at the Dev Summit
youtube_id: FAMfy7izB6A
buttons:
@@ -66,3 +215,4 @@ landing_page:
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
+ - classname: devsite-landing-row-item-hidden
diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml
index b39666516b..d48d07be04 100644
--- a/tensorflow/contrib/lite/g3doc/_project.yaml
+++ b/tensorflow/contrib/lite/g3doc/_project.yaml
@@ -1,6 +1,6 @@
name: TensorFlow Lite
-breadcrumb_name: Mobile
-home_url: /mobile/
+breadcrumb_name: TensorFlow Lite
+home_url: /lite/
parent_project_metadata_path: /_project.yaml
description: >
TensorFlow Lite is a lightweight solution for mobile and embedded devices.
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
deleted file mode 100644
index 1e1c44c692..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# Automatically generated file; please do not edit
-toc:
- - title: TensorFlow Lite
- section:
- - title: Overview
- path: /mobile/api_docs/python/
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index 90e7915c52..0eed516000 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,4 @@
-
-# Developer Guide
+# TF Lite Developer Guide
Using a TensorFlow Lite model in your mobile app requires multiple
considerations: you must choose a pre-trained or custom model, convert the model
@@ -55,7 +54,7 @@ both floating point and quantized inference.
### Train a custom model
A developer may choose to train a custom model using Tensorflow (see the
-[TensorFlow tutorials](../../tutorials/) for examples of building and training
+[TensorFlow tutorials](../tutorials/) for examples of building and training
models). If you have already written a model, the first step is to export this
to a `tf.GraphDef` file. This is required because some formats do not store the
model structure outside the code, and we must communicate with other parts of the
@@ -205,7 +204,7 @@ The open source Android demo app uses the JNI interface and is available
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
You can also download a
[prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
-See the <a href="../demo_android.md">Android demo</a> guide for details.
+See the <a href="./demo_android.md">Android demo</a> guide for details.
The <a href="./android_build.md">Android mobile</a> guide has instructions for
installing TensorFlow on Android and setting up `bazel` and Android Studio.
@@ -214,7 +213,7 @@ installing TensorFlow on Android and setting up `bazel` and Android Studio.
To integrate a TensorFlow model in an iOS app, see the
[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md)
-guide and <a href="../demo_ios.md">iOS demo</a> guide.
+guide and <a href="./demo_ios.md">iOS demo</a> guide.
#### Core ML support
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
new file mode 100644
index 0000000000..ced0872ab2
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
new file mode 100644
index 0000000000..45b3b4f6fe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
new file mode 100644
index 0000000000..bc1bf6e1e7
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
new file mode 100644
index 0000000000..d76fca86a9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
new file mode 100644
index 0000000000..f1a93ab763
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
new file mode 100644
index 0000000000..21aa2c84ea
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
new file mode 100644
index 0000000000..b6b3d14df9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
new file mode 100644
index 0000000000..b3e46d4bd8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
new file mode 100644
index 0000000000..35bfd97373
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
new file mode 100644
index 0000000000..4333426dfe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
new file mode 100644
index 0000000000..6ec412c75c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
new file mode 100644
index 0000000000..f408f9024b
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index d003bb2f38..49ad35d4e6 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -4,7 +4,7 @@
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
-<a href="../index.md">TensorFlow Lite</a>.
+<a href="../../lite">TensorFlow Lite</a>.
## TensorFlow Lite versus TensorFlow Mobile
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a6fd4ac2dd..195474e7fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,7 @@ cc_library(
"compatibility.h",
"types.h",
],
+ deps = ["@com_google_absl//absl/base:core_headers"],
)
config_setting(
@@ -458,7 +459,7 @@ cc_library(
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
- "//tensorflow/contrib/lite/kernels:activation_functor",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 6a7e664e85..1a2d45166a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -3804,11 +3804,11 @@ inline void LstmCell(
uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label(
+ "LstmCell/quantized (8bit external, 16bit internal)");
int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift;
- gemmlowp::ScopedProfilingLabel label(
- "LstmCell/quantized (8bit external, 16bit internal)");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
@@ -5063,8 +5063,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
@@ -5073,13 +5072,13 @@ inline void Logistic(const LogisticParams& params,
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -5315,22 +5314,21 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -6385,6 +6383,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -6396,6 +6404,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().max(max_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 76fa1944bc..bb1d30b216 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1916,7 +1916,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1957,7 +1957,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const uint8* input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2021,7 +2021,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2061,7 +2061,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -3637,8 +3637,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3649,13 +3648,13 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -3741,9 +3740,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3753,13 +3751,13 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -4735,6 +4733,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4747,6 +4755,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@@ -4822,6 +4840,16 @@ void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
std::greater<T1>());
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Transpose(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index b70a87d0dc..3e0308721e 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <cstring>
#include <iterator>
+#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -424,7 +425,7 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-// Deprecated. Prefer FlatSize.
+ABSL_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 6e35799c35..2f4b663a28 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -158,7 +158,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
AddBuiltin(BuiltinOperator_RNN, Register_RNN());
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 3a534300ae..3d1eb3978c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/): To return a working TF GraphDef, we should be returning the
+ // correct SpaceToBatchNd and BatchToSpaceND operation before and after the
+ // conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632451..4d213b3f9c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@ DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool has_default_ranges_flag_ = false;
};
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "IdentifyDilatedConv"; }
+ bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+ void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+ bool identify_depthwise_conv_ = true;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857cfc2..aac77eb39e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@ namespace toco {
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 6c72e20121..f943da6d85 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8266e2c205..8e150db6fa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -25,29 +25,57 @@ limitations under the License.
namespace toco {
+namespace {
+
+void RenameArray(Model* model, const string& oldname,
+ const string& desired_newname) {
+ const string& newname = AvailableArrayName(*model, desired_newname);
+ auto& arrays = model->GetMutableArrayMap();
+ arrays[newname] = std::move(arrays[oldname]);
+ arrays.erase(oldname);
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == oldname) {
+ input = newname;
+ }
+ }
+ for (string& output : op->outputs) {
+ if (output == oldname) {
+ output = newname;
+ }
+ }
+ }
+}
+
+} // namespace
+
// Reorder the elements of an input_array according to the input_axes_order and
// output_axes_order. Then adjust the shapes of the input and output arrays
// accordingly. Note that input_array must have a buffer (that is, it is a
// constant array).
template <typename T, ArrayDataType DataType>
void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
- Array* input_array, Array* output_array) {
- CHECK(input_array->buffer->type == DataType);
- CHECK(!output_array->buffer);
- auto& input_data = input_array->GetMutableBuffer<DataType>().data;
- std::vector<T> reordered_data;
- reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+ const Array& input_array, Array* output_array) {
+ DCHECK(input_array.buffer->type == DataType);
+ DCHECK(!output_array->buffer);
+ const auto& input_data = input_array.GetBuffer<DataType>().data;
+ auto& output_data = output_array->GetMutableBuffer<DataType>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
// TODO(b/62904716) Shapes should be used directly.
- Shape input_shape = input_array->shape();
+ Shape input_shape = input_array.shape();
Shape output_shape = output_array->shape();
if (AxesCount(input_axes_order) == 2) {
UnextendShape(&input_shape, 2);
UnextendShape(&output_shape, 2);
}
ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
- input_data.data(), reordered_data.data());
- input_data = reordered_data;
- input_array->copy_shape(output_array->shape());
+ input_data.data(), output_data.data());
+ if (input_array.minmax) {
+ output_array->GetOrCreateMinMax() = input_array.GetMinMax();
+ }
+ if (input_array.narrow_range) {
+ output_array->narrow_range = true;
+ }
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
@@ -57,8 +85,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
return false;
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
- const auto& input_array_name = reorder_op->inputs[0];
- const auto& output_array_name = reorder_op->outputs[0];
+
+ // Intentionally copies, not references.
+ const string input_array_name = reorder_op->inputs[0];
+ const string output_array_name = reorder_op->outputs[0];
+
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
@@ -72,31 +103,23 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
if (input_array.buffer->type == ArrayDataType::kFloat) {
ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
- } else if (input_array.buffer->type == ArrayDataType::kInt32) {
+ input_array, &output_array);
+ } else if (input_array.buffer->type == ArrayDataType::kUint8) {
+ // TODO(benoitjacob): This path seems unused.
+ // ReorderAxes is only used when importing from
+ // TensorFlow GraphDef, which does not support quantized nodes.
ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
+ input_array, &output_array);
} else {
LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
}
- input_array.copy_shape(output_array.shape());
-
- // Update the edges of the graph to point to the input array
- for (const auto& other_op : model->operators) {
- for (auto& input : other_op->inputs) {
- if (input == output_array_name) {
- input = input_array_name;
- }
- }
- }
-
AddMessageF("Reordered axes for array %s", input_array_name);
- // Remove the op and output array.
- model->EraseArray(output_array_name);
- model->operators.erase(it);
+ DeleteOpAndArraysIfUnused(model, op);
+ RenameArray(model, output_array_name, input_array_name);
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd347..65346c4fe4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@ limitations under the License.
namespace toco {
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+ const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ Operator* op = it->get();
+ if (op->type != OperatorType::kTranspose) {
+ continue;
+ }
+ if (op->inputs[0] != array_name) {
+ continue;
+ }
+ const auto& permutation_array = model.GetArray(op->inputs[1]);
+ if (permutation_array.data_type != ArrayDataType::kInt32) {
+ continue;
+ }
+ const auto& permutation_data =
+ permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+ if (permutation_data.size() != 2) {
+ continue;
+ }
+ if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+ continue;
+ }
+ return static_cast<TransposeOperator*>(op);
+ }
+ return nullptr;
+}
+
+} // namespace
+
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
- CHECK(!matmul_op->transpose_a);
+ if (matmul_op->transpose_a) {
+ AddMessageF(
+ "Not replacing %s by a FullyConnected operator, because it has "
+ "the transpose_a attribute",
+ LogName(*matmul_op));
+ return false;
+ }
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(model,
- AvailableArrayName(
- *model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
-
+ // Need to transpose input_rhs, by inserting a TransposeOperator.
+ // First, check if there already is a TransposeOperator transposing that
+ // array, so we can just reuse it.
+ auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+ if (!transpose_op) {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, created new "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ // No such TransposeOperator found. Create one now.
+ transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ input_rhs,
+ CreateInt32Array(
+ model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, input_rhs + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+ // Sanity check
+ DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+ } else {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, reused existing "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ }
+ // Re-wire: have the matmul consume the transposed array.
input_rhs = transpose_op->outputs[0];
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 4c678e7e73..e02d000e7e 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -641,6 +641,23 @@ tensorflow::Status ConvertDepthwiseConvOperator(
CHECK_EQ(strides.i(3), 1);
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
+ if (HasAttr(node, "dilations")) {
+ const auto& dilations = GetListAttr(node, "dilations");
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
+ conv->dilation_height_factor = dilations.i(1);
+ conv->dilation_width_factor = dilations.i(2);
+ } else {
+ conv->dilation_height_factor = 1;
+ conv->dilation_width_factor = 1;
+ }
const auto& padding = GetStringAttr(node, "padding");
if (padding == "SAME") {
conv->padding.type = PaddingType::kSame;
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 0fd2732973..6e207fdf54 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -2084,6 +2084,7 @@ class Model {
}
}
const ArrayMap& GetArrayMap() const { return arrays; }
+ ArrayMap& GetMutableArrayMap() { return arrays; }
int64 ArithmeticOpsCount() const { return ops_count; }
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 28d31e3797..a08b02485f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
- transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
@@ -282,6 +281,14 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
}
transformations.Add(new ResolveConstantConcatenation);
+ // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+ // conv, so we need to make sure we don't convert to dilated depthwise conv
+ // when outputing to TF GraphDef.
+ auto* identify_dilated_conv = new IdentifyDilatedConv;
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ identify_dilated_conv->set_identify_depthwise_conv(false);
+ }
+ transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 19359cb612..ac76712aeb 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -599,8 +599,8 @@ class _InternalTPUContext(object):
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
- 'TPUEstimator.evaluate should be running on single TPU worker. '
- 'got {}.'.format(num_hosts))
+ 'TPUEstimator.evaluate should be running on single TPU'
+ ' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9bcf5b0865..e82dd13b31 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1067,7 +1067,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1075,6 +1074,13 @@ tf_gen_op_libs(
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
@@ -2095,6 +2101,7 @@ cc_library(
deps = tf_additional_lib_deps() + [
"@com_google_absl//absl/strings",
"//third_party/eigen3",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
)
@@ -2287,6 +2294,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2319,6 +2327,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..4cb8955dcb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..a82dae9e48
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..e22d980424
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..8f0b1db45d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 364071e066..2d74bf2b28 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -304,7 +304,7 @@ class BFCAllocator : public Allocator {
};
// Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
+ static size_t RoundedBytes(size_t bytes);
// Try to add a new memory region that can satisfy an allocation of
// 'rounded_bytes' bytes. Returns true on success and false on
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3be4..fb76d6ac29 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -106,6 +106,10 @@ class Device : public DeviceBase {
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index d0a0767d6b..98719542c0 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2301,13 +2301,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 44ffce77a1..42021e51f3 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,6 +22,39 @@ limitations under the License.
namespace tensorflow {
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+ const char* force_allow_growth_string =
+ std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ if (force_allow_growth_string == nullptr) {
+ return gpu_options.allow_growth();
+ }
+
+ if (strcmp("false", force_allow_growth_string) == 0) {
+ if (gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return false;
+ } else if (strcmp("true", force_allow_growth_string) == 0) {
+ if (!gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return true;
+ }
+
+ LOG(ERROR)
+ << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+ << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+ << " values are \"true\" or \"false\". Using original config value"
+ << " of " << gpu_options.allow_growth() << ".";
+ return gpu_options.allow_growth();
+}
+
GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
size_t total_memory, const string& name)
: GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
@@ -30,7 +63,7 @@ GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(sub_allocator, total_memory, gpu_options.allow_growth(),
- name) {}
+ : BFCAllocator(sub_allocator, total_memory,
+ GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 3470f7a9f7..d4c9cee89a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -93,6 +93,9 @@ class GPUBFCAllocator : public BFCAllocator {
~GPUBFCAllocator() override {}
TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+ static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index e313135d8d..60e82ed13b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -410,6 +410,8 @@ BENCHMARK(BM_AllocationDelayed)->Arg(1)->Arg(10)->Arg(100)->Arg(1000);
class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
protected:
+ void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
// The following test methods are called from tests. The reason for this is
// that this class is a friend class to BFCAllocator, but tests are not, so
// only methods inside this class can access private members of BFCAllocator.
@@ -510,6 +512,56 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
}
+
+ void TestForceAllowGrowth() {
+ PlatformGpuId platform_gpu_id(0);
+ GPUOptions options;
+ // Unset flag value uses provided option.
+ unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ options.set_allow_growth(true);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_0_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unset_flag_allocator.curr_region_allocation_bytes_);
+
+ // Unparseable flag value uses provided option.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_1_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+ // Max of 2GiB total memory. Env variable set forces allow_growth, which
+ // does an initial allocation of 1MiB.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+ options.set_allow_growth(false);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_2_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+ // If env variable forces allow_growth disabled, all available memory is
+ // allocated.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_3_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+ force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+ }
};
TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -518,6 +570,10 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, Log2FloorNonZeroSlow) {
TestLog2FloorNonZeroSlow();
}
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+ TestForceAllowGrowth();
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index ec93b9aad9..016d1a92c1 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -103,6 +103,7 @@ limitations under the License.
#include <iterator>
#include <type_traits>
+#include "absl/base/macros.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@ namespace tensorflow {
namespace internal {
-// DEPRECATED: Use GetFeature instead.
// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,9 +315,9 @@ bool HasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, GetFeatures(example));
}
-// DEPRECATED: use HasFeature instead.
// TODO(gorban): update all clients in a followup CL.
template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
bool ExampleHasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, example);
}
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40c93..af59500aee 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6a90..7a5d942486 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f18240b5..bf7593bc5f 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 53ac639b4c..446c31b17f 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -176,9 +177,9 @@ class DeviceBase {
return nullptr;
}
- // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
// This method is provided for backwards compatibility, and will be removed
// in a future release.
+ ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa517d..eb9c79ff2d 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -948,9 +948,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1023,26 @@ string SummarizeArray(int64 limit, int64 num_elts,
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
- if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1051,54 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1118,9 @@ string Tensor::SummarizeValue(int64 max_entries) const {
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 696fd277cd..5f5d2021a4 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -430,7 +430,7 @@ class Tensor {
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 9a78cdc91e..fc05c86990 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1295,6 +1295,63 @@ TEST(SummarizeValue, STRING) {
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index bd0284d43a..b00196f587 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@ namespace test {
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index e84df10778..7128a50be0 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -49,6 +49,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -67,6 +68,7 @@ tf_cc_test(
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -78,6 +80,40 @@ tf_cc_test(
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -137,6 +173,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":graph_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
@@ -409,3 +446,42 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000000..e95ea1a4c1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+template <typename Predicate, typename Collection>
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ std::vector<int> indices = {};
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ indices.push_back(idx);
+ }
+ idx++;
+ }
+ return indices;
+}
+
+} // namespace
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+
+ return indices.empty() ? -1 : indices.front();
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000000..d4ce824652
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000000..3739e20eb1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c04b0..b3bfee138f 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function,
auto* if_node = fused_function->add_node_def();
// This is guaranteed to succeed.
TF_CHECK_OK(if_builder.Finalize(if_node));
- graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466080..e667affeea 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index d4ab444036..b3f60e34f9 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -108,26 +108,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd) {
- NodeDef* node = fd->add_node_def();
- if (!name.empty()) {
- node->set_name(string(name));
- } else {
- SetUniqueFunctionNodeName(op, fd, node);
- }
- node->set_op(string(op));
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -196,6 +176,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
+bool ContainsGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return FindGraphNodeWithName(name, graph) != -1;
}
@@ -204,18 +189,14 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
+ return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
@@ -237,31 +218,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
-}
-
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -284,17 +240,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = string(prefix);
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
string name = string(prefix);
@@ -305,7 +250,6 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
-
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c232d..1652afcd9e 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,12 +37,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd);
-
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -76,13 +70,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -95,14 +82,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,11 +98,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
// Sets the node name using the `prefix` name as a prefix while guaranteeing the
// name is unique across the graph.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b880..6877c207c4 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -112,20 +112,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +136,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) {
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +195,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,43 +206,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
- FunctionDef func;
- const char* op_name = "xxx";
- AddNode(op_name, op_name, {}, {}, &func);
-
- const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
- EXPECT_EQ(node1.op(), op_name);
- EXPECT_EQ(node1.input_size(), 0);
- EXPECT_EQ(node1.attr_size(), 0);
-
- const std::vector<string> inputs({"input1", "input2"});
- AddNode("", op_name, inputs, {}, &func);
- const NodeDef& node2 =
- func.node_def(FindFunctionNodeWithName("xxx/_2", func));
- EXPECT_EQ(node2.op(), op_name);
- EXPECT_EQ(node2.attr_size(), 0);
- EXPECT_EQ(node2.input_size(), inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- EXPECT_EQ(node2.input(i), inputs[i]);
- }
-
- AttrValue a1, a2;
- a1.set_type(DT_INT32);
- a2.set_type(DT_INT64);
- const std::vector<std::pair<string, AttrValue>> attrs(
- {{"attr1", a1}, {"attr2", a2}});
- AddNode("", op_name, {}, attrs, &func);
- const NodeDef& node3 =
- func.node_def(FindFunctionNodeWithName("xxx/_3", func));
- EXPECT_EQ(node3.op(), op_name);
- EXPECT_EQ(node3.input_size(), 0);
- EXPECT_EQ(node3.attr_size(), attrs.size());
- for (size_t i = 0; i < attrs.size(); ++i) {
- EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
- }
-}
-
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77eb7..07766aa7b3 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -52,8 +53,8 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
- graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
- map_defun_node);
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000000..6a59eb0d32
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+ const string& output_retval, const DataType t) {
+ // Set to unknown shape
+ TensorShapeProto tensor_shape_proto;
+ PartialTensorShape().AsProto(&tensor_shape_proto);
+
+ function_utils::AddFunctionOutputWithUniqueName(
+ "vectorized_out", output_retval, map_defun_fn, t);
+
+ *(*map_defun_node->mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape() = tensor_shape_proto;
+ (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, int output_position) {
+ DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs =
+ map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+ // Remove from map_defun_fn's ret dict and output args
+ map_defun_fn->mutable_ret()->erase(
+ map_defun_fn->signature().output_arg(output_position).name());
+ map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+ output_position, 1);
+
+ // Renumber outputs that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i + 1),
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i),
+ outer_scope);
+ }
+ map_defun_node->mutable_attr()
+ ->at("output_shapes")
+ .mutable_list()
+ ->mutable_shape()
+ ->DeleteSubrange(output_position, 1);
+ map_defun_node->mutable_attr()
+ ->at("output_types")
+ .mutable_list()
+ ->mutable_type()
+ ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, const NodeDef& cast_node,
+ const FunctionDefTensorDesc& output_desc,
+ std::map<string, string>* conversion_map) {
+ if (output_desc.node_output != "y" || output_desc.position != 0) {
+ // We expect the Cast node to have only one output, with the name "y".
+ return errors::Internal("Cannot convert Cast op output.");
+ }
+
+ // Promote Cast inputs to outputs of MapDefun
+ DCHECK_EQ(cast_node.input_size(), 1);
+ AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0),
+ cast_node.attr().at("SrcT").type());
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = cast_node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", cast_node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(
+ 0, strings::StrCat(map_defun_node->name(), ":output:",
+ map_defun_fn->signature().output_arg_size() - 1));
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+}
+
+Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, const NodeDef& unpack_node,
+ const FunctionDefTensorDesc& output_desc,
+ std::map<string, string>* conversion_map) {
+ if (output_desc.node_output != "output") {
+ return errors::Internal("Cannot convert Unpack op output.");
+ }
+
+ // Promote Unpack inputs to outputs of MapDefun
+ AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0),
+ unpack_node.attr().at("T").type());
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = unpack_node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ unpack_node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(
+ 0, strings::StrCat(map_defun_node->name(), ":output:",
+ map_defun_fn->signature().output_arg_size() - 1));
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+ const std::set<string>& unconvertible,
+ FunctionDefTensorDesc* f) {
+ for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+ const string& ret_key = function.signature().output_arg(i).name();
+ *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+ if (unconvertible.find(f->node_name) == unconvertible.end()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+ Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node)
+ : outer_scope_(outer_scope),
+ map_defun_fn_(map_defun_fn),
+ map_defun_node_(map_defun_node) {}
+
+ // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+ // the outer_scope_, until there are no convertible outputs remaining.
+ // This method is idempotent.
+ void Vectorize();
+
+ private:
+ // Vectorizes the map defun function's output at output_position
+ Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+ // Given a descriptor of the original output tensor, gets a string
+ // corresponding to the converted output tensor.
+ Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+ string* converted);
+ Status AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+ // and modify map_defun_node_ attrs accordingly
+ // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // outer_scope_, since the vectorized version of Cast is itself.
+ // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+ // are now outputs of map_defun_node_)
+ // 4) For each output of the old node, add the mapping of output strings to
+ // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+ Status AddConversionMappingFromOp(const NodeDef& node,
+ const FunctionDefTensorDesc& output_desc);
+
+ // Maps a tensor name to the name of the corresponding vectorized tensor. For
+ // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+ std::map<string, string> conversion_map_;
+ // Unconvertible node names
+ std::set<string> unconvertible_;
+
+ FunctionDef* outer_scope_;
+ FunctionDef* map_defun_fn_;
+ NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+ const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+ for (const string& input_name : node.input()) {
+ if (IsControlInput(input_name)) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ // TODO(rachelim): Have some mechanism for registering converters and some
+ // uniform, simpler way to represent them.
+
+ // TODO(rachelim): Do step (1) outside of the individual op converters, when
+ // we know how to find out the type of the input.
+ if (node.op() == "Cast") {
+ return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
+ output_desc, &conversion_map_);
+ } else if (node.op() == "Unpack") {
+ return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
+ output_desc, &conversion_map_);
+ }
+ return errors::Unimplemented("Op converter for \"", node.op(),
+ "\" not implemented yet");
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc) {
+ int input_index = function_utils::FindFunctionInputWithName(
+ output_desc.node_name, *map_defun_fn_);
+ if (input_index == -1) {
+ return errors::Internal("Cannot convert non-existent input.");
+ }
+
+ conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+ const FunctionDefTensorDesc& output_desc, string* converted) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+ *converted = *found;
+ return Status::OK();
+ }
+
+ int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+ *map_defun_fn_);
+ if (index == -1) { // The output comes from an input
+ TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+ map_defun_fn_->node_def(index), output_desc));
+ }
+ *converted = conversion_map_.at(output_desc.full_str);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+ const FunctionDefTensorDesc& output_desc) {
+ string converted_output_name;
+ TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+ // Remove the old output and make everything that referenced it point
+ // to the new string
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+ converted_output_name, outer_scope_);
+ RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+ output_position);
+
+ return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+ while (true) {
+ FunctionDefTensorDesc desc;
+ int output_position =
+ FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+ if (output_position == -1) break;
+
+ if (!ConvertOutput(output_position, desc).ok()) {
+ unconvertible_.insert(desc.node_name);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->signature().output_arg_size() == 0) {
+ outer_scope_->mutable_node_def()->DeleteSubrange(
+ function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+ *outer_scope_),
+ 1);
+ }
+
+ if (!unconvertible_.empty()) {
+ VLOG(2) << "The following nodes could not be converted: ["
+ << absl::StrJoin(unconvertible_, ", ") << "].";
+ }
+}
+} // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node) {
+ Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000000..bb405faa77
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000000..e129fa9237
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+ EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | +---v--+ | |
+// | | |Const | | | Op0 | | |
+// | | +---v--+ | +---+--+ | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | XOp2 | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ NodeDef* x_op1 =
+ function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+
+ NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+ CHECK_NOTNULL(x_op2);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ // The attrs aren't relevant
+ NodeDef* print_op =
+ function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 1ed1b22931..4b0cbfaa82 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -352,7 +352,7 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- LOG(INFO) << "Starting optimization for grappler item: " << item.id;
+ VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7aa1169061..b0d04a7213 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2707,6 +2707,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2764,6 +2765,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4401,6 +4403,7 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
@@ -4432,6 +4435,30 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..8bafd5739d 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -74,8 +75,7 @@ class PrintOp : public OpKernel {
string msg;
strings::StrAppend(&msg, message_);
for (int i = 1; i < ctx->num_inputs(); ++i) {
- strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
- "]");
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_));
}
std::cerr << msg << std::endl;
}
@@ -90,6 +90,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788448..82dfece4a2 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92f94..272aa3b4f5 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@ class QueueBase : public QueueInterface {
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index e4ca89eca3..5318d8c133 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,15 +76,7 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-REGISTER_KERNEL_BUILDER(
- Name("Sum")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("T")
- .TypeConstraint<int32>("Tidx")
- .HostMemory("input")
- .HostMemory("output")
- .HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
+
#endif
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index e59958749c..2360432d96 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -38880,6 +38880,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -70188,6 +70212,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a897a..f84142c992 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd867561b..13c3b933f4 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211767..2034d3601b 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -44,6 +46,23 @@ REGISTER_OP("Print")
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4ece1c8953..29e327753b 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -19521,6 +19521,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -32735,6 +32759,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15dc8a..99159839d0 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@ REGISTER_OP("AsString")
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83228fab6f..83ea8539ed 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@ limitations under the License.
#ifdef _WIN32
#include <io.h> // for _mktemp
#endif
+#include "absl/base/macros.h"
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@ constexpr int kGetChildrenDefaultPageSize = 1000;
// The HTTP response code "308 Resume Incomplete".
constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
// The environment variable that disables the GCS block cache for reads.
// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
index 1ab682182c..5823374d1a 100644
--- a/tensorflow/core/platform/default/cord.h
+++ b/tensorflow/core/platform/default/cord.h
@@ -16,9 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
-class Cord;
-namespace absl {
-using ::Cord;
-} // namespace absl
+// TODO(ebrevdo): Fill this in.
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 30059dc02e..156af6cdea 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -255,10 +255,13 @@ class WritableFile {
/// \brief Append 'data' to the file.
virtual Status Append(StringPiece data) = 0;
+ // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
// \brief Append 'data' to the file.
virtual Status Append(const absl::Cord& cord) {
return errors::Unimplemented("Append(absl::Cord) is not implemented");
}
+#endif
/// \brief Close the file.
///
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65f60..b9ca8ab395 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/base/macros.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@ class SparseTensor {
SparseTensor() : dims_(0) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
@@ -237,9 +238,10 @@ class SparseTensor {
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- // DEPRECATED: use the form of Split() that takes an output pointer and
- // returns a status instead.
template <typename T>
+ ABSL_DEPRECATED(
+ "Use the form of Split() that takes an output pointer and returns a "
+ "status instead.")
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
const int num_split,
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index eb636dbf54..1d72bcd2b6 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3741,98 +3741,28 @@ func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Outpu
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor.
-//
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// For example:
-//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
-//
-// tf.sparse_segment_sum_with_num_segments(
-// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// # [0 0 0 0]
-// # [0 0 0 0]]
-//
-// tf.sparse_segment_sum_with_num_segments(c,
-// tf.constant([0, 1]),
-// tf.constant([0, 2],
-// num_segments=4))
-// # => [[ 1 2 3 4]
-// # [ 0 0 0 0]
-// # [-1 -2 -3 -4]
-// # [ 0 0 0 0]]
-// ```
-//
-// Arguments:
-//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentSumWithNumSegments",
- Input: []tf.Input{
- data, indices, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
-
-// PreventGradientMessage sets the optional message attribute to value.
-//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
-// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
- return func(m optionalAttr) {
- m["message"] = value
- }
-}
-
-// An identity op that triggers an error if a gradient is requested.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
+// Makes the summary of accumulated stats for the batch.
//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function. This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
+// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
//
// Arguments:
-// input: any tensor.
+// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+// max_splits: int; the maximum number of splits possible in the whole tree.
+// num_buckets: int; equals to the maximum possible value of bucketized feature.
//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
+ attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
opspec := tf.OpSpec{
- Type: "PreventGradient",
+ Type: "BoostedTreesMakeStatsSummary",
Input: []tf.Input{
- input,
+ node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
},
Attrs: attrs,
}
@@ -3840,21 +3770,6 @@ func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientA
return op.Output(0)
}
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Asin",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along sparse segments of a tensor.
//
// Read
@@ -4564,37 +4479,142 @@ func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
+// Computes exponential of x element-wise. \\(y = e^x\\).
+func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exp",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
-// NthElementReverse sets the optional reverse attribute to value.
+// Returns an element-wise indication of the sign of a number.
//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
+// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`.
+//
+// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
+func Sign(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sign",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ArgMinAttr is an optional argument to ArgMin.
+type ArgMinAttr func(optionalAttr)
+
+// ArgMinOutputType sets the optional output_type attribute to value.
+// If not specified, defaults to DT_INT64
+func ArgMinOutputType(value tf.DataType) ArgMinAttr {
return func(m optionalAttr) {
- m["reverse"] = value
+ m["output_type"] = value
}
}
-// Finds values of the `n`-th order statistic for the last dimension.
+// Returns the index with the smallest value across dimensions of a tensor.
//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
+// Note that in case of ties the identity of the return value is not guaranteed.
//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+// Arguments:
//
-// values.shape = input.shape[:-1]
+// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
+// Describes which dimension of the input Tensor to reduce across. For vectors,
+// use dimension = 0.
+func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ArgMin",
+ Input: []tf.Input{
+ input, dimension,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Convert the quantized 'input' tensor into a lower-precision 'output', using the
+//
+// output range specified with 'requested_output_min' and 'requested_output_max'.
+//
+// [input_min, input_max] are scalar floats that specify the range for the float
+// interpretation of the 'input' data. For example, if input_min is -1.0f and
+// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0
+// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f.
//
// Arguments:
-// input: 1-D or higher with last dimension at least `n+1`.
-// n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+// input_min: The float value that the minimum quantized input value represents.
+// input_max: The float value that the maximum quantized input value represents.
+// requested_output_min: The float value that the minimum quantized output value represents.
+// requested_output_max: The float value that the maximum quantized output value represents.
+// out_type: The type of the output. Should be a lower bit depth than Tinput.
+//
+// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output.
+func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"out_type": out_type}
+ opspec := tf.OpSpec{
+ Type: "Requantize",
+ Input: []tf.Input{
+ input, input_min, input_max, requested_output_min, requested_output_max,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
+//
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+ return func(m optionalAttr) {
+ m["message"] = value
+ }
+}
+
+// An identity op that triggers an error if a gradient is requested.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function. This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
+//
+// Arguments:
+// input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -4603,9 +4623,9 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme
a(attrs)
}
opspec := tf.OpSpec{
- Type: "NthElement",
+ Type: "PreventGradient",
Input: []tf.Input{
- input, n,
+ input,
},
Attrs: attrs,
}
@@ -4613,6 +4633,21 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme
return op.Output(0)
}
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Asin",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the maximum along segments of a tensor.
//
// Read
@@ -4662,61 +4697,37 @@ func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num
return op.Output(0)
}
-// Computes exponential of x element-wise. \\(y = e^x\\).
-func Exp(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Exp",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
-// Returns an element-wise indication of the sign of a number.
-//
-// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`.
+// NthElementReverse sets the optional reverse attribute to value.
//
-// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
-func Sign(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Sign",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ArgMinAttr is an optional argument to ArgMin.
-type ArgMinAttr func(optionalAttr)
-
-// ArgMinOutputType sets the optional output_type attribute to value.
-// If not specified, defaults to DT_INT64
-func ArgMinOutputType(value tf.DataType) ArgMinAttr {
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
return func(m optionalAttr) {
- m["output_type"] = value
+ m["reverse"] = value
}
}
-// Returns the index with the smallest value across dimensions of a tensor.
+// Finds values of the `n`-th order statistic for the last dimension.
//
-// Note that in case of ties the identity of the return value is not guaranteed.
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+// values.shape = input.shape[:-1]
//
// Arguments:
+// input: 1-D or higher with last dimension at least `n+1`.
+// n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
//
-// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
-// Describes which dimension of the input Tensor to reduce across. For vectors,
-// use dimension = 0.
-func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) {
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
if scope.Err() != nil {
return
}
@@ -4725,9 +4736,9 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ArgMin",
+ Type: "NthElement",
Input: []tf.Input{
- input, dimension,
+ input, n,
},
Attrs: attrs,
}
@@ -4735,38 +4746,56 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM
return op.Output(0)
}
-// Convert the quantized 'input' tensor into a lower-precision 'output', using the
+// Computes the sum along sparse segments of a tensor.
//
-// output range specified with 'requested_output_min' and 'requested_output_max'.
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
//
-// [input_min, input_max] are scalar floats that specify the range for the float
-// interpretation of the 'input' data. For example, if input_min is -1.0f and
-// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0
-// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// # [0 0 0 0]
+// # [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+// tf.constant([0, 1]),
+// tf.constant([0, 2],
+// num_segments=4))
+// # => [[ 1 2 3 4]
+// # [ 0 0 0 0]
+// # [-1 -2 -3 -4]
+// # [ 0 0 0 0]]
+// ```
//
// Arguments:
//
-// input_min: The float value that the minimum quantized input value represents.
-// input_max: The float value that the maximum quantized input value represents.
-// requested_output_min: The float value that the minimum quantized output value represents.
-// requested_output_max: The float value that the maximum quantized output value represents.
-// out_type: The type of the output. Should be a lower bit depth than Tinput.
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
//
-// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output.
-func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"out_type": out_type}
opspec := tf.OpSpec{
- Type: "Requantize",
+ Type: "SparseSegmentSumWithNumSegments",
Input: []tf.Input{
- input, input_min, input_max, requested_output_min, requested_output_max,
+ data, indices, segment_ids, num_segments,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
// Computes the determinant of one or more square matrices.
@@ -9229,6 +9258,66 @@ func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, opti
return op.Output(0)
}
+// RandomUniformIntAttr is an optional argument to RandomUniformInt.
+type RandomUniformIntAttr func(optionalAttr)
+
+// RandomUniformIntSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two. The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// minval: 0-D. Inclusive lower bound on the generated integers.
+// maxval: 0-D. Exclusive upper bound on the generated integers.
+//
+// Returns A tensor of the specified shape filled with uniform random integers.
+func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomUniformInt",
+ Input: []tf.Input{
+ shape, minval, maxval,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -11926,38 +12015,6 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
return op.Output(0)
}
-// The gradient operator for the SparseAdd op.
-//
-// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
-// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
-// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
-// values of A and B.
-//
-// Arguments:
-// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
-// the non-empty values of the sum.
-// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
-// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
-// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
-// `[nnz(sum), ndims]`.
-//
-// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
-// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
-// non-empty values of B.
-func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseAddGrad",
- Input: []tf.Input{
- backprop_val_grad, a_indices, b_indices, sum_indices,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// String lengths of `input`.
//
// Computes the length of each string given in the input tensor.
@@ -12814,6 +12871,123 @@ func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
return op.Output(0)
}
+// ShapeAttr is an optional argument to Shape.
+type ShapeAttr func(optionalAttr)
+
+// ShapeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func ShapeOutType(value tf.DataType) ShapeAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns the shape of a tensor.
+//
+// This operation returns a 1-D integer tensor representing the shape of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// shape(t) ==> [2, 2, 3]
+// ```
+func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Shape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pow",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes fingerprints of the input strings.
+//
+// Arguments:
+// input: vector of strings to compute fingerprints on.
+//
+// Returns a (N,2) shaped matrix where N is the number of elements in the input
+// vector. Each row contains the low and high parts of the fingerprint.
+func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SdcaFprint",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// The gradient operator for the SparseAdd op.
+//
+// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
+// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
+// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
+// values of A and B.
+//
+// Arguments:
+// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
+// the non-empty values of the sum.
+// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
+// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
+// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
+// `[nnz(sum), ndims]`.
+//
+// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
+// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
+// non-empty values of B.
+func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseAddGrad",
+ Input: []tf.Input{
+ backprop_val_grad, a_indices, b_indices, sum_indices,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// Computes the mean along segments of a tensor.
//
// Read
@@ -13006,6 +13180,79 @@ func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Outpu
return op.Output(0)
}
+// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
+type RandomPoissonV2Attr func(optionalAttr)
+
+// RandomPoissonV2Seed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomPoissonV2Dtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Outputs random values from the Poisson distribution(s) described by rate.
+//
+// This op uses two algorithms, depending on rate. If rate >= 10, then
+// the algorithm by Hormann is used to acquire samples via
+// transformation-rejection.
+// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+//
+// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+// random variables.
+// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+// Programming, Volume 2. Addison Wesley
+//
+// Arguments:
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in rate.
+// rate: A tensor in which each scalar is a "rate" parameter describing the
+// associated poisson distribution.
+//
+// Returns A tensor with shape `shape + shape(rate)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `rate[i0, i1, ...iN]`.
+func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoissonV2",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
type DecodeAndCropJpegAttr func(optionalAttr)
@@ -20288,164 +20535,6 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
-// ShapeAttr is an optional argument to Shape.
-type ShapeAttr func(optionalAttr)
-
-// ShapeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func ShapeOutType(value tf.DataType) ShapeAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Returns the shape of a tensor.
-//
-// This operation returns a 1-D integer tensor representing the shape of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// shape(t) ==> [2, 2, 3]
-// ```
-func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Shape",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pow",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes fingerprints of the input strings.
-//
-// Arguments:
-// input: vector of strings to compute fingerprints on.
-//
-// Returns a (N,2) shaped matrix where N is the number of elements in the input
-// vector. Each row contains the low and high parts of the fingerprint.
-func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SdcaFprint",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
-type RandomPoissonV2Attr func(optionalAttr)
-
-// RandomPoissonV2Seed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomPoissonV2Dtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Outputs random values from the Poisson distribution(s) described by rate.
-//
-// This op uses two algorithms, depending on rate. If rate >= 10, then
-// the algorithm by Hormann is used to acquire samples via
-// transformation-rejection.
-// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
-//
-// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
-// random variables.
-// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
-// Programming, Volume 2. Addison Wesley
-//
-// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in rate.
-// rate: A tensor in which each scalar is a "rate" parameter describing the
-// associated poisson distribution.
-//
-// Returns A tensor with shape `shape + shape(rate)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `rate[i0, i1, ...iN]`.
-func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoissonV2",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
type MatrixTriangularSolveAttr func(optionalAttr)
@@ -20959,66 +21048,6 @@ func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, nu
return op.Output(0)
}
-// RandomUniformIntAttr is an optional argument to RandomUniformInt.
-type RandomUniformIntAttr func(optionalAttr)
-
-// RandomUniformIntSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random integers from a uniform distribution.
-//
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
-//
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two. The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// minval: 0-D. Inclusive lower bound on the generated integers.
-// maxval: 0-D. Exclusive upper bound on the generated integers.
-//
-// Returns A tensor of the specified shape filled with uniform random integers.
-func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomUniformInt",
- Input: []tf.Input{
- shape, minval, maxval,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the mean along sparse segments of a tensor.
//
// Read
@@ -28116,35 +28145,6 @@ func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Op
return scope.AddOperation(opspec)
}
-// Makes the summary of accumulated stats for the batch.
-//
-// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
-//
-// Arguments:
-// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
-// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
-// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
-// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
-// max_splits: int; the maximum number of splits possible in the whole tree.
-// num_buckets: int; equals to the maximum possible value of bucketized feature.
-//
-// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
-func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "BoostedTreesMakeStatsSummary",
- Input: []tf.Input{
- node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adjust the contrast of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d70e9c5798..9730e9933a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2324,6 +2324,8 @@ py_library(
deps = [
":framework_for_generated_wrappers",
":logging_ops_gen",
+ ":platform",
+ ":string_ops",
":util",
],
)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 39a2922ac0..ef7527d887 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -463,7 +463,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
}
// Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
@@ -782,7 +782,7 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
-%unignore ResourceHandleShapeAndType;
+%unignore HandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 8edd6419d3..419c376b45 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 19)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 20)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c1bc27d443..a2686c68a9 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -34,6 +34,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -146,6 +147,7 @@ cuda_py_test(
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 50a6ce6324..d95e0fe721 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -608,8 +608,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5df8..32731747b7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase):
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index a68c6ab3b4..bcb1881264 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -73,16 +73,36 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -91,12 +111,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -435,6 +453,7 @@ class _EagerDefinedFunction(object):
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -511,6 +530,8 @@ class _EagerDefinedFunction(object):
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 4a1bde3f5e..e4513cc87c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -48,6 +48,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -438,10 +439,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -465,10 +473,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -477,12 +492,34 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d107c..5f5af4ab6c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@ import collections
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index a0f6be459e..196e20e4d7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,7 +1403,24 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = PyEagerTensor_ID(tensor);
@@ -1183,16 +1428,16 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,180 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- if (EagerTensor_CheckExact(tensor)) {
- return PyEagerTensor_NumElements(tensor);
- }
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- Py_DECREF(arglist);
- if (result == nullptr) {
- // The caller detects whether a python exception has been raised.
- return -1;
- }
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 36048a2bfd..756d32d03f 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -422,9 +422,13 @@ class _EnsembleGrower(object):
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
- if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
- and tree_hparams.tree_complexity <= 0):
- raise ValueError('For pruning, tree_complexity must be positive.')
+ if tree_hparams.tree_complexity > 0:
+ if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError(
+ 'Tree complexity have no effect unless pruning mode is chosen.')
+ else:
+ if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError('For pruning, tree_complexity must be positive.')
# pylint: enable=protected-access
@abc.abstractmethod
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 9409cb5cc7..d4cb3e27d0 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -564,6 +564,41 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
"""Test debug/model explainability outputs for individual predictions.
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a009..68b3170dfe 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph):
if handle_data:
handle_data = handle_data.SerializeToString()
else:
- handle_data = c_api.GetResourceHandleShapeAndType(
- tensor.graph._c_graph, tensor._as_tf_output())
+ handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
if handle_data:
- c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
- ph._as_tf_output(),
- compat.as_bytes(handle_data))
+ c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+ compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 343f52fe8f..8bb177939e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2532,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
+ serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b7398238f5..c302072aa1 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
+import os
import math
import random
import re
@@ -868,6 +869,19 @@ def device(use_gpu):
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -934,6 +948,52 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f4ec..6dfbbf3694 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@ class BaseLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@ class ProgbarLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@ class TensorBoard(Callback):
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@ class TensorBoard(Callback):
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index d133595793..26c5ec4efc 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -293,11 +293,16 @@ def _experimental_fit_loop(
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ if steps_per_epoch is None:
+ raise ValueError('steps_per_epoch should be specified in the fit call.')
+ steps_per_run_var = K.variable(
+ value=min(steps_per_epoch, current_strategy.steps_per_run),
+ dtype='int32',
+ name='steps_per_run_var')
+
with current_strategy.scope():
- # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
- # steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=current_strategy.steps_per_run,
+ step_fn, iterator, iterations=steps_per_run_var,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -309,14 +314,6 @@ def _experimental_fit_loop(
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
-
- assert steps_per_epoch is not None
-
- # TODO(sourabhbajaj): Convert this into a proper validation function
- if callbacks:
- raise NotImplementedError(
- 'Callbacks are not supported with TPUStrategy right now.')
-
callbacks = cbks.configure_callbacks(
callbacks,
model,
@@ -327,17 +324,26 @@ def _experimental_fit_loop(
steps_per_epoch=steps_per_epoch,
verbose=verbose)
# TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
- # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+
+ # Calculate the steps each time on the device.
+ steps_to_run = [current_strategy.steps_per_run] * (
+ steps_per_epoch // current_strategy.steps_per_run)
+ if steps_per_epoch % current_strategy.steps_per_run:
+ steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
- # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
- # and batch_size
- batch_logs = {'batch': step_index, 'size': 1}
+ step_index = 0
+ prev_step_count = None
+ for step_count in steps_to_run:
+ batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
+ if prev_step_count is None or step_count != prev_step_count:
+ steps_per_run_var.load(step_count, K.get_session())
+ prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
@@ -350,6 +356,7 @@ def _experimental_fit_loop(
batch_logs.update(outputs)
callbacks.on_batch_end(step_index, batch_logs)
+ step_index = step_index + step_count
if callbacks.model.stop_training:
break
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index a048eaa69f..9dc6df77f1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -961,6 +961,19 @@ tf_py_test(
)
tf_py_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+tf_py_test(
name = "string_join_op_test",
size = "small",
srcs = ["string_join_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 82729b9e27..79fe9de62f 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,14 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
class LoggingOpsTest(test.TestCase):
@@ -57,6 +66,305 @@ class LoggingOpsTest(test.TestCase):
out.eval()
+class PrintV2Test(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensor(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=1)
+ self.evaluate(print_op)
+
+ expected = "[0 ... 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=2)
+ self.evaluate(print_op)
+
+ expected = "[0 1 ... 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=3)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=-1)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 3 4 5 6 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneVariable(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(var)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoVariablesInStructWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ self.evaluate(plus_one)
+ print_op = logging_ops.print_v2(var_one, {"second": var_two})
+ self.evaluate(print_op)
+ expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoTensors(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, tensor * 10)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintPlaceholderGeneration(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+ self.evaluate(print_op)
+ expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintNoTensors(self):
+ with self.test_session():
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+ self.evaluate(print_op)
+ expected = "23 [23, 5] {'6': 12}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintFloatScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor(434.43)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "434.43"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintStringScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor("scalar")
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "scalar"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintComplexTensorStruct(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+ big_tensor = math_ops.mul(tensor, 10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ "first:", tensor, "middle:",
+ {"small": small_tensor, "Big": big_tensor}, 10,
+ [tensor * 2, tensor])
+ self.evaluate(print_op)
+ # Note that the keys in the dict will always be sorted,
+ # so 'Big' comes before 'small'
+ expected = ("first: [0 1 2 ... 7 8 9] "
+ "middle: {'Big': [0 10 20 ... 70 80 90], "
+ "'small': [0.3 12.4 -16.1]} "
+ "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensor(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(sparse)
+ self.evaluate(print_op)
+ expected = ("'SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensorInDataStruct(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2([sparse])
+ self.evaluate(print_op)
+ expected = ("['SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorStdout(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stdout) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=sys.stdout)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ self.assertTrue("I" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ self.assertTrue("W" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ self.assertTrue("E" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidOutputStreamRaisesError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.assertRaises(ValueError):
+ print_op = logging_ops.print_v2(
+ tensor, output_stream="unknown")
+ self.evaluate(print_op)
+
+ def testPrintOpName(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ print_op = logging_ops.print_v2(tensor, name="print_name")
+ self.assertEqual(print_op.name, "print_name")
+
+ def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ formatted_string = string_ops.string_format("{}", tensor)
+ print_op = logging_ops.print_v2(formatted_string)
+ self.evaluate(print_op)
+ graph_ops = ops.get_default_graph().get_operations()
+ format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+ # Should be only 1 format_op for graph mode.
+ self.assertEqual(len(format_ops), 1)
+
+ def testPrintOneTensorEagerOnOpCreate(self):
+ with self.test_session():
+ with context.eager_mode():
+ tensor = math_ops.range(10)
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed:
+ logging_ops.print_v2(tensor)
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+ @function.defun
+ def f():
+ tensor = math_ops.range(10)
+ logging_ops.print_v2(tensor)
+ return tensor
+
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed_one:
+ x = f()
+ self.evaluate(x)
+ self.assertTrue((expected + "\n") in printed_one.contents())
+
+ # We execute the function again to make sure it doesn't only print on the
+ # first call.
+ with self.captureWritesToStream(sys.stderr) as printed_two:
+ y = f()
+ self.evaluate(y)
+ self.assertTrue((expected + "\n") in printed_two.contents())
+
+
class PrintGradientTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -65,6 +373,11 @@ class PrintGradientTest(test.TestCase):
inp_printed = logging_ops.Print(inp, [inp])
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+ def testPrintString(self):
+ inp = constant_op.constant(2.0, shape=[100, 32])
+ inp_printed = logging_ops.Print(inp, ["hello"])
+ self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
def testPrintGradient(self):
with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000000..afa71db909
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDim(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", [tensor])
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableScalar(self):
+ with self.test_session():
+ var = variables.Variable(3.34)
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "3.34"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableOneDim(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatTwoVariablesWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ self.evaluate(plus_one)
+ out = self.evaluate(format_output)
+ expected = "3.14, [0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimFloat(self):
+ with self.test_session():
+ tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimMatchesSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=-1)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=1)
+ out = self.evaluate(format_output)
+ expected = "[0 ... 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = "[0 1 ... 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=10)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimAlmostSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(5)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimLessThanSummarize(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1]\n"
+ " [2 3]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorThreeDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]\n"
+ "\n"
+ " [[100 101 102 ... 107 108 109]\n"
+ " [110 111 112 ... 117 118 119]\n"
+ " [120 121 122 ... 127 128 129]\n"
+ " ...\n [170 171 172 ... 177 178 179]\n"
+ " [180 181 182 ... 187 188 189]\n"
+ " [190 191 192 ... 197 198 199]]\n"
+ "\n"
+ " [[200 201 202 ... 207 208 209]\n"
+ " [210 211 212 ... 217 218 219]\n"
+ " [220 221 222 ... 227 228 229]\n"
+ " ...\n"
+ " [270 271 272 ... 277 278 279]\n"
+ " [280 281 282 ... 287 288 289]\n"
+ " [290 291 292 ... 297 298 299]]\n"
+ "\n"
+ " ...\n"
+ "\n"
+ " [[700 701 702 ... 707 708 709]\n"
+ " [710 711 712 ... 717 718 719]\n"
+ " [720 721 722 ... 727 728 729]\n"
+ " ...\n"
+ " [770 771 772 ... 777 778 779]\n"
+ " [780 781 782 ... 787 788 789]\n"
+ " [790 791 792 ... 797 798 799]]\n"
+ "\n"
+ " [[800 801 802 ... 807 808 809]\n"
+ " [810 811 812 ... 817 818 819]\n"
+ " [820 821 822 ... 827 828 829]\n"
+ " ...\n"
+ " [870 871 872 ... 877 878 879]\n"
+ " [880 881 882 ... 887 888 889]\n"
+ " [890 891 892 ... 897 898 899]]\n"
+ "\n"
+ " [[900 901 902 ... 907 908 909]\n"
+ " [910 911 912 ... 917 918 919]\n"
+ " [920 921 922 ... 927 928 929]\n"
+ " ...\n"
+ " [970 971 972 ... 977 978 979]\n"
+ " [980 981 982 ... 987 988 989]\n"
+ " [990 991 992 ... 997 998 999]]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefixAndSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}, suffix",
+ tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplateSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}, suffix", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatNoTensor(self):
+ with self.test_session():
+ format_output = string_ops.string_format("No tensor.", ())
+ out = self.evaluate(format_output)
+ expected = "No tensor."
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatMultiTensor(self):
+ with self.test_session():
+ tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+ tensor_two = tensor_one * 10
+ format_output = string_ops.string_format("One: {},\nTwo: {}",
+ (tensor_one, tensor_two))
+ out = self.evaluate(format_output)
+ expected = ("One: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]],\n"
+ "Two: [[0 10 20 ... 70 80 90]\n"
+ " [100 110 120 ... 170 180 190]\n"
+ " [200 210 220 ... 270 280 290]\n"
+ " ...\n"
+ " [700 710 720 ... 770 780 790]\n"
+ " [800 810 820 ... 870 880 890]\n"
+ " [900 910 920 ... 970 980 990]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeOne(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=1)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 ... 9]\n"
+ " ...\n"
+ " [90 ... 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatPlaceholder(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: %t%", tensor,
+ placeholder="%t%")
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTensorCountMustMatchPlaceholderCount(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", tensor)
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", [tensor])
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"1 placeholder\(s\) in template does not match 2 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", (tensor, tensor))
+ self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 0e20fadb2b..87f8bd85a5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -610,9 +610,10 @@ def _EnforceShapeInvariant(merge_var, next_var):
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
- if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
- raise TypeError("Type %s not supported" % type(var))
- if isinstance(var, ops.IndexedSlices):
+ if not isinstance(merge_var,
+ (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(merge_var))
+ if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3140..325418d5f7 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
- Args:
- image: 4-D Tensor of shape `[batch, height, width, channels]` or
- 3-D Tensor of shape `[height, width, channels]`.
- flip_index: The dimension along which to flip the image.
- Vertical: 0, Horizontal: 1
- seed: A Python integer. Used to create a random seed. See
- `tf.set_random_seed`
- for behavior.
- scope_name: Name of the scope in which the ops are added.
- Returns:
- A tensor of the same type and shape as `image`.
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ `tf.set_random_seed`
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
- Raises:
- ValueError: if the shape of `image` not supported.
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
@@ -334,15 +333,16 @@ def _random_flip(image, flip_index, seed, scope_name):
result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
+ batch_size = array_ops.shape(image)[0]
uniform_random = random_ops.random_uniform(
- [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ [batch_size], 0, 1.0, seed=seed
)
- mirror_cond = math_ops.less(uniform_random, .5)
- return array_ops.where(
- mirror_cond,
- image,
- functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ flips = math_ops.round(
+ array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
)
+ flips = math_ops.cast(flips, image.dtype)
+ flipped_input = array_ops.reverse(image, [flip_index + 1])
+ return flips * flipped_input + (1 - flips) * image
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933f8a..4c53f33af1 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@ from tensorflow.python.util.tf_export import tf_export
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5fa4..9e477ab8af 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,7 +48,7 @@ def get_resource_handle_data(graph_op):
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ handle_data = pywrap_tensorflow.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index b2c6937368..5d949467fd 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,14 +29,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -103,6 +104,87 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 7f99d81ef3..a4580d6462 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -359,9 +358,8 @@ class DeviceDescriptionBuilder {
bool ThreadDimOk(const DeviceDescription &device_description,
const ThreadDim &thread_dim);
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
// Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
uint64 DivideCeil(uint64 x, uint64 y);
// Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ecd24..3065b5cb77 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@ class PluginRegistry {
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
+ ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index d04025b681..4a8a270afa 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@ class StreamExecutor {
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
- // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
+ ABSL_DEPRECATED("Use platform() instead.")
PlatformKind platform_kind() const { return platform_kind_; }
// Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@ class StreamExecutor {
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
- //
- // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
uint64 size) SE_MUST_USE_RESULT;
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
- //
- // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
uint64 size) SE_MUST_USE_RESULT;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 14ab885c91..6ff4343e9e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1593,6 +1593,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b9f9..c81c156518 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 323d2fc519..db90c007d4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -581,10 +581,6 @@ tf_module {
argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "Print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "abs"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1541,6 +1537,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b9f9..c81c156518 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}