aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-17 12:18:44 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-04-17 12:18:44 -0700
commit8bed1ea47d96c53db7d8b68b811b1487635d4106 (patch)
tree2260bf78d4b834a1009c9ac7ca4979a0a5b41fdf
parentf1b892b608a3e2b5fa8a16c03ac3c3ca6293ad65 (diff)
parentb50142067e776fc86ce2ba3d01d01c7c16da671f (diff)
Merge commit for internal changes
-rw-r--r--configure.py2
-rw-r--r--tensorflow/c/eager/c_api.cc14
-rw-r--r--tensorflow/c/eager/runtime.cc9
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc1
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD25
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc159
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h4
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc63
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h30
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc145
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc228
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc6
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc32
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc17
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/__init__.py2
-rw-r--r--tensorflow/contrib/autograph/__init__.py1
-rw-r--r--tensorflow/contrib/autograph/converters/asserts.py6
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py92
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py12
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py38
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees.py6
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py8
-rw-r--r--tensorflow/contrib/autograph/converters/converter_test_base.py7
-rw-r--r--tensorflow/contrib/autograph/converters/ifexp.py2
-rw-r--r--tensorflow/contrib/autograph/converters/lists.py4
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards.py6
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py21
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py10
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD13
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py32
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py29
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures.py56
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures_test.py44
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity.py49
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py189
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py73
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py43
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py46
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py102
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py19
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py5
-rw-r--r--tensorflow/contrib/cmake/external/grpc.cmake2
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake9
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py25
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/kernels/BUILD12
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc274
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py173
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD10
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py7
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py100
-rw-r--r--tensorflow/contrib/distribute/README.md5
-rw-r--r--tensorflow/contrib/distribute/python/BUILD37
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py17
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py35
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py82
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py56
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py80
-rw-r--r--tensorflow/contrib/framework/__init__.py4
-rw-r--r--tensorflow/contrib/gan/python/train.py5
-rw-r--r--tensorflow/contrib/linalg/BUILD19
-rw-r--r--tensorflow/contrib/linalg/__init__.py2
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py194
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py560
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py36
-rw-r--r--tensorflow/contrib/lite/BUILD1
-rw-r--r--tensorflow/contrib/lite/interpreter.cc17
-rw-r--r--tensorflow/contrib/lite/interpreter.h26
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD3
-rw-r--r--tensorflow/contrib/lite/profiling/BUILD44
-rw-r--r--tensorflow/contrib/lite/profiling/profile_buffer.h150
-rw-r--r--tensorflow/contrib/lite/profiling/profile_buffer_test.cc102
-rw-r--r--tensorflow/contrib/lite/profiling/profiler.h174
-rw-r--r--tensorflow/contrib/lite/profiling/profiler_test.cc105
-rw-r--r--tensorflow/contrib/lite/toco/BUILD5
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc307
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc88
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc139
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc86
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc25
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc7
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto11
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc73
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h18
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py16
-rw-r--r--tensorflow/contrib/proto/BUILD16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD86
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl89
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py68
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py300
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py180
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_case.py35
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_example.proto149
-rw-r--r--tensorflow/contrib/rpc/BUILD16
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD80
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py71
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py336
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py101
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/test_example.proto171
-rw-r--r--tensorflow/contrib/seq2seq/BUILD8
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py2
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py39
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc6
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc1
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt42
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/process_util.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h2
-rw-r--r--tensorflow/core/framework/graph_transfer_info.proto91
-rw-r--r--tensorflow/core/framework/op_kernel.cc48
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc132
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h16
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc4
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc145
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h8
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc16
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc53
-rw-r--r--tensorflow/core/kernels/boosted_trees/training_ops.cc19
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc689
-rw-r--r--tensorflow/core/kernels/cwise_op_clip.cc25
-rw-r--r--tensorflow/core/kernels/cwise_op_clip.h8
-rw-r--r--tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc24
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc700
-rw-r--r--tensorflow/core/kernels/gather_op.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/BUILD1
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.h4
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc126
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h21
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer_test.cc37
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc36
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h4
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc29
-rw-r--r--tensorflow/core/kernels/matching_files_op.cc1
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc7
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc36
-rw-r--r--tensorflow/core/kernels/sdca_internal.h2
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h6
-rw-r--r--tensorflow/core/kernels/string_to_hash_bucket_op.h2
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc37
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt94
-rw-r--r--tensorflow/core/ops/ops.pbtxt90
-rw-r--r--tensorflow/core/platform/default/build_config.bzl86
-rw-r--r--tensorflow/core/platform/default/fingerprint.h10
-rw-r--r--tensorflow/core/platform/fingerprint.h8
-rw-r--r--tensorflow/core/profiler/BUILD4
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md19
-rw-r--r--tensorflow/docs_src/programmers_guide/datasets.md2
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD29
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py48
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py16
-rw-r--r--tensorflow/python/estimator/BUILD67
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py240
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py170
-rw-r--r--tensorflow/python/estimator/canned/head.py10
-rw-r--r--tensorflow/python/estimator/canned/head_test.py22
-rw-r--r--tensorflow/python/estimator/replicate_model_fn.py824
-rw-r--r--tensorflow/python/estimator/replicate_model_fn_test.py1739
-rw-r--r--tensorflow/python/framework/dtypes.py1
-rw-r--r--tensorflow/python/framework/dtypes_test.py1
-rw-r--r--tensorflow/python/framework/python_op_gen.cc2
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD6
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py14
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py51
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py66
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py7
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py272
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py22
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py14
-rw-r--r--tensorflow/python/ops/clip_ops.py26
-rw-r--r--tensorflow/python/ops/cudnn_rnn_grad.py47
-rw-r--r--tensorflow/python/ops/init_ops.py374
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py12
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py38
-rw-r--r--tensorflow/python/ops/standard_ops.py4
-rw-r--r--tensorflow/python/ops/template.py12
-rw-r--r--tensorflow/python/util/tf_inspect.py30
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/docs/parser.py28
-rw-r--r--tensorflow/tools/docs/parser_test.py99
-rw-r--r--tensorflow/tools/pip_package/BUILD2
-rw-r--r--tensorflow/workspace.bzl20
-rw-r--r--third_party/llvm/llvm.BUILD1
233 files changed, 9567 insertions, 4525 deletions
diff --git a/configure.py b/configure.py
index 8fb8979111..b745e374a2 100644
--- a/configure.py
+++ b/configure.py
@@ -226,8 +226,6 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
- write_to_bazelrc('build --force_python=py%s' % python_major_version)
- write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index c96a38dec3..393851d13c 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -116,9 +116,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->async, std::move(device_mgr), r);
}
-void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
- delete ctx;
-}
+void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
@@ -581,7 +579,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
return nullptr;
}
-
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
// primitive op (e.g. matmul).
@@ -725,9 +722,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
}
const tensorflow::FunctionDef* fdef;
- {
- fdef = op->ctx->context.FindFunctionDef(op->name);
- }
+ { fdef = op->ctx->context.FindFunctionDef(op->name); }
std::vector<TF_DataType> const_input_types;
std::vector<TF_DataType> arg_input_types;
tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
@@ -940,8 +935,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
- std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
- nullptr);
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
+ *num_retvals);
status->status = tensorflow::EagerExecute(
&op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
handle_retvals.data(), *num_retvals);
@@ -1091,7 +1086,6 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
} // namespace tensorflow
-
TFE_Op::~TFE_Op() {
for (tensorflow::TensorHandle* h : inputs) {
h->Unref();
diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc
index abe2793ce8..e6c51ab17a 100644
--- a/tensorflow/c/eager/runtime.cc
+++ b/tensorflow/c/eager/runtime.cc
@@ -184,8 +184,7 @@ void CombineUnordered(const tensorflow::Fprint128& a,
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
const tensorflow::Fprint128& b) {
- // TODO(agarwal): avoid ToString().
- tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
+ tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
return FingerprintCat128(a, b);
}
@@ -213,10 +212,8 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
if (node_def_finalized_) return f;
}
for (const auto& p : string_attrs_) {
- // TODO(agarwal): avoid ToString().
- CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
- p.second.ToString())),
- &f);
+ CombineUnordered(
+ CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
}
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 6489929a57..0048eec93b 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <string>
#include "llvm/ADT/Triple.h"
-#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 6edeb7047f..50fa95c4f3 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -318,6 +318,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:bounds_check",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 344773c8c5..fde1977c1b 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -39,6 +39,7 @@ cc_library(
":batch_dot",
":triangular_solve",
":util",
+ ":while_loop",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -126,6 +127,30 @@ cc_library(
],
)
+xla_test(
+ name = "util_test",
+ srcs = ["util_test.cc"],
+ deps = [
+ ":batch_dot",
+ ":util",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "while_loop",
srcs = ["while_loop.cc"],
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index e795701181..203365e2ab 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,68 +32,122 @@ namespace tensorflow {
namespace {
+// The Cholesky–Banachiewicz algorithm. See
+// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
+// for a description.
+//
// def cholesky_unblocked(a):
// assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
// n = a.shape[-2]
// l = np.zeros_like(a)
// for j in xrange(n):
-// r = l[..., j, :j]
-// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r))
-// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j],
-// np.transpose(r))) / l[..., j, j]
+// row = l[..., j, :j]
+// row_t = np.swapaxes(row, -1, -2)
+// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t))
+// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
+// l[..., j, j]
// return l
xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(a));
- xla::ComputationDataHandle l = Zeros(builder, *shape);
- const int64 n = xla::ShapeUtil::GetDimension(*shape, -2);
- for (int j = 0; j < n; ++j) {
- // Picture of block structure:
- // ... \
- // \
- // -- r -- d
- // |\
- // B c \
- // | \
- // | ...
- //
- // ^
- // column j
- TF_ASSIGN_OR_RETURN(auto d,
- SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1}));
- TF_ASSIGN_OR_RETURN(auto c,
- SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1}));
- xla::ComputationDataHandle new_d_squared = d;
- xla::ComputationDataHandle br;
- if (j > 0) {
- TF_ASSIGN_OR_RETURN(auto r,
- SliceInMinorDims(builder, l, {j, 0}, {j + 1, j}));
- TF_ASSIGN_OR_RETURN(auto b,
- SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
- TF_ASSIGN_OR_RETURN(auto r_squared,
- BatchDot(builder, r, r, /*transpose_x=*/false,
- /*transpose_y=*/true, /*conjugate_x=*/false,
- /*conjugate_y=*/false));
- new_d_squared = builder->Sub(new_d_squared, r_squared);
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
+ builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(*a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape->dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
- TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
- /*transpose_y=*/true,
- /*conjugate_x=*/false,
- /*conjugate_y=*/false));
- }
- auto new_d_inv = builder->Pow(
- new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
- auto new_d = builder->Mul(new_d_inv, new_d_squared);
- TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j}));
+ xla::ComputationDataHandle l = Zeros(builder, *a_shape);
- if (j > 0) {
- c = builder->Sub(c, br);
+ // Construct the for loop body to iterate over rows.
+ auto body_fn = [&](xla::ComputationDataHandle i,
+ gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
+ xla::ComputationBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
+ xla::Shape col_shape;
+ xla::Shape row_shape;
+ for (int64 d : major_dims) {
+ row_shape.add_dimensions(d);
+ col_shape.add_dimensions(d);
}
- auto new_c = builder->Mul(c, new_d_inv);
- TF_ASSIGN_OR_RETURN(l,
- UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j}));
- }
- return l;
+ row_shape.add_dimensions(1);
+ row_shape.add_dimensions(n);
+ row_shape.set_element_type(a_shape->element_type());
+ auto mask_zeros_row = Zeros(body_builder, row_shape);
+
+ col_shape.add_dimensions(n);
+ col_shape.add_dimensions(1);
+ col_shape.set_element_type(a_shape->element_type());
+ auto mask_zeros_col = Zeros(body_builder, col_shape);
+
+ std::vector<int32> mask_vector(n);
+ std::iota(mask_vector.begin(), mask_vector.end(), 0);
+ auto mask_range = body_builder->ConstantR1<int32>(mask_vector);
+ auto mask_range_row = body_builder->Broadcast(
+ body_builder->Reshape(mask_range, {0}, {1, n}), major_dims);
+ auto mask_range_col = body_builder->Broadcast(
+ body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims);
+ auto body_a = loop_vars[0];
+ auto body_l = loop_vars[1];
+
+ // row = l[..., i, :i]
+ // select the whole i-th row, then mask out all columns past i-1
+ auto zero = body_builder->ConstantR0<int32>(0);
+ TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l,
+ {i, zero}, {1, n}));
+ auto row = body_builder->Select(body_builder->Ge(mask_range_row, i),
+ mask_zeros_row, l_i);
+ // a[..., i, i]
+ TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
+ {i, i}, {1, 1}));
+ // np.dot(row, np.swapaxes(row, -1, -2))
+ xla::ComputationDataHandle diag_dot;
+ TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true));
+ // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
+ // np.swapaxes(row, -1, -2)))
+ auto l_ii = body_builder->Pow(
+ body_builder->Sub(a_ii, diag_dot),
+ FloatLiteral(body_builder, a_shape->element_type(), 0.5));
+
+ // a[..., i+1:, i]
+ auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
+ // select the whole i-th column, then mask out all rows above i+1
+ TF_ASSIGN_OR_RETURN(
+ auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
+ auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i),
+ mask_zeros_col, a_0i);
+
+ // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
+ // l[..., i, i]
+ // The columns in [i, n] are zeroed out in `row`, so we just have to
+ // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
+ // r.T)
+ TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true));
+ // np.dot(l[..., i+1:, :i], r.T)
+ auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i),
+ mask_zeros_col, dot);
+
+ auto col_update =
+ body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii);
+ TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
+ body_builder, body_l, col_update, {i}));
+ // Assign the diagonal after the rest of the column because otherwise the
+ // column assign will wrap around and overwrite the diagonal assign.
+ TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
+ body_builder, body_l, l_ii, {i, i}));
+
+ return std::vector<xla::ComputationDataHandle>{body_a, body_l};
+ };
+
+ TF_ASSIGN_OR_RETURN(
+ auto cholesky_while,
+ XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
+
+ return cholesky_while[1];
}
} // namespace
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index e083a383be..17da8d8b22 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ namespace tensorflow {
// the block size to use.
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
-// TODO(mattjj): handle the complex Hermitian case
+// TODO(znado): handle the complex Hermitian case
xla::StatusOr<xla::ComputationDataHandle> Cholesky(
xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
int64 block_size = 256);
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index f579669bbd..31d823ca33 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -140,13 +140,47 @@ xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
return builder->Slice(x, padded_start, padded_end, strides);
}
+std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+ const gtl::ArraySlice<int64>& major_dims,
+ const gtl::ArraySlice<int64>& indices) {
+ std::vector<int64> output(indices.size() + major_dims.size());
+ std::copy(major_dims.begin(), major_dims.end(), output.begin());
+ std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size());
+ return output;
+}
+
+xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const std::vector<xla::ComputationDataHandle>& starts,
+ const gtl::ArraySlice<int64>& sizes) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ int64 n_minor_dims = starts.size();
+ TF_RET_CHECK(n_minor_dims == sizes.size());
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
+ TF_ASSIGN_OR_RETURN(auto padded_starts,
+ PrependZerosInMajorDims(builder, x, starts));
+ auto padded_sizes = PrependMajorDims(builder, major_dims, sizes);
+ return builder->DynamicSlice(x, padded_starts, padded_sizes);
+}
+
xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
- return builder->DynamicUpdateSlice(
- x, update, builder->ConstantR1<int32>(start_as_int32));
+ auto start_constant = builder->ConstantR1<int32>(start_as_int32);
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> start_constant_shape,
+ builder->GetShape(start_constant));
+ const int64 start_length =
+ xla::ShapeUtil::GetDimension(*start_constant_shape, -1);
+ TF_RET_CHECK(start_length == n_dims);
+ return builder->DynamicUpdateSlice(x, update, start_constant);
}
xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
@@ -162,6 +196,29 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
return UpdateSlice(builder, x, update, padded_start);
}
+xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update,
+ const std::vector<xla::ComputationDataHandle>& starts) {
+ TF_ASSIGN_OR_RETURN(auto padded_starts,
+ PrependZerosInMajorDims(builder, x, starts));
+ return builder->DynamicUpdateSlice(x, update, padded_starts);
+}
+
+xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const std::vector<xla::ComputationDataHandle>& starts) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
+ std::vector<xla::ComputationDataHandle> padded_starts(n_dims, zero);
+ for (int i = 0; i < starts.size(); ++i) {
+ padded_starts[n_dims - starts.size() + i] =
+ builder->Reshape(starts[i], {1});
+ }
+ return builder->ConcatInDim(padded_starts, 0);
+}
+
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index 51f8baaf00..b684123f13 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -32,16 +32,39 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
xla::PrimitiveType type, double value);
+// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
+// prepended until the array is length n_dims.
+xla::ComputationDataHandle PrependZerosInMajorDims(
+ xla::ComputationBuilder* builder,
+ gtl::ArraySlice<xla::ComputationDataHandle> starts);
+
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
xla::PrimitiveType type, int64 value);
+// Builds a vector of zeros of length rank(x) with the last two values being
+// those in `starts`.
+xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const std::vector<xla::ComputationDataHandle>& starts);
+
// Performs a slice in the minor dimensions of a Tensor.
xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end);
+// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
+std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+ const gtl::ArraySlice<int64>& major_dims,
+ const gtl::ArraySlice<int64>& indices);
+
+// Performs a dynamic slice in the minor dimensions of a Tensor.
+xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const std::vector<xla::ComputationDataHandle>& starts,
+ const gtl::ArraySlice<int64>& sizes);
+
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
@@ -54,6 +77,11 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update,
+ const std::vector<xla::ComputationDataHandle>& starts);
+
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
new file mode 100644
index 0000000000..b6bd33af2e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+using UtilTest = xla::ClientLibraryTestBase;
+using UtilLeftLookingTest = xla::ClientLibraryTestBase;
+
+xla::Array2D<float> BValsRight() {
+ return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
+}
+
+xla::Array2D<float> BValsLeft() {
+ return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
+}
+
+xla::Array2D<float> AValsFull() {
+ return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
+}
+
+xla::Array3D<float> BatchedAValsFull() {
+ return {{
+ {2, 0, 1, 2},
+ {3, 6, 0, 1},
+ {4, 7, 9, 0},
+ {5, 8, 10, 11},
+ },
+ {
+ {16, 24, 8, 12},
+ {24, 61, 82, 48},
+ {8, 82, 456, 106},
+ {12, 48, 106, 62},
+ }};
+}
+
+XLA_TEST_F(UtilTest, Simple2dLookup) {
+ xla::ComputationBuilder builder(client_, TestName());
+
+ xla::ComputationDataHandle a, x, y;
+ auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
+ auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
+ auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
+ auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1});
+ TF_ASSERT_OK(result.status());
+
+ ComputeAndCompareR2<float>(&builder, {{10}},
+ {a_data.get(), x_data.get(), y_data.get()},
+ xla::ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(UtilTest, Simple3dLookup) {
+ xla::ComputationBuilder builder(client_, TestName());
+
+ xla::ComputationDataHandle a, index;
+ auto a_data =
+ CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
+ auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto l_index,
+ DynamicSliceInMinorDims(&builder, a,
+ {index, builder.ConstantR0<int32>(0)}, {1, 4}));
+
+ ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
+ {a_data.get(), index_data.get()});
+}
+
+XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
+ xla::ComputationBuilder builder(client_, TestName());
+
+ xla::ComputationDataHandle a, b, x, y;
+ auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
+ auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
+ auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
+
+ auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y});
+ TF_ASSERT_OK(result.status());
+
+ xla::Array2D<float> expected(
+ {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
+
+ ComputeAndCompareR2<float>(
+ &builder, expected,
+ {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
+}
+
+XLA_TEST_F(UtilTest, RowBatchDot) {
+ xla::ComputationBuilder builder(client_, TestName());
+
+ int n = 4;
+
+ xla::ComputationDataHandle a, row, index;
+ auto a_data =
+ CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
+ auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
+ "row", &builder, &row);
+ // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
+ auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto l_index,
+ DynamicSliceInMinorDims(&builder, a,
+ {index, builder.ConstantR0<int32>(0)}, {1, n}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto dot, BatchDot(&builder, l_index, row,
+ /*transpose_x=*/false, /*transpose_y=*/true));
+
+ ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
+ {a_data.get(), row_data.get(), index_data.get()});
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ddc099807d..9831a09c1f 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1283,6 +1283,7 @@ cc_library(
":hlo_creation_utils",
":hlo_pass",
":hlo_query",
+ ":pattern_matcher",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 6cb1bd5669..8d26938c6e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -44,8 +45,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
+
namespace {
+namespace m = match;
+
// Returns whether operand is a literal with the given value.
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
return operand->opcode() == HloOpcode::kConstant &&
@@ -105,6 +109,7 @@ HloComputation* CreateScalarBinaryComputation(HloModule* module,
module->AddEmbeddedComputation(b.Build(scalar_op));
return scalar_computation;
}
+
} // namespace
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@@ -350,8 +355,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
}
Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
- auto lhs = add->mutable_operand(0);
- auto rhs = add->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
+
// A + 0 => A
VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
@@ -366,7 +372,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
// Canonicalization: Put constants on the right. This makes the reassociation
// rules below simpler.
VLOG(10) << "trying transform [Const + A => A + Const]";
- if (lhs->IsConstant() && !rhs->IsConstant()) {
+ if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
return ReplaceWithNewInstruction(
add,
HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
@@ -379,16 +385,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
// (A + C1) + (B + C2) => A + B + (C1 + C2).
//
VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
- if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd &&
- !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) {
- auto* c1 = lhs->mutable_operand(1);
- auto* c2 = rhs;
-
+ HloInstruction *a, *c1, *c2;
+ if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
+ m::Constant(&c2)))) {
TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
return ReplaceWithNewInstruction(
- add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
- lhs->mutable_operand(0),
+ add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
sum_of_constants));
}
@@ -397,11 +400,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
// If a bitcast feeds a bitcast, make it a single bitcast.
- if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
+ HloInstruction* op;
+ if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
return ReplaceWithNewInstruction(
- bitcast, HloInstruction::CreateUnary(
- bitcast->shape(), HloOpcode::kBitcast,
- bitcast->mutable_operand(0)->mutable_operand(0)));
+ bitcast,
+ HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op));
}
// All bitcasts can be eliminated (assuming layout constraints are
// satisified).
@@ -418,11 +421,10 @@ Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
- if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
+ HloInstruction* op;
+ if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
return ReplaceWithNewInstruction(
- copy, HloInstruction::CreateUnary(
- copy->shape(), HloOpcode::kCopy,
- copy->mutable_operand(0)->mutable_operand(0)));
+ copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
}
// All copies can be eliminated (assuming layout constraints are satisified).
ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
@@ -462,12 +464,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
} else if (operands.size() == 2) {
// A binary concat with a broadcasted scalar as an operand can be converted
// into a pad which is simpler to fold into other operations.
- bool is_effective_low_pad =
- operands[0]->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::IsScalar(operands[0]->operand(0)->shape());
- bool is_effective_high_pad =
- operands[1]->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::IsScalar(operands[1]->operand(0)->shape());
+ bool is_effective_low_pad = Match(
+ operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
+ bool is_effective_high_pad = Match(
+ operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
if (!is_effective_low_pad && !is_effective_high_pad) {
return Status::OK();
}
@@ -537,8 +537,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
}
Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
- auto lhs = sub->mutable_operand(0);
- auto rhs = sub->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
// A - 0 => A
VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
@@ -547,7 +547,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
// Canonicalize subtraction of a constant to addition.
VLOG(10) << "trying transform [A - Const => A + (-Const)]";
- if (rhs->IsConstant() && !lhs->IsConstant()) {
+ if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
HloInstruction* negative_const = computation_->AddInstruction(
HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
return ReplaceWithNewInstruction(
@@ -559,56 +559,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
}
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
- auto lhs = divide->mutable_operand(0);
- auto rhs = divide->mutable_operand(1);
+ Shape* shape;
+ HloInstruction *a, *b, *c, *d;
+ CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
// A/1 => A
VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
- if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
+ if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
return Status::OK();
}
// exp(A)/exp(B) => exp(A-B)
- if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
+ .WithShape(m::Shape(&shape)))) {
VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
- HloInstruction* subtract =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0),
- rhs->mutable_operand(0)));
+ HloInstruction* subtract = computation_->AddInstruction(
+ HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp,
- subtract));
+ divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
}
// A/exp(B) => A*exp(-B)
- if (rhs->opcode() == HloOpcode::kExp) {
+ if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
- HloInstruction* negate =
- computation_->AddInstruction(HloInstruction::CreateUnary(
- divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0)));
+ HloInstruction* negate = computation_->AddInstruction(
+ HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
HloInstruction* new_exp = computation_->AddInstruction(
HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, new_exp));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kMultiply, a, new_exp));
}
// A/pow(B,C) => A*pow(B,-C)
- if (rhs->opcode() == HloOpcode::kPower) {
+ if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
// The output shape of the created negate operator should be the same as the
// input.
- const Shape& negate_shape = rhs->operand(1)->shape();
- HloInstruction* negate =
- computation_->AddInstruction(HloInstruction::CreateUnary(
- negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1)));
+ const Shape& negate_shape = c->shape();
+ HloInstruction* negate = computation_->AddInstruction(
+ HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
// And the power operator should retain the output shape of the old one.
- const Shape& new_power_shape = rhs->shape();
- HloInstruction* new_power = computation_->AddInstruction(
- HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower,
- rhs->mutable_operand(0), negate));
+ const Shape& new_power_shape = b->shape();
+ HloInstruction* new_power =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ new_power_shape, HloOpcode::kPower, b, negate));
return ReplaceWithNewInstruction(
divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, new_power));
+ divide->shape(), HloOpcode::kMultiply, a, new_power));
}
// Simplifying integral division would produce unexpected results.
@@ -620,28 +617,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
//
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
- if (lhs->opcode() != HloOpcode::kConstant &&
- rhs->opcode() == HloOpcode::kConstant) {
+ if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
HloInstruction* one =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(lhs->shape().element_type()).CloneToUnique()));
- HloInstruction* inverse =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- rhs->shape(), HloOpcode::kDivide, one, rhs));
+ Literal::One(a->shape().element_type()).CloneToUnique()));
+ HloInstruction* inverse = computation_->AddInstruction(
+ HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, inverse));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kMultiply, a, inverse));
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
- if (lhs->opcode() == HloOpcode::kDivide &&
- rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply,
- lhs->mutable_operand(0),
- rhs->mutable_operand(1)));
- TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply,
- lhs->mutable_operand(1),
- rhs->mutable_operand(0)));
+ if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
+ m::Divide(m::Op(&c), m::Op(&d))))) {
+ TF_ASSIGN_OR_RETURN(auto a_times_d,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, d));
+ TF_ASSIGN_OR_RETURN(auto b_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, b, c));
TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
a_times_d, b_times_c));
@@ -649,24 +642,21 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
}
// (A / B) / C => A / (B * C)
- if (lhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(
- auto b_times_c,
- MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
+ if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
+ TF_ASSIGN_OR_RETURN(auto b_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, b, c));
return ReplaceWithNewInstruction(
- divide,
- HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
- lhs->mutable_operand(0), b_times_c));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kDivide, a, b_times_c));
}
// A / (B / C) => (A*C) / B
- if (rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs,
- rhs->mutable_operand(1)));
+ if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
+ TF_ASSIGN_OR_RETURN(auto a_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, c));
return ReplaceWithNewInstruction(
- divide,
- HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
- a_times_c, rhs->mutable_operand(0)));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kDivide, a_times_c, b));
}
return Status::OK();
@@ -674,8 +664,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
HloInstruction* dot) {
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
int64 lhs_collapsing_dim =
dot->dot_dimension_numbers().lhs_contracting_dimensions(0);
if (lhs->IsRank2Transpose()) {
@@ -792,8 +782,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
TF_ASSIGN_OR_RETURN(
HloInstruction * optimized_lhs_concat,
@@ -923,8 +913,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
}
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
- auto lhs = dot->mutable_operand(0);
- auto rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
// Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
// below.
@@ -976,8 +966,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
}
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
- auto lhs = multiply->mutable_operand(0);
- auto rhs = multiply->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
// A*1 => A
VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
@@ -990,10 +980,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
}
// exp(A) * exp(B) => exp(A+B)
- if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
- multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0),
- rhs->mutable_operand(0)));
+ multiply->shape(), HloOpcode::kAdd, lhs, rhs));
return ReplaceWithNewInstruction(
multiply,
HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
@@ -1004,20 +993,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
// ln(exp(A)) => A
VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
- auto operand = log->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kExp &&
- ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) {
+ HloInstruction *a, *b;
+ if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
+ ReplaceInstructionIfSameShape(log, a)) {
return Status::OK();
}
// ln(pow(A,B)) => B*ln(A)
- if (operand->opcode() == HloOpcode::kPower) {
- auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary(
- log->shape(), HloOpcode::kLog, operand->mutable_operand(0)));
+ if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
+ auto new_log = computation_->AddInstruction(
+ HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
return ReplaceWithNewInstruction(
- log,
- HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
- new_log, operand->mutable_operand(1)));
+ log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
+ new_log, b));
}
return Status::OK();
@@ -1120,7 +1108,8 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
} // namespace
Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
- auto operand = broadcast->mutable_operand(0);
+ HloInstruction* operand;
+ CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
auto dims = broadcast->dimensions();
// A degenerate broadcast of a reshape that does not change the number of
// elements can be replaced by a reshape.
@@ -1231,30 +1220,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
// Complex(Real(c), Imag(c)) -> c
Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
- auto real = complex->mutable_operand(0);
- auto imag = complex->mutable_operand(1);
- if (real->opcode() == HloOpcode::kReal &&
- imag->opcode() == HloOpcode::kImag &&
- real->operand(0) == imag->operand(0)) {
- return ReplaceInstruction(complex, real->mutable_operand(0));
+ HloInstruction *c0, *c1;
+ if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
+ c0 == c1) {
+ return ReplaceInstruction(complex, c0);
}
return Status::OK();
}
// Real(Complex(r, i)) -> r
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
- auto operand = real->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kComplex) {
- return ReplaceInstruction(real, operand->mutable_operand(0));
+ HloInstruction* op;
+ if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
+ return ReplaceInstruction(real, op);
}
return Status::OK();
}
// Imag(Complex(r, i)) -> i
Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
- auto operand = imag->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kComplex) {
- return ReplaceInstruction(imag, operand->mutable_operand(1));
+ HloInstruction* op;
+ if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
+ return ReplaceInstruction(imag, op);
}
return Status::OK();
}
@@ -1351,8 +1338,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
- auto lhs = power->mutable_operand(0);
- auto rhs = power->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
Literal::One(power->shape().element_type()).CloneToUnique());
@@ -1372,9 +1359,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
}
// pow(exp(A),B) => exp(A*B)
- if (lhs->opcode() == HloOpcode::kExp) {
+ HloInstruction *a, *b;
+ if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
- power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs));
+ power->shape(), HloOpcode::kMultiply, a, b));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
a_times_b));
@@ -1707,7 +1695,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
HloInstruction::CreateReshape(reduce->shape(), arg));
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateMap(reduce->shape(),
- {reshape, init_value}, function));
+ {init_value, reshape}, function));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 61b2da7a7d..6a7eb85e3b 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -25,11 +25,11 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/MCContext.h"
#include "llvm/Object/ObjectFile.h"
+#include "llvm/Support/SmallVectorMemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO.h"
@@ -158,7 +158,7 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
// Construct ObjectFile from machine code buffer.
return std::unique_ptr<llvm::MemoryBuffer>(
- new llvm::ObjectMemoryBuffer(std::move(stream_buffer)));
+ new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index 99c5e16db7..e97113dfa0 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -115,7 +115,7 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
for (int i = 0; i < hlo->operand_count(); i++) {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(i))(
- ElementwiseSourceIndex(index, *hlo, 0)));
+ ElementwiseSourceIndex(index, *hlo, i)));
operands.push_back(operand_value);
}
return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index b4f9a9db9c..52bc2c0448 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1604,8 +1604,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
// Evaluate computation with specified literal operands.
auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
- std::vector<const Literal*> args = {curr_val_literal.get(),
- result_val_literal.get()};
+ std::vector<const Literal*> args = {result_val_literal.get(),
+ curr_val_literal.get()};
std::unique_ptr<Literal> computed_result =
embedded_evaluator.Evaluate<const Literal*>(*function, args)
@@ -1804,7 +1804,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const auto result_val_literal =
Literal::CreateR0<ReturnT>(result_val);
const std::vector<const Literal*> args = {
- curr_val_literal.get(), result_val_literal.get()};
+ result_val_literal.get(), curr_val_literal.get()};
std::unique_ptr<Literal> computed_result =
embedded_evaluator.Evaluate<const Literal*>(*function, args)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 768beec15e..423ccadb5b 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -934,5 +935,36 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
DoTest<uint64>(1234556789123, 1024);
}
+// Test the operational semantic that the init value is passed on the lhs for
+// reduces. Can be tested by performing an "identity" reduce (that simply
+// returns one of the parameters). In this case, we return the rhs, which for
+// a 1D array with one element, should not be the init value.
+XLA_TEST_F(ReduceTest, ReduceIdentity) {
+ ComputationBuilder builder(client_, TestName());
+ Shape single_float = ShapeUtil::MakeShape(F32, {});
+ builder.Parameter(0, single_float, "lhs-unused");
+ builder.Parameter(1, single_float, "rhs-used");
+ auto computation_status = builder.Build();
+ TF_ASSERT_OK(computation_status.status());
+
+ Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
+ builder.Reduce(builder.Parameter(0, operand_shape, "operand"),
+ builder.Parameter(1, single_float, "init"),
+ computation_status.ValueOrDie(), {0});
+
+ float operand[] = {42.0f};
+ float init = 58.5f;
+ float expected = 42.0f;
+ std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(operand);
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> input_literal2 = Literal::CreateR0<float>(init);
+ std::unique_ptr<GlobalData> input_global_data2 =
+ client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+ ComputeAndCompareR0<float>(
+ &builder, expected, {input_global_data.get(), input_global_data2.get()},
+ ErrorSpec(0.0001));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 6a054a5dd3..0a09766722 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1435,5 +1435,22 @@ ENTRY R3Window {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
}
+TEST_F(HloTestBase, ReduceWindowIdentity) {
+ const string& hlo_string = R"(
+HloModule ReduceWindowIdentity
+identity.pad_to_reduce_window {
+ param0 = f32[] parameter(0)
+ ROOT param1 = f32[] parameter(1)
+}
+ENTRY reduce-window-identity {
+ operand = f32[1,32,64]{2,1,0} parameter(0)
+ constant.4466 = f32[] constant(0)
+ ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window
+}
+
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 9bef0d8b61..7e47516550 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -77,6 +77,7 @@ py_library(
"//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
"//tensorflow/contrib/periodic_resample:init_py",
"//tensorflow/contrib/predictor",
+ "//tensorflow/contrib/proto",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/quantize:quantize_graph",
"//tensorflow/contrib/autograph",
@@ -86,6 +87,7 @@ py_library(
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",
"//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/contrib/rpc",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/contrib/seq2seq:seq2seq_py",
"//tensorflow/contrib/signal:signal_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index aaddb06fa0..36cc5144d0 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -64,12 +64,14 @@ from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import periodic_resample
from tensorflow.contrib import predictor
+from tensorflow.contrib import proto
from tensorflow.contrib import quantization
from tensorflow.contrib import quantize
from tensorflow.contrib import recurrent
from tensorflow.contrib import reduce_slice_ops
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn
+from tensorflow.contrib import rpc
from tensorflow.contrib import saved_model
from tensorflow.contrib import seq2seq
from tensorflow.contrib import signal
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index a39f44b21a..3386c4eca4 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -21,6 +21,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# TODO(mdan): Bring only the relevant symbols to the top level.
from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.impl.api import convert
from tensorflow.contrib.autograph.impl.api import converted_call
diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py
index f011a97ade..2d9e2c58e3 100644
--- a/tensorflow/contrib/autograph/converters/asserts.py
+++ b/tensorflow/contrib/autograph/converters/asserts.py
@@ -27,8 +27,6 @@ from tensorflow.contrib.autograph.pyct import transformer
class AssertsTransformer(transformer.Base):
"""Transforms Print nodes to Call so they can be handled as functions."""
- # pylint:disable=invalid-name
-
def visit_Assert(self, node):
self.generic_visit(node)
@@ -44,9 +42,7 @@ class AssertsTransformer(transformer.Base):
elif isinstance(node.msg, gast.Str):
return templates.replace(template, test=node.test, msg=node.msg)
else:
- raise NotImplementedError('Can only convert string messages for now.')
-
- # pylint:enable=invalid-name
+ raise NotImplementedError('can only convert string messages for now.')
def transform(node, context):
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 62115d4005..5dfb7a59d5 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import gast
-
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
@@ -35,86 +33,62 @@ class BreakCanonicalizationTransformer(transformer.Base):
# Each item is a list [break_used, break_variable_name]
self.break_uses = []
- def _create_break_check(self):
- template = """
- (not var_name)
- """
- expr, = templates.replace(template, var_name=self.break_uses[-1][1])
- return expr.value
-
- def _create_break_trigger(self):
+ def visit_Break(self, node):
+ self.break_uses[-1][0] = True
template = """
var_name = True
+ continue
"""
- block = templates.replace(template, var_name=self.break_uses[-1][1])
- block.append(gast.Continue())
- return block
-
- def _create_break_init(self):
- template = """
- var_name = False
- """
- assign, = templates.replace(template, var_name=self.break_uses[-1][1])
- return assign
-
- # TODO(mdan): Surely the transformer supports this better?
- def _manual_visit_list(self, block):
- new_block = []
- for n in block:
- new_n = self.visit(n)
- if isinstance(new_n, list):
- new_block.extend(new_n)
- else:
- new_block.append(new_n)
- return new_block
+ return templates.replace(template, var_name=self.break_uses[-1][1])
def visit_While(self, node):
- self.generic_visit(node.test)
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-
break_var = self.context.namer.new_symbol('break_requested',
scope.referenced)
+
self.break_uses.append([False, break_var])
- node.body = self._manual_visit_list(node.body)
+ node = self.generic_visit(node)
if self.break_uses[-1][0]:
- node.test = gast.BoolOp(gast.And(), [
- node.test,
- gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(), None))
- ])
- final_nodes = [self._create_break_init(), node]
- else:
- final_nodes = node
+ template = """
+ var_name = False
+ while original_test and not var_name:
+ original_body
+ else:
+ original_orelse
+ """
+ node = templates.replace(
+ template,
+ var_name=break_var,
+ original_test=node.test,
+ original_body=node.body,
+ original_orelse=node.orelse)
self.break_uses.pop()
- for n in node.orelse:
- self.generic_visit(n)
- return final_nodes
+ return node
def visit_For(self, node):
- self.generic_visit(node.target)
- self.generic_visit(node.iter)
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-
break_var = self.context.namer.new_symbol('break_requested',
scope.referenced)
+
self.break_uses.append([False, break_var])
- node.body = self._manual_visit_list(node.body)
+ node = self.generic_visit(node)
if self.break_uses[-1][0]:
+ template = """
+ var_name = False
+ original_for
+ """
+ node = templates.replace(
+ template,
+ var_name=break_var,
+ original_for=node)
extra_cond = templates.replace_as_expression(
'not var_name', var_name=break_var)
- anno.setanno(node, 'extra_cond', extra_cond)
- final_nodes = [self._create_break_init(), node]
- else:
- final_nodes = node
+ new_for_node = node[1]
+ anno.setanno(new_for_node, 'extra_cond', extra_cond)
self.break_uses.pop()
- for n in node.orelse:
- self.generic_visit(n)
- return final_nodes
-
- def visit_Break(self, node):
- self.break_uses[-1][0] = True
- return self._create_break_trigger()
+ return node
def transform(node, context):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index 0349ce29ce..317711a866 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -34,24 +34,24 @@ class BuiltinFunctionTransformer(transformer.Base):
def __init__(self, context):
super(BuiltinFunctionTransformer, self).__init__(context)
- # pylint:disable=invalid-name
-
def _convert_builtin(self, node):
template = """
- autograph_utils.dynamic_builtin(func, args)
+ ag__.utils.dynamic_builtin(func, args)
"""
return templates.replace(template, func=node.func, args=node.args)[0].value
def _convert_print(self, node):
template = """
- autograph_utils.dynamic_print(args)
+ ag__.utils.dynamic_print(args)
"""
return templates.replace(template, args=node.args)[0].value
def visit_Call(self, node):
self.generic_visit(node)
# TODO(mdan): This won't work if the function was hidden.
- if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'):
+ # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
+ if (isinstance(node.func, gast.Name) and
+ node.func.id in ('len', 'range', 'xrange')):
return self._convert_builtin(node)
# Print needs to be handled separately because it can be read as statement.
if isinstance(node.func, gast.Name) and node.func.id == 'print':
@@ -70,8 +70,6 @@ class BuiltinFunctionTransformer(transformer.Base):
function_call = templates.replace(template, fname='print', args=args)[0]
return self.visit(function_call)
- # pylint:enable=invalid-name
-
def transform(node, context):
return BuiltinFunctionTransformer(context).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index ac7e756c47..30272409df 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -26,8 +26,6 @@ from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -49,7 +47,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
self.assertEqual(3, result.test_fn([0, 0, 0]))
- def test_print_with_op(self):
+ def test_print(self):
def test_fn(a):
print(a)
@@ -57,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- # Note: it's relevant not to include script_ops.py_func here, to verify
- # that tf.Print is used.
- with self.compiled(node, logging_ops.Print) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- result.test_fn('a')
+ result.test_fn(constant_op.constant('a'))
sess.run(sess.graph.get_operations())
self.assertEqual(out_capturer.getvalue(), 'a\n')
finally:
@@ -72,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
def test_print_with_op_multiple_values(self):
- def test_fn(a, b):
- print(a, b)
-
- node = self.parse_and_analyze(test_fn, {'print': print})
- node = builtin_functions.transform(node, self.ctx)
-
- # Note: it's relevant not to include script_ops.py_func here, to verify
- # that tf.Print is used.
- with self.compiled(node, logging_ops.Print) as result:
- with self.test_session() as sess:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn('a', 1)
- sess.run(sess.graph.get_operations())
- self.assertEqual(out_capturer.getvalue(), 'a 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_print_with_py_func(self):
-
def test_fn(a, b, c):
print(a, b, c)
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- # Note: it's relevant not to include logging_ops.Print here, to verify
- # that py_func is used.
- with self.compiled(node, script_ops.py_func) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- result.test_fn('a', 1, [2, 3])
+ result.test_fn(
+ constant_op.constant('a'), constant_op.constant(1), [2, 3])
sess.run(sess.graph.get_operations())
self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
finally:
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py
index b9088026c1..685fd39d7c 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/contrib/autograph/converters/call_trees.py
@@ -198,7 +198,7 @@ class CallTreeTransformer(transformer.Base):
def _wrap_to_py_func_no_return(self, node):
# TODO(mdan): Properly handle varargs, etc.
template = """
- autograph_utils.wrap_py_func(func, None, (args,), kwargs, True)
+ ag__.utils.wrap_py_func(func, None, (args,), kwargs, True)
"""
return templates.replace(
template,
@@ -209,7 +209,7 @@ class CallTreeTransformer(transformer.Base):
def _wrap_to_py_func_single_return(self, node, dtype):
# TODO(mdan): Properly handle varargs, etc.
template = """
- autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False)
+ ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
"""
return templates.replace_as_expression(
template,
@@ -237,7 +237,7 @@ class CallTreeTransformer(transformer.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- autograph_api.converted_call(func, True, False, {}, args)
+ ag__.converted_call(func, True, False, {}, args)
"""
call_expr = templates.replace(template, func=node.func, args=node.args)
new_call = call_expr[0].value
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 55a28e8ac3..2e26cdb3d9 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -78,7 +78,7 @@ class ControlFlowTransformer(transformer.Base):
def _create_cond_expr(self, results, test, body_name, orelse_name):
if results is not None:
template = """
- results = autograph_utils.run_cond(test, body_name, orelse_name)
+ results = ag__.utils.run_cond(test, body_name, orelse_name)
"""
return templates.replace(
template,
@@ -88,7 +88,7 @@ class ControlFlowTransformer(transformer.Base):
orelse_name=orelse_name)
else:
template = """
- autograph_utils.run_cond(test, body_name, orelse_name)
+ ag__.utils.run_cond(test, body_name, orelse_name)
"""
return templates.replace(
template, test=test, body_name=body_name, orelse_name=orelse_name)
@@ -207,7 +207,7 @@ class ControlFlowTransformer(transformer.Base):
def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = __ops.while_loop(
+ state_ast_tuple = ag__.while_loop(
test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
@@ -264,7 +264,7 @@ class ControlFlowTransformer(transformer.Base):
def body_name(iterate, state_ssf):
body
return state_ssf,
- state_ast_tuple = __ops.for_loop(
+ state_ast_tuple = ag__.for_loop(
iterated, extra_cond_name, body_name, (state,))
"""
node = templates.replace(
diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py
index 6f75e9a529..23b61cf781 100644
--- a/tensorflow/contrib/autograph/converters/converter_test_base.py
+++ b/tensorflow/contrib/autograph/converters/converter_test_base.py
@@ -76,9 +76,10 @@ class TestCase(test.TestCase):
try:
result, source = compiler.ast_to_object(node)
result.tf = self.make_fake_mod('fake_tf', *symbols)
- result.autograph_utils = utils
- result.autograph_api = self.make_fake_mod('fake_api', converted_call)
- result.__dict__['__ops'] = operators
+ fake_ag = self.make_fake_mod('fake_ag', converted_call)
+ fake_ag.__dict__.update(operators.__dict__)
+ fake_ag.__dict__['utils'] = utils
+ result.__dict__['ag__'] = fake_ag
yield result
except Exception: # pylint:disable=broad-except
if source is None:
diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py
index bb0c0a36a7..616d222762 100644
--- a/tensorflow/contrib/autograph/converters/ifexp.py
+++ b/tensorflow/contrib/autograph/converters/ifexp.py
@@ -27,7 +27,7 @@ class IfExp(transformer.Base):
def visit_IfExp(self, node):
template = """
- autograph_utils.run_cond(test, lambda: (body,), lambda: (orelse,))
+ ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,))
"""
desugared_ifexp = templates.replace_as_expression(
template, test=node.test, body=node.body, orelse=node.orelse)
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py
index 234a0a7487..6dda554acc 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/contrib/autograph/converters/lists.py
@@ -45,7 +45,7 @@ class ListTransformer(transformer.Base):
if not anno.hasanno(node, 'element_type'):
raise NotImplementedError(
'type inference for empty lists is not yet supported; '
- 'use utils.set_element_type(<list>, <dtype>) to continue')
+ 'use set_element_type(<list>, <dtype>) to continue')
dtype = anno.getanno(node, 'element_type')
if not isinstance(dtype, dtypes.DType):
# TODO(mdan): Allow non-TF dtypes?
@@ -74,7 +74,7 @@ class ListTransformer(transformer.Base):
if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
template = """
- target = autograph_utils.dynamic_list_append(target, element)
+ target = ag__.utils.dynamic_list_append(target, element)
"""
node = templates.replace(
template,
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py
index 1c1293d2c4..3bcb2d3c42 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py
@@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base):
[alias_map.get(s, s).ast() for s in guarded_args], None)
template = """
- with autograph_utils.control_dependency_on_returns(call):
- aliased_guarded_args = autograph_utils.alias_tensors(guarded_args)
+ with ag__.utils.control_dependency_on_returns(call):
+ aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
"""
control_deps_guard = templates.replace(
template,
@@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base):
alias_map = {}
template = """
- with autograph_utils.control_dependency_on_returns(call):
+ with ag__.utils.control_dependency_on_returns(call):
pass
"""
control_deps_guard = templates.replace(template, call=node.value)[-1]
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index bcf31b8961..c9fb972953 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import imp
+
import gast
from tensorflow.contrib.autograph import operators
@@ -221,12 +223,17 @@ def _add_reserved_symbol(namespace, name, entity):
def _add_self_references(namespace, api_module):
- # Manually add the utils namespace which may be used from generated code.
- _add_reserved_symbol(namespace, 'autograph_utils', utils)
- _add_reserved_symbol(namespace, '__ops', operators)
- # We also make reference to the api module for dynamic conversion, but
- # to avoid circular references we don't import it here.
- _add_reserved_symbol(namespace, 'autograph_api', api_module)
+ # Craft a module that exposes parts of the external API as well as certain
+ # internal modules.
+ ag_internal = imp.new_module('autograph')
+ ag_internal.converted_call = api_module.converted_call
+ ag_internal.utils = utils
+ # TODO(mdan): Add safeguards against name clashes.
+ # We don't want to create a submodule because we want the operators to be
+ # accessible as ag__.<operator>
+ ag_internal.__dict__.update(operators.__dict__)
+
+ _add_reserved_symbol(namespace, 'ag__', ag_internal)
def function_to_graph(f, conversion_map, arg_values, arg_types,
@@ -312,6 +319,8 @@ def node_to_graph(node, ctx, nocompile_decorators):
node = ifexp.transform(node, ctx)
node, deps = decorators.transform(node, nocompile_decorators)
node = break_statements.transform(node, ctx)
+ node = _static_analysis_pass(node, ctx)
+
node = asserts.transform(node, ctx)
# Note: sequencing continue canonicalization before for loop one avoids
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index 962009c71f..f0b597c12f 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.impl import api
from tensorflow.contrib.autograph.impl import conversion
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
@@ -28,6 +29,9 @@ from tensorflow.python.platform import test
class ConversionTest(test.TestCase):
+ def _simple_conversion_map(self):
+ return conversion.ConversionMap(True, (), (), api)
+
def test_is_whitelisted_for_graph(self):
def test_fn():
@@ -39,7 +43,7 @@ class ConversionTest(test.TestCase):
def test_entity_to_graph_unsupported_types(self):
with self.assertRaises(ValueError):
- conversion_map = conversion.ConversionMap(True, (), (), None)
+ conversion_map = self._simple_conversion_map()
conversion.entity_to_graph('dummy', conversion_map, None, None)
def test_entity_to_graph_callable(self):
@@ -47,7 +51,7 @@ class ConversionTest(test.TestCase):
def f(a):
return a + b
- conversion_map = conversion.ConversionMap(True, (), (), None)
+ conversion_map = self._simple_conversion_map()
ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None)
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
self.assertEqual('tf__f', name)
@@ -61,7 +65,7 @@ class ConversionTest(test.TestCase):
def f(a):
return g(a)
- conversion_map = conversion.ConversionMap(True, (), (), None)
+ conversion_map = self._simple_conversion_map()
conversion.entity_to_graph(f, conversion_map, None, None)
self.assertTrue(f in conversion_map.dependency_cache)
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 4c62468575..efb8d441dd 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -21,11 +21,24 @@ py_library(
srcs = [
"__init__.py",
"control_flow.py",
+ "data_structures.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "data_structures_test",
+ srcs = ["data_structures_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 81ae64f110..d9d8b0d593 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -25,6 +25,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
+# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
+# TODO(mdan): Rename arguments to match the AST names.
+
def for_loop(iterated, extra_cond, loop_body, init_state):
"""Functional form of a for statement.
@@ -182,3 +185,32 @@ def _py_while_loop(loop_cond, loop_body, init_state, opts):
while loop_cond(*state):
state = loop_body(*state)
return state
+
+
+def if_stmt(cond, body, orelse):
+ """Functional form of an if statement.
+
+ Args:
+ cond: Boolean.
+ body: Callable with no arguments, and outputs of the positive (if) branch
+ as return type.
+ orelse: Callable with no arguments, and outputs of the negative (else)
+ branch as return type.
+
+ Returns:
+ Tuple containing the statement outputs.
+ """
+ if tensor_util.is_tensor(cond):
+ return _tf_if_stmt(cond, body, orelse)
+ else:
+ return _py_if_stmt(cond, body, orelse)
+
+
+def _tf_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that stages a TF cond."""
+ return control_flow_ops.cond(cond, body, orelse)
+
+
+def _py_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that executes a Python if statement."""
+ return body() if cond else orelse()
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
index 9112b1627f..a0cd0bfa82 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import operators
+from tensorflow.contrib.autograph.operators import control_flow
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class ForLoopTest(test.TestCase):
def test_tensor(self):
- s = operators.for_loop(
+ s = control_flow.for_loop(
constant_op.constant([1, 2, 3, 4]),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -38,7 +38,7 @@ class ForLoopTest(test.TestCase):
self.assertEqual((10,), sess.run(s))
def test_python(self):
- s = operators.for_loop(
+ s = control_flow.for_loop(
range(5),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -47,7 +47,7 @@ class ForLoopTest(test.TestCase):
def test_dataset(self):
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
- s = operators.for_loop(
+ s = control_flow.for_loop(
dataset_ops.Dataset.range(5).map(to_int32),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -60,7 +60,7 @@ class WhileLoopTest(test.TestCase):
def test_tensor(self):
n = constant_op.constant(5)
- results = operators.while_loop(
+ results = control_flow.while_loop(
loop_cond=lambda i, s: i < n,
loop_body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
@@ -70,7 +70,7 @@ class WhileLoopTest(test.TestCase):
def test_python(self):
n = 5
- results = operators.while_loop(
+ results = control_flow.while_loop(
loop_cond=lambda i, s: i < n,
loop_body=lambda i, s: (i + 1, s + i),
init_state=(0, 0),
@@ -78,5 +78,22 @@ class WhileLoopTest(test.TestCase):
self.assertEqual((5, 10), results)
+class IfStmtTest(test.TestCase):
+
+ def test_tensor(self):
+ def test_if_stmt(cond):
+ return control_flow.if_stmt(
+ cond=cond,
+ body=lambda: 1,
+ orelse=lambda: -1)
+ with self.test_session() as sess:
+ self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
+ self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
+
+ def test_python(self):
+ self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1))
+ self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py
new file mode 100644
index 0000000000..c862306baa
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/data_structures.py
@@ -0,0 +1,56 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to data structures: list append, subscripts, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import tensor_array_ops
+
+# TODO(mdan): Add support for TensorList once functional.
+# TODO(mdan): Add primitives for empty list, list with elements.
+
+
+def append(target, element):
+ """The list append function.
+
+ Note: it is unspecified where target will be mutated or not. If target is
+ a TensorFlow entity, it will not be typically mutated. If target is a plain
+ list, it will be. In general, if the target is mutated then the return value
+ should point to the original entity.
+
+ Args:
+ target: An entity that supports append semantics.
+ element: The element to append.
+
+ Returns:
+ Same as target, after the append was performed.
+ """
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_append(target, element)
+ else:
+ return _py_append(target, element)
+
+
+def _tf_tensorarray_append(target, element):
+ """Overload of append that stages a TensorArray write at the last position."""
+ return target.write(target.size(), element)
+
+
+def _py_append(target, element):
+ """Overload of append that executes a Python list append."""
+ target.append(element)
+ return target
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py
new file mode 100644
index 0000000000..577d28c34d
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/data_structures_test.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for data_structures module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class AppendTest(test.TestCase):
+
+ def test_tf_tensorarray(self):
+ l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
+ l1 = data_structures.append(l, 1)
+ l2 = data_structures.append(l1, 2)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(l1.stack()), [1])
+ self.assertAllEqual(sess.run(l2.stack()), [1, 2])
+
+ def test_python(self):
+ l = []
+ self.assertAllEqual(data_structures.append(l, 1), [1])
+ self.assertAllEqual(data_structures.append(l, 2), [1, 2])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
index b6817e9d75..2c14c2c8c2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
@@ -133,18 +133,18 @@ class Scope(object):
def mark_param(self, name):
self.params.add(name)
- def mark_creation(self, name):
+ def mark_creation(self, name, writes_create_symbol=False):
if name.is_composite():
parent = name.parent
if self.has(parent):
- # This is considered mutation of the parent, not creation.
- # TODO(mdan): Is that really so?
- return
+ if not writes_create_symbol:
+ return
else:
raise ValueError('Unknown symbol "%s".' % parent)
self.created.add(name)
def mark_write(self, name):
+ """Marks the given symbol as modified in the current scope."""
self.modified.add(name)
if self.isolated:
self.mark_creation(name)
@@ -170,15 +170,37 @@ class ActivityAnalyzer(transformer.Base):
self.scope = Scope(parent_scope)
self._in_return_statement = False
- def _track_symbol(self, node):
- # This can happen when we have an attribute (or subscript) on a function
- # call. Example: a().b
+ @property
+ def _in_constructor(self):
+ innermost = self.enclosing_entities[-1]
+ if len(self.enclosing_entities) > 1:
+ parent = self.enclosing_entities[-2]
+ return isinstance(parent, gast.ClassDef) and innermost.name == '__init__'
+ return False
+
+ def _node_sets_self_attribute(self, node):
+ if anno.hasanno(node, anno.Basic.QN):
+ qn = anno.getanno(node, anno.Basic.QN)
+ # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
+ if qn.has_attr and qn.parent.qn == ('self',):
+ return True
+
+ def _track_symbol(self,
+ node,
+ composite_writes_alter_parent=False,
+ writes_create_symbol=False):
+ # A QN may be missing when we have an attribute (or subscript) on a function
+ # call. Example: a().b
if not anno.hasanno(node, anno.Basic.QN):
return
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Store):
self.scope.mark_write(qn)
+ if qn.is_composite and composite_writes_alter_parent:
+ self.scope.mark_write(qn.parent)
+ if writes_create_symbol:
+ self.scope.mark_creation(qn, writes_create_symbol=True)
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
@@ -207,7 +229,18 @@ class ActivityAnalyzer(transformer.Base):
def visit_Attribute(self, node):
self.generic_visit(node)
- self._track_symbol(node)
+ if self._in_constructor and self._node_sets_self_attribute(node):
+ self._track_symbol(
+ node, composite_writes_alter_parent=True, writes_create_symbol=True)
+ else:
+ self._track_symbol(node)
+ return node
+
+ def visit_Subscript(self, node):
+ self.generic_visit(node)
+ # Subscript writes (e.g. a[b] = "value") are considered to modify
+ # both the element itself (a[b]) and its parent (a).
+ self._track_symbol(node, composite_writes_alter_parent=True)
return node
def visit_Print(self, node):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
index 65e1a8f0ea..ef79a295bf 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
@@ -144,10 +144,21 @@ class ActivityAnalyzerTest(test.TestCase):
anno.getanno(node.body[0].body[2].value,
NodeAnno.IS_LOCAL)) # b in return b
+ def assertSymbolSetsAre(self, expected, actual, name):
+ expected = set(expected)
+ actual = set(str(s) for s in actual)
+ self.assertSetEqual(
+ expected, actual, 'for symbol set: %s\n'
+ ' Expected: %s\n'
+ ' Got: %s\n'
+ ' Missing: %s\n'
+ ' Extra: %s\n' % (name.upper(), expected, actual,
+ expected - actual, actual - expected))
+
def assertScopeIsRmc(self, scope, used, modified, created):
- self.assertItemsEqual(used, tuple(str(s) for s in scope.used))
- self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified))
- self.assertItemsEqual(created, tuple(str(s) for s in scope.created))
+ self.assertSymbolSetsAre(used, scope.used, 'read')
+ self.assertSymbolSetsAre(modified, scope.modified, 'modified')
+ self.assertSymbolSetsAre(created, scope.created, 'created')
def test_print_statement(self):
@@ -172,7 +183,7 @@ class ActivityAnalyzerTest(test.TestCase):
# arguments.
self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
- def test_call(self):
+ def test_call_args(self):
def test_fn(a):
b = 0
@@ -187,6 +198,57 @@ class ActivityAnalyzerTest(test.TestCase):
self.assertScopeIsRmc(
anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+ def test_call_args_attributes(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ a.c = 0
+ foo(a.b, a.c)
+ return a.d
+
+ node = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[1].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a.b', 'a.c'),
+ (),
+ (),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent,
+ ('a', 'a.b', 'a.c', 'a.d', 'foo'),
+ ('a.c',),
+ ('a',),
+ )
+
+ def test_call_args_subscripts(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ b = 1
+ c = 2
+ foo(a[0], a[b])
+ return a[c]
+
+ node = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a[0]', 'a[b]', 'b'),
+ (),
+ (),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent,
+ ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'),
+ ('b', 'c'),
+ ('a', 'b', 'c'),
+ )
+
def test_while(self):
def test_fn(a):
@@ -253,7 +315,72 @@ class ActivityAnalyzerTest(test.TestCase):
anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
- def test_nested_if_else_creation(self):
+ def test_if_attributes(self):
+
+ def test_fn(a):
+ if a > 0:
+ a.b = -a.c
+ d = 2 * a
+ else:
+ a.b = a.c
+ d = 1
+ return d
+
+ node = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
+ ('a', 'a.c', 'd'),
+ ('a.b', 'd'),
+ ('a', 'd'),
+ )
+
+ def test_if_subscripts(self):
+
+ def test_fn(a, b, c, e):
+ if a > 0:
+ a[b] = -a[c]
+ d = 2 * a
+ else:
+ a[0] = e
+ d = 1
+ return d
+
+ node = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b', 'c', 'a[c]'),
+ ('a', 'a[b]', 'd'),
+ ('d',),
+ )
+ # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'e'),
+ ('a', 'a[0]', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'),
+ ('a', 'd', 'a[b]', 'a[0]'),
+ ('a', 'b', 'c', 'd', 'e'),
+ )
+
+ def test_nested_if(self):
def test_fn(b):
if b > 0:
@@ -272,7 +399,7 @@ class ActivityAnalyzerTest(test.TestCase):
anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
('a',))
- def test_function_def(self):
+ def test_nested_function(self):
def test_fn(a):
@@ -287,44 +414,48 @@ class ActivityAnalyzerTest(test.TestCase):
return b, c
node = self._parse_and_analyze(test_fn)
- fndef_node = node.body[0].body[0]
+ fn_def_node = node.body[0].body[0]
self.assertScopeIsRmc(
- anno.getanno(fndef_node,
+ anno.getanno(fn_def_node,
NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'),
('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i'))
self.assertScopeIsRmc(
- anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
'x',
'y',
))
- def test_call_with_composite_names(self):
+ def test_constructor_attributes(self):
- def foo(*_):
- pass
+ class TestClass(object):
+
+ def __init__(self, a):
+ self.b = a
+ self.b.c = 1
+
+ node = self._parse_and_analyze(TestClass)
+ init_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE),
+ ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'),
+ ('self', 'a', 'self.b'),
+ )
+
+ def test_aug_assign_subscripts(self):
def test_fn(a):
- foo(a.b, a.c)
- if a > 0:
- a.b = 2
- else:
- d = 2
- d.e = a.c
- f = d.e + 1
- a.c = f
+ a[0] += 1
node = self._parse_and_analyze(test_fn)
- call_node = node.body[0].body[0].value
+ fn_node = node.body[0]
self.assertScopeIsRmc(
- anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
- ())
- if_node = node.body[0].body[1]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
- ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a',),
+ ('a', 'a[0]'),
+ ('a',),
+ )
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index 203aa3c3d1..2f553e1e23 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -48,6 +48,9 @@ from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
+# TODO(mdan): Remove the duplication between this and activity.py.
+# In particular, the symbol definitions we track here could as well be tracked
+# there because they follow the same rules for visibility.
class Scope(object):
"""Tracks symbol value references.
@@ -99,20 +102,16 @@ class TypeInfoResolver(transformer.Base):
def __init__(self, context):
super(TypeInfoResolver, self).__init__(context)
self.scope = Scope(None)
- self.function_level = 0
def visit_FunctionDef(self, node):
self.scope = Scope(self.scope)
- self.function_level += 1
- self.generic_visit(node)
- self.function_level -= 1
+ node = self.generic_visit(node)
self.scope = self.scope.parent
return node
def _visit_block(self, block):
self.scope = Scope(self.scope)
- for i, n in enumerate(block):
- block[i] = self.generic_visit(n)
+ block = self.visit_block(block)
self.scope = self.scope.parent
return block
@@ -137,7 +136,7 @@ class TypeInfoResolver(transformer.Base):
def _process_function_arg(self, arg_name):
str_name = str(arg_name)
- if self.function_level == 1 and str_name in self.context.arg_types:
+ if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
type_holder = arg_name.ast()
@@ -168,16 +167,8 @@ class TypeInfoResolver(transformer.Base):
anno.getanno(definition, 'element_type'))
return node
- def _process_tuple_assignment(self, source, t):
- for i, e in enumerate(t.elts):
- if isinstance(e, gast.Tuple):
- self._process_tuple_assignment(source, e)
- else:
- self.scope.setval(
- anno.getanno(e, anno.Basic.QN),
- gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
-
def _process_variable_assignment(self, source, targets):
+ # Special case: constructors.
if isinstance(source, gast.Call):
func = source.func
if anno.hasanno(func, 'live_val'):
@@ -190,15 +181,26 @@ class TypeInfoResolver(transformer.Base):
# We can have a whitelist of no-side-effects constructors.
# We can also step inside the constructor and further analyze.
- for t in targets:
- if isinstance(t, gast.Tuple):
- # need to recurse on the case of assigning nested tuples,
- # ex. a, (b, c) = f()
- self._process_tuple_assignment(source, t)
- elif isinstance(t, (gast.Name, gast.Attribute)):
- self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
+ # Multiple targets mean multiple assignment.
+ for target in targets:
+ # Tuple target means unpacking.
+ if isinstance(target, gast.Tuple):
+ for i, target_item in enumerate(target.elts):
+ # Two cases here:
+ # 1. Static unpacking, e.g. a, b = c, d
+ # 2. Dynamic unpacking, e.g. a, b = c
+ # The former case is optimized away.
+ if isinstance(source, (gast.Tuple, gast.List)):
+ source_item = source.elts[i]
+ else:
+ source_item = gast.Subscript(source, gast.Index(i), ctx=None)
+ self._process_variable_assignment(source_item, (target_item,))
+ elif isinstance(target, (gast.Name, gast.Attribute)):
+ target_symbol = anno.getanno(target, anno.Basic.QN)
+ self.scope.setval(target_symbol, source)
else:
- raise ValueError('Dont know how to handle assignment to %s' % t)
+ raise ValueError(
+ 'assignment target has unknown type: %s' % target_item)
def visit_With(self, node):
for wi in node.items:
@@ -218,19 +220,26 @@ class TypeInfoResolver(transformer.Base):
# type that it specified.
if (anno.getanno(node.func, 'live_val') is
self.context.type_annotation_func):
- # Expecting the actual type to be the second argument.
+
if len(node.args) != 2:
raise ValueError('"%s" must have exactly two parameters'
% self.context.type_annotation_func)
- if not anno.hasanno(node.args[0], anno.Basic.QN):
+ target_arg, type_arg = node.args
+ if not anno.hasanno(target_arg, anno.Basic.QN):
raise ValueError('the first argument of "%s" must by a symbol'
% self.context.type_annotation_func)
- if not anno.hasanno(node.args[1], 'live_val'):
- raise ValueError(
- 'the second argument of "%s" must be statically resolvable' %
- self.context.type_annotation_func)
- target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
- element_type = anno.getanno(node.args[1], 'live_val')
+ if isinstance(type_arg, gast.Str):
+ element_type = type_arg.s
+ elif isinstance(type_arg, gast.Num):
+ element_type = type_arg.n
+ else:
+ if not anno.hasanno(type_arg, 'live_val'):
+ raise ValueError(
+ 'the second argument of "%s" must be statically resolvable' %
+ self.context.type_annotation_func)
+ element_type = anno.getanno(type_arg, 'live_val')
+
+ target_symbol = anno.getanno(target_arg, anno.Basic.QN)
# Find the definition of this symbol and annotate it with the given
# data type. That in turn will cause future uses of the symbol
# to receive the same type annotation.
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
index c0de4a6043..46b7701624 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
@@ -196,23 +196,46 @@ class TypeInfoResolverTest(test.TestCase):
f_ref = node.body[0].body[1].value
self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
- def test_nested_assignment(self):
+ def test_nested_unpacking(self):
- def test_fn(foo):
- a, (b, c) = foo
+ class Foo(object):
+ pass
+
+ class Bar(object):
+ pass
+
+ def test_fn():
+ a, (b, c) = (Foo(), (Bar(), Foo()))
return a, b, c
- node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)})
- lhs = node.body[0].body[1].value.elts
- a = lhs[0]
- b = lhs[1]
- c = lhs[2]
- # TODO(mdan): change these once we have the live values propagating
- # correctly
+ node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
+ a, b, c = node.body[0].body[1].value.elts
+ self.assertEquals(Foo, anno.getanno(a, 'type'))
+ self.assertEquals(Bar, anno.getanno(b, 'type'))
+ self.assertEquals(Foo, anno.getanno(c, 'type'))
self.assertFalse(anno.hasanno(a, 'live_val'))
self.assertFalse(anno.hasanno(b, 'live_val'))
self.assertFalse(anno.hasanno(c, 'live_val'))
+ def test_inner_scope(self):
+
+ def test_fn():
+ a = []
+ utils.set_element_type(a, 1)
+ for _ in a:
+ b = []
+ utils.set_element_type(b, 2)
+ return a, b
+
+ node = self._parse_and_analyze(test_fn, {'utils': utils})
+ a, b = node.body[0].body[2].body[2].value.elts
+ self.assertEquals(1, anno.getanno(a, 'element_type'))
+ self.assertEquals(2, anno.getanno(b, 'element_type'))
+ self.assertFalse(anno.hasanno(a, 'type'))
+ self.assertFalse(anno.hasanno(b, 'type'))
+ self.assertFalse(anno.hasanno(a, 'live_val'))
+ self.assertFalse(anno.hasanno(b, 'live_val'))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index b38d52c5b2..e102ab7630 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -40,7 +40,13 @@ def try_ast_to_source(node):
class Base(gast.NodeTransformer):
- """Base class for specialized transformers."""
+ """Base class for specialized transformers.
+
+ Scope-local state tracking: to keep state across nodes, at the level of
+ (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
+ You must call enter/exit_local_scope manually, but the transformer detects
+ when they are not properly paired.
+ """
def __init__(self, context):
"""Initialize the transformer. Subclasses should call this.
@@ -53,20 +59,51 @@ class Base(gast.NodeTransformer):
self.context = context
self._enclosing_entities = []
+ # A stack that allows keeping mutable, scope-local state where scopes may be
+ # nested. For example, it can be used to track the usage of break
+ # statements in each loop, where loops may be nested.
+ self._local_scope_state = []
+ self.enter_local_scope()
+
@property
def enclosing_entities(self):
return tuple(self._enclosing_entities)
+ def enter_local_scope(self):
+ self._local_scope_state.append({})
+
+ def exit_local_scope(self):
+ return self._local_scope_state.pop()
+
+ def set_local(self, name, value):
+ self._local_scope_state[-1][name] = value
+
+ def get_local(self, name, default=None):
+ return self._local_scope_state[-1].get(name, default)
+
def debug_print(self, node):
"""Helper method useful for debugging."""
if __debug__:
print(pretty_printer.fmt(node))
return node
+ def visit_block(self, nodes):
+ """Helper equivalent to generic_visit, but for node lists."""
+ results = []
+ for node in nodes:
+ replacement = self.visit(node)
+ if replacement:
+ if isinstance(replacement, (list, tuple)):
+ results.extend(replacement)
+ else:
+ results.append(replacement)
+ return results
+
def visit(self, node):
source_code = self.context.source_code
source_file = self.context.source_file
did_enter_function = False
+ local_scope_state_size = len(self._local_scope_state)
try:
if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
@@ -97,3 +134,10 @@ class Base(gast.NodeTransformer):
finally:
if did_enter_function:
self._enclosing_entities.pop()
+
+ if local_scope_state_size != len(self._local_scope_state):
+ raise AssertionError(
+ 'Inconsistent local scope stack. Before entering node %s, the'
+ ' stack had length %d, after exit it has length %d. This'
+ ' indicates enter_local_scope and exit_local_scope are not'
+ ' well paired.')
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index 57f1c31ef6..f96b0dc377 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -27,6 +27,17 @@ from tensorflow.python.platform import test
class TransformerTest(test.TestCase):
+ def _context_for_nodetesting(self):
+ return context.EntityContext(
+ namer=None,
+ source_code=None,
+ source_file=None,
+ namespace=None,
+ arg_values=None,
+ arg_types=None,
+ owner_type=None,
+ recursive=False)
+
def test_entity_scope_tracking(self):
class TestTransformer(transformer.Base):
@@ -42,16 +53,7 @@ class TransformerTest(test.TestCase):
anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
return self.generic_visit(node)
- tr = TestTransformer(
- context.EntityContext(
- namer=None,
- source_code=None,
- source_file=None,
- namespace=None,
- arg_values=None,
- arg_types=None,
- owner_type=None,
- recursive=False))
+ tr = TestTransformer(self._context_for_nodetesting())
def test_function():
a = 0
@@ -92,6 +94,86 @@ class TransformerTest(test.TestCase):
inner_function, lambda_node),
anno.getanno(lambda_expr, 'enclosing_entities'))
+ def test_statement_info_stack(self):
+
+ class TestTransformer(transformer.Base):
+
+ # Extract all string constants from the block.
+ def visit_Str(self, node):
+ self.set_local('string', self.get_local('string', default='') + node.s)
+ return self.generic_visit(node)
+
+ def _annotate_result(self, node):
+ self.enter_local_scope()
+ node = self.generic_visit(node)
+ anno.setanno(node, 'test', self.get_local('string'))
+ self.exit_local_scope()
+ return node
+
+ def visit_While(self, node):
+ return self._annotate_result(node)
+
+ def visit_For(self, node):
+ return self._annotate_result(node)
+
+ tr = TestTransformer(self._context_for_nodetesting())
+
+ def test_function(a):
+ """Docstring."""
+ assert a == 'This should not be counted'
+ for i in range(3):
+ _ = 'a'
+ if i > 2:
+ return 'b'
+ else:
+ _ = 'c'
+ while True:
+ raise '1'
+ return 'nor this'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ for_node = node.body[0].body[2]
+ while_node = for_node.body[1].orelse[1]
+
+ self.assertFalse(anno.hasanno(for_node, 'string'))
+ self.assertEqual('abc', anno.getanno(for_node, 'test'))
+ self.assertFalse(anno.hasanno(while_node, 'string'))
+ self.assertEqual('1', anno.getanno(while_node, 'test'))
+
+ def test_statement_info_stack_checks_integrity(self):
+
+ class TestTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ self.enter_local_scope()
+ return self.generic_visit(node)
+
+ def visit_For(self, node):
+ node = self.generic_visit(node)
+ self.exit_local_scope()
+ return node
+
+ tr = TestTransformer(self._context_for_nodetesting())
+
+ def no_exit(a):
+ if a > 0:
+ print(a)
+ return None
+
+ node, _ = parser.parse_entity(no_exit)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
+ def no_entry(a):
+ for _ in a:
+ print(a)
+
+ node, _ = parser.parse_entity(no_entry)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index 0a0e72d70e..211e8eaee9 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -28,24 +28,17 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import tf_inspect
def dynamic_builtin(f, *args, **kwargs):
"""Converts a builtin function call inline."""
- # Some built-ins may be objects.
- if not tf_inspect.isbuiltin(f) and f not in (range,):
- return f(*args, **kwargs)
-
if f is len:
return dynamic_len(*args, **kwargs)
if six.PY2 and f is xrange:
return dynamic_range(*args, **kwargs)
if f is range:
return dynamic_range(*args, **kwargs)
-
- raise NotImplementedError(
- 'The "%s" builtin is not yet supported.' % f.__name__)
+ raise ValueError('%s is not supported' % f)
def dynamic_len(list_or_tensor):
@@ -98,9 +91,15 @@ def dynamic_print(*values):
if all(map(is_tf_print_compatible, values)):
return logging_ops.Print(1, values)
- def flushed_print(*vals):
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
print(*vals)
+ # The flush helps avoid garbled output in IPython.
sys.stdout.flush()
return py_func.wrap_py_func(
- flushed_print, None, values, use_dummy_return=True)
+ print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
index d9f7913d89..163e698407 100644
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ b/tensorflow/contrib/autograph/utils/builtins_test.py
@@ -76,8 +76,9 @@ class BuiltinsTest(test.TestCase):
def range(x): # pylint:disable=redefined-builtin
return x
- # Functions that just have the names of builtins are ignored.
- self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
+ # Functions that just have the names of builtins are rejected.
+ with self.assertRaises(ValueError):
+ self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
if six.PY2:
self.assertListEqual(
list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake
index 35c2a294ec..693dc7cd67 100644
--- a/tensorflow/contrib/cmake/external/grpc.cmake
+++ b/tensorflow/contrib/cmake/external/grpc.cmake
@@ -17,7 +17,7 @@ include (ExternalProject)
set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include)
set(GRPC_URL https://github.com/grpc/grpc.git)
set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc)
-set(GRPC_TAG 09386db3939cae1ac12e5f09b735adfa8958c68e)
+set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 1c3206f1a2..954e215fcc 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -330,8 +330,10 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
-GENERATE_PYTHON_OP_LIB("decode_proto_ops")
-GENERATE_PYTHON_OP_LIB("encode_proto_ops")
+GENERATE_PYTHON_OP_LIB("decode_proto_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py)
+GENERATE_PYTHON_OP_LIB("encode_proto_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py)
GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
@@ -345,7 +347,8 @@ GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
-GENERATE_PYTHON_OP_LIB("rpc_ops")
+GENERATE_PYTHON_OP_LIB("rpc_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py)
GENERATE_PYTHON_OP_LIB("script_ops")
GENERATE_PYTHON_OP_LIB("sdca_ops")
GENERATE_PYTHON_OP_LIB("set_ops")
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index c28c3a18e4..b615824460 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -1640,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC):
_NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER
-@ops.RegisterGradient("CudnnRNN")
-def _cudnn_rnn_backward(op, *grad):
- if not op.get_attr("is_training"):
- raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
- return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
- input=op.inputs[0],
- input_h=op.inputs[1],
- input_c=op.inputs[2],
- params=op.inputs[3],
- output=op.outputs[0],
- output_h=op.outputs[1],
- output_c=op.outputs[2],
- output_backprop=grad[0],
- output_h_backprop=grad[1],
- output_c_backprop=grad[2],
- reserve_space=op.outputs[3],
- dropout=op.get_attr("dropout"),
- seed=op.get_attr("seed"),
- seed2=op.get_attr("seed2"),
- rnn_mode=op.get_attr("rnn_mode"),
- input_mode=op.get_attr("input_mode"),
- direction=op.get_attr("direction"))
-
-
ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 637b1dc46c..077cbba9d2 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -41,6 +41,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
+@@sample_from_datasets
@@scan
@@shuffle_and_repeat
@@sliding_window_batch
@@ -69,6 +70,7 @@ from tensorflow.contrib.data.python.ops.get_single_element import get_single_ele
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
+from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index 83ada6fb67..c56910c783 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -19,6 +19,17 @@ cc_library(
)
cc_library(
+ name = "directed_interleave_dataset_op",
+ srcs = ["directed_interleave_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
@@ -52,6 +63,7 @@ cc_library(
cc_library(
name = "dataset_kernels",
deps = [
+ ":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
new file mode 100644
index 0000000000..48d3734162
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -0,0 +1,274 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class DirectedInterleaveDatasetOp : public DatasetOpKernel {
+ public:
+ explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ DatasetBase* selector_input;
+ OP_REQUIRES_OK(ctx,
+ GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
+
+ OP_REQUIRES(
+ ctx,
+ selector_input->output_dtypes().size() == 1 &&
+ selector_input->output_dtypes()[0] == DT_INT64 &&
+ selector_input->output_shapes().size() == 1 &&
+ selector_input->output_shapes()[0].IsCompatibleWith(
+ PartialTensorShape({})),
+ errors::InvalidArgument(
+ "The selector input must be a dataset of scalar int64 elements."));
+
+ std::vector<DatasetBase*> data_inputs;
+ for (size_t i = 1; i < ctx->num_inputs(); ++i) {
+ DatasetBase* input;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
+ data_inputs.push_back(input);
+
+ OP_REQUIRES(
+ ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
+ errors::InvalidArgument(
+ "All inputs must have the same output_dtypes. First input "
+ "has types ",
+ DataTypeVectorString(data_inputs[0]->output_dtypes()),
+ ", and input ", i - 1, " has types ",
+ DataTypeVectorString(input->output_dtypes())));
+ }
+ *output = new Dataset(ctx, selector_input, std::move(data_inputs));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
+ std::vector<DatasetBase*> data_inputs)
+ : GraphDatasetBase(ctx),
+ selector_input_(selector_input),
+ data_inputs_(std::move(data_inputs)) {
+ selector_input_->Ref();
+
+ output_shapes_ = data_inputs_[0]->output_shapes();
+ data_inputs_[0]->Ref();
+ for (size_t i = 1; i < data_inputs_.size(); ++i) {
+ const DatasetBase* data_input = data_inputs_[i];
+ data_input->Ref();
+ for (size_t j = 0; j < output_shapes_.size(); ++j) {
+ output_shapes_[j] = MostSpecificCompatibleShape(
+ output_shapes_[j], data_input->output_shapes()[j]);
+ }
+ }
+ }
+
+ ~Dataset() override {
+ selector_input_->Unref();
+ for (DatasetBase* data_input : data_inputs_) {
+ data_input->Unref();
+ }
+ }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::DirectedInterleave")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return data_inputs_[0]->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() override {
+ return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* selector_input_node;
+ TF_RETURN_IF_ERROR(
+ b->AddParentDataset(ctx, selector_input_, &selector_input_node));
+ std::vector<Node*> data_input_nodes(data_inputs_.size());
+ for (size_t i = 0; i < data_inputs_.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
+ }
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
+ {{1, data_input_nodes}}, {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ selector_input_impl_(params.dataset->selector_input_->MakeIterator(
+ params.prefix + ".selector")),
+ num_active_inputs_(params.dataset->data_inputs_.size()) {
+ data_input_impls_.reserve(params.dataset->data_inputs_.size());
+ for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) {
+ const DatasetBase* data_input = params.dataset->data_inputs_[i];
+ data_input_impls_.push_back(data_input->MakeIterator(
+ strings::StrCat(params.prefix, "[", i, "]")));
+ }
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (!selector_input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ while (true) {
+ std::vector<Tensor> selector_result;
+ *end_of_sequence = false;
+ TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
+ ctx, &selector_result, end_of_sequence));
+ if (*end_of_sequence) {
+ selector_input_impl_.reset();
+ for (auto& data_input_impl : data_input_impls_) {
+ data_input_impl.reset();
+ }
+ return Status::OK();
+ }
+
+ int64 selected_input = selector_result[0].scalar<int64>()();
+ if (selected_input < 0 || selected_input > data_input_impls_.size()) {
+ return errors::InvalidArgument(
+ "Selector index out of range: ", selected_input,
+ " >= ", data_input_impls_.size());
+ }
+
+ if (data_input_impls_[selected_input]) {
+ bool end_of_selected_input = false;
+ TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
+ ctx, out_tensors, &end_of_selected_input));
+
+ if (!end_of_selected_input) {
+ return Status::OK();
+ }
+
+ data_input_impls_[selected_input].reset();
+ --num_active_inputs_;
+
+ if (num_active_inputs_ == 0) {
+ selector_input_impl_.reset();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
+ << selected_input;
+ }
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (selector_input_impl_) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const auto& data_input_impl = data_input_impls_[i];
+ if (data_input_impl) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
+ ""));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("selector_input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_));
+ } else {
+ selector_input_impl_.reset();
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ if (!reader->Contains(full_name(
+ strings::StrCat("data_input_impl_empty[", i, "]")))) {
+ TF_RETURN_IF_ERROR(
+ RestoreParent(ctx, reader, data_input_impls_[i]));
+ } else {
+ data_input_impls_[i].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
+ GUARDED_BY(mu_);
+ int64 num_active_inputs_ GUARDED_BY(mu_);
+ };
+
+ static PartialTensorShape MostSpecificCompatibleShape(
+ const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
+ PartialTensorShape output_tensorshape;
+ if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ return output_tensorshape;
+ auto dims1 = ts1.dim_sizes();
+ auto dims2 = ts2.dim_sizes();
+ for (int d = 0; d < ts1.dims(); d++) {
+ if (dims1[d] == dims2[d])
+ output_tensorshape.Concatenate(dims1[d]);
+ else
+ output_tensorshape.Concatenate(-1);
+ }
+ return output_tensorshape;
+ }
+
+ const DatasetBase* const selector_input_;
+ const std::vector<DatasetBase*> data_inputs_;
+ std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index cf0a8bbccb..137deb6352 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -17,6 +17,23 @@ limitations under the License.
namespace tensorflow {
+REGISTER_OP("DirectedInterleaveDataset")
+ .Input("selector_input_dataset: variant")
+ .Input("data_input_datasets: N * variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+
+selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
+ which of the `N` data inputs should produce the next output element.
+data_input_datasets: `N` datasets with the same type that will be interleaved
+ according to the values of `selector_input_dataset`.
+)doc");
+
REGISTER_OP("IgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 256ad8d94d..f8556a1b28 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -94,6 +95,76 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
+class ParallelInterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self.input_values = np.array([4, 5, 6], dtype=np.int64)
+ self.num_repeats = 2
+ self.num_outputs = np.sum(self.input_values) * 2
+
+ def _build_ds(self, cycle_length, block_length, sloppy=False):
+ return (dataset_ops.Dataset.from_tensor_slices(
+ self.input_values).repeat(self.num_repeats).apply(
+ interleave_ops.parallel_interleave(
+ lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
+ cycle_length, block_length, sloppy)))
+
+ def testSerializationCore(self):
+ # cycle_length > 1, block_length > 1
+ cycle_length = 2
+ block_length = 3
+ self.run_core_tests(
+ lambda: self._build_ds(cycle_length, block_length),
+ lambda: self._build_ds(cycle_length * 2, block_length * 1),
+ self.num_outputs)
+ # cycle_length = 1
+ cycle_length = 1
+ block_length = 3
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+ # block_length = 1
+ cycle_length = 2
+ block_length = 1
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+
+ def testSerializationWithSloppy(self):
+ break_points = self.gen_break_points(self.num_outputs, 10)
+ expected_outputs = np.repeat(
+ np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
+ self.num_repeats).tolist()
+
+ def run_test(cycle_length, block_length):
+ actual = self.gen_outputs(
+ lambda: self._build_ds(cycle_length, block_length, True),
+ break_points, self.num_outputs)
+ self.assertSequenceEqual(sorted(actual), expected_outputs)
+
+ # cycle_length > 1, block_length > 1
+ run_test(2, 3)
+ # cycle_length = 1
+ run_test(1, 3)
+ # block_length = 1
+ run_test(2, 1)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).apply(
+ interleave_ops.parallel_interleave(_interleave_fn, 1))
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self):
@@ -836,5 +907,107 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
+class DirectedInterleaveDatasetTest(test.TestCase):
+
+ def testBasic(self):
+ selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
+ input_datasets = [
+ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
+ ]
+ dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(100):
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def _normalize(self, vec):
+ batched = (len(vec.shape) == 2)
+ return vec / vec.sum(axis=1, keepdims=True) if batched else vec / vec.sum()
+
+ def _chi2(self, expected, actual):
+ actual = np.asarray(actual)
+ expected = np.asarray(expected)
+ diff = actual - expected
+ chi2 = np.sum(diff * diff / expected, axis=0)
+ return chi2
+
+ def testSampleFromDatasets(self):
+ random_seed.set_random_seed(1618)
+ num_samples = 10000
+ rand_probs = self._normalize(np.random.random_sample((10,)))
+ rand_probs2 = self._normalize(np.random.random_sample((15,)))
+
+ for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]:
+ probs = np.asarray(probs)
+
+ # Create a dataset that samples each integer in `[0, probs.shape[0])`
+ # with probability given by `probs[i]`.
+ dataset = interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(probs.shape[0])
+ ], probs)
+ dataset = dataset.take(num_samples)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ freqs = np.zeros_like(probs)
+ for _ in range(num_samples):
+ freqs[sess.run(next_element)] += 1
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ # Use chi-squared test to assert that the observed distribution
+ # matches the expected distribution. Based on the implementation
+ # in "tensorflow/python/kernel_tests/multinomial_op_test.py".
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ r"vector of length `len\(datasets\)`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[0.25, 0.25, 0.25, 0.25])
+
+ with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[1, 1])
+
+ with self.assertRaisesRegexp(TypeError, "must have the same type"):
+ interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(0.0)
+ ])
+
+
+class SampleFromDatasetsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, probs, num_samples):
+ dataset = interleave_ops.sample_from_datasets(
+ [
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(len(probs))
+ ],
+ probs,
+ seed=1813)
+ return dataset.take(num_samples)
+
+ def testSerializationCore(self):
+ self.run_core_tests(
+ lambda: self._build_dataset([0.5, 0.5], 100),
+ lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
index b13ad9ba4e..d0cb203a3a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
@@ -48,8 +48,8 @@ class SequenceDatasetSerializationTest(
self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
def testInvalidSkip(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
def _build_take_dataset(self, count):
@@ -75,8 +75,8 @@ class SequenceDatasetSerializationTest(
self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
def testInvalidTake(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
def _build_repeat_dataset(self, count, take_count=3):
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 0e4590829b..e00f2304cc 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -172,8 +172,18 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":contrib_op_loader",
+ ":gen_dataset_ops",
+ ":random_ops",
+ "//tensorflow/contrib/stateless",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
],
)
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 1eba010b56..28db949da9 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -370,9 +370,10 @@ def assert_element_shape(expected_shapes):
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
- checked_tensors = [with_shape(shape, tensor)
- for shape, tensor in zip(flatten_shapes,
- flatten_tensors)]
+ checked_tensors = [
+ with_shape(shape, tensor)
+ for shape, tensor in zip(flatten_shapes, flatten_tensors)
+ ]
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 91f19da02d..106a1ef388 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,7 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib import stateless
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.contrib.data.python.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -140,3 +151,92 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
prefetch_input_elements=None)
return _apply_fn
+
+
+class DirectedInterleaveDataset(dataset_ops.Dataset):
+ """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
+
+ def __init__(self, selector_input, data_inputs):
+ self._selector_input = selector_input
+ self._data_inputs = list(data_inputs)
+
+ for data_input in data_inputs[1:]:
+ if (data_input.output_types != data_inputs[0].output_types or
+ data_input.output_classes != data_inputs[0].output_classes):
+ raise TypeError("All datasets must have the same type.")
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_dataset_ops.directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(),
+ [data_input._as_variant_tensor() for data_input in self._data_inputs],
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._data_inputs[0].output_classes
+
+ @property
+ def output_shapes(self):
+ ret = self._data_inputs[0].output_shapes
+ for data_input in self._data_inputs[1:]:
+ ret = nest.pack_sequence_as(ret, [
+ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
+ nest.flatten(ret), nest.flatten(data_input.output_shapes))
+ ])
+ return ret
+
+ @property
+ def output_types(self):
+ return self._data_inputs[0].output_types
+
+
+def sample_from_datasets(datasets, weights=None, seed=None):
+ """Samples elements at random from the datasets in `datasets`.
+
+ Args:
+ datasets: A list of @{tf.data.Dataset} objects with compatible structure.
+ weights: (Optional.) A list of `len(datasets)` floating-point values,
+ where `weights[i]` represents the probability with which an element
+ should be sampled from `datasets[i]`. Defaults to a uniform distribution
+ across `datasets`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ @{tf.set_random_seed} for behavior.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` at random, according to
+ `weights` if provided, otherwise with uniform probability.
+
+ Raises:
+ TypeError: If the `datasets` or `weights` arguments have the wrong type.
+ ValueError: If the `weights` argument is specified and does not match the
+ length of the `datasets` element.
+ """
+ num_datasets = len(datasets)
+ if weights is None:
+ weights = array_ops.ones(
+ [num_datasets], dtype=dtypes.float32, name="weights")
+ else:
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError("`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits = math_ops.log(weights, name="logits")
+
+ def select_dataset(seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial([logits], 1, seed=seed), axis=[0, 1])
+
+ selector_input = random_ops.RandomDataset(seed).batch(2).map(select_dataset)
+
+ return DirectedInterleaveDataset(selector_input, datasets)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 14de1e8f49..5d22d9aa2b 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -116,7 +116,8 @@ in the input function gives a solid boost in performance. When using
## Caveats
This feature is in early stages and there are a lot of improvements forthcoming:
-* Metrics are not yet supported during distributed training.
+* Metrics are not yet supported during distributed training. They are still
+supported during the evaluation.
* Summaries are only computed in the first tower in `MirroredStrategy`.
* Evaluation is not yet distributed.
* Eager support is in the works; performance can be more challenging with eager
@@ -130,6 +131,8 @@ adjusting your learning rate or batch size according to the number of GPUs.
We are working on addressing this limitation by splitting each batch across GPUs
instead.
* PartitionedVariables are not supported yet.
+* Input pipelines with Datasets that capture stateful objects and rely on
+`make_initializable_iterator` are not supported yet.
## What's next?
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 5aad21cccd..837a1f1348 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -131,6 +131,7 @@ py_library(
deps = [
":mirrored_strategy",
":one_device_strategy",
+ ":tpu_strategy",
"//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
@@ -225,14 +226,30 @@ py_library(
],
)
-cuda_py_test(
- name = "minimize_loss_test",
+py_library(
+ name = "tpu_strategy",
+ srcs = ["tpu_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/contrib/distribute/python:one_device_strategy",
+ "//tensorflow/contrib/eager/python:datasets",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/contrib/tpu",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "minimize_loss_test_lib",
+ testonly = 1,
srcs = ["minimize_loss_test.py"],
- additional_deps = [
+ deps = [
":combinations",
":single_loss_example",
- "@absl_py//absl/testing:parameterized",
- "//third_party/py/numpy",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
@@ -240,6 +257,16 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//tensorflow/python/ops/losses",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "minimize_loss_test",
+ srcs = ["minimize_loss_test.py"],
+ additional_deps = [
+ ":minimize_loss_test_lib",
],
tags = [
"multi_and_single_gpu",
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 02b1e7ef9f..1f66997e6e 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -45,6 +45,7 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
@@ -55,6 +56,7 @@ from tensorflow.python.util import tf_inspect
GPU_TEST = "test_gpu" in sys.argv[0]
+TPU_TEST = "test_tpu" in sys.argv[0]
def generate(combinations):
@@ -108,6 +110,11 @@ def generate(combinations):
if "distribution" in kwargs:
distribution = kwargs["distribution"]
kwargs["distribution"] = distribution.strategy
+ if distribution.required_tpu and not TPU_TEST:
+ self.skipTest("Test requires a TPU, but it's not available.")
+ if not distribution.required_tpu and TPU_TEST:
+ self.skipTest("Test that doesn't require a TPU.")
+
if not distribution.required_gpus:
if GPU_TEST:
self.skipTest("Test that doesn't require GPUs.")
@@ -232,10 +239,12 @@ class NamedObject(object):
class NamedDistribution(object):
"""Translates DistributionStrategy and its data into a good name."""
- def __init__(self, name, distribution, required_gpus):
+ def __init__(self, name, distribution, required_gpus=None,
+ required_tpu=False):
self._distribution = distribution
self._name = name
self._required_gpus = required_gpus
+ self._required_tpu = required_tpu
def __repr__(self):
return self._name
@@ -248,10 +257,16 @@ class NamedDistribution(object):
def required_gpus(self):
return self._required_gpus
+ @property
+ def required_tpu(self):
+ return self._required_tpu
+
one_device_strategy = NamedDistribution(
"OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
None)
+tpu_strategy = NamedDistribution(
+ "TPU", tpu_strategy.TpuStrategy(), required_tpu=True)
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1)
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 0fa90df79b..4219d54cbd 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example
from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
+from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -42,24 +43,46 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
combinations.times(
combinations.distributions_and_v1_optimizers(),
combinations.combine(mode=["graph"], use_callable_loss=[True, False])
- + combinations.combine(mode=["eager"], use_callable_loss=[True])))
- def testTrainNetwork(self, distribution, optimizer_fn,
- use_callable_loss=True):
+ + combinations.combine(mode=["eager"], use_callable_loss=[True]),
+ combinations.combine(is_tpu=[False])) +
+ combinations.combine(
+ distribution=[combinations.tpu_strategy],
+ optimizer_fn=[combinations.adam_optimizer_v1_fn],
+ mode=["graph"],
+ use_callable_loss=[False],
+ is_tpu=[True]))
+ def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
+ is_tpu):
with distribution.scope():
model_fn, dataset, layer = minimize_loss_example(
optimizer_fn,
use_bias=True,
use_callable_loss=use_callable_loss)
+ # TODO(isaprykin): Eliminate `is_tpu`. Probably add a
+ # `DistributionStrategy.create_monitor` so that each DistributionStrategy
+ # could influence its training loop. That method would return an instance
+ # of Monitor. TPUMonitor would execute tpu.initialize_system() and
+ # tpu.shutdown_system().
+ if is_tpu:
+ dataset = dataset.batch(2)
+
iterator = distribution.distribute_dataset(dataset)
def run_step():
+ # TODO(isaprykin): Make iterator get_next() return a list of sub-
+ # batches for each iteration. Pass iterator.get_next() and not iterator
+ # to call_for_each_tower.
return distribution.group(
distribution.call_for_each_tower(
- model_fn, iterator.get_next(), run_concurrently=layer.built))
+ model_fn,
+ iterator.get_next() if not is_tpu else iterator,
+ run_concurrently=layer.built))
if not context.executing_eagerly():
with self.test_session() as sess:
+ if is_tpu:
+ sess.run(tpu.initialize_system())
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -70,6 +93,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
weights.append(self.evaluate(distribution.fetch(layer.kernel)))
biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ if is_tpu:
+ with self.test_session() as sess:
+ sess.run(tpu.shutdown_system())
+
error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
self.assertTrue(is_not_increasing)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
new file mode 100644
index 0000000000..0ac307dd6a
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -0,0 +1,82 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU Distribution Strategy.
+
+This is experimental. It's not ready for general use.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import tpu
+from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+# TODO(isaprykin): Consider whether inheriting is really appropriate.
+class TpuStrategy(one_device_strategy.OneDeviceStrategy):
+
+ def __init__(self, master=None, iterations=None, model_dir=None):
+ super(TpuStrategy, self).__init__('/cpu:0')
+
+ def _call_for_each_tower(self, fn, *args, **kwargs):
+ kwargs.pop('run_concurrently', None)
+
+ # TODO(isaprykin): Give an API for many iterations per step.
+ iterations = 1
+
+ # TODO(isaprykin): Do not hard code shapes and input format :)
+ # TODO(isaprykin): Detect the number of TPU cores automatically.
+
+ def dequeueing_fn(*args, **kwargs):
+ del args, kwargs
+ x, = tpu.infeed_dequeue_tuple(dtypes=[dtypes.float32], shapes=[[1, 1, 1]])
+ return fn(x)
+
+ iterator = args[0]
+
+ def infeed_input(i):
+ """Get input, split it and then enqueue."""
+ batches = iterator.get_next()
+ batches = array_ops.split(batches, 2)
+
+ infeeds = [
+ tpu_ops.infeed_enqueue_tuple(
+ inputs=[batches[j]], shapes=[[1, 1, 1]], device_ordinal=j)
+ for j in range(2)
+ ]
+
+ with ops.control_dependencies(infeeds):
+ return i + 1
+
+ with ops.device('/task:0/device:CPU:0'):
+ enqueue_ops = control_flow_ops.while_loop(
+ lambda i: i < iterations,
+ infeed_input, [constant_op.constant(0)],
+ parallel_iterations=1)
+
+ def iterate_on_tpu():
+ return tpu.repeat(iterations, dequeueing_fn, [])
+
+ with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access
+ tpu_result = tpu.batch_parallel(iterate_on_tpu, [], num_shards=2)
+
+ return control_flow_ops.group(tpu_result, enqueue_ops)
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 314c54ed00..bd641014e9 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -17,10 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+def _validate_input_fn_and_repeat_dataset(train_input_fn):
+ """Validates whether the input_fn is valid, and repeat() if tf.Dataset."""
+ def _input_fn():
+ result_input_fn = train_input_fn()
+ if isinstance(result_input_fn, dataset_ops.Dataset):
+ return result_input_fn.repeat()
+ return result_input_fn
+
+ return _input_fn
+
+
class _BoostedTreesEstimator(estimator.Estimator):
"""An Estimator for Tensorflow Boosted Trees models."""
@@ -36,6 +48,7 @@ class _BoostedTreesEstimator(estimator.Estimator):
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None):
"""Initializes a `BoostedTreesEstimator` instance.
@@ -65,13 +78,16 @@ class _BoostedTreesEstimator(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -96,6 +112,7 @@ def boosted_trees_classifier_train_in_memory(
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None,
train_hooks=None):
"""Trains a boosted tree classifier with in memory dataset.
@@ -108,10 +125,13 @@ def boosted_trees_classifier_train_in_memory(
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
- def input_fn_train():
+ def train_input_fn():
dataset = create-dataset-from-training-data
- # Don't use repeat or cache, since it is assumed to be one epoch
- # This is either tf.data.Dataset, or a tuple of feature dict and label.
+ # This is tf.data.Dataset of a tuple of feature dict and label.
+ # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
+ # Dataset.from_tensors(label_array)))
+ # The returned Dataset shouldn't be batched.
+ # If Dataset repeats, only the first repetition would be used for training.
return dataset
classifier = boosted_trees_classifier_train_in_memory(
@@ -162,6 +182,9 @@ def boosted_trees_classifier_train_in_memory(
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
train_hooks: a list of Hook instances to be passed to estimator.train().
@@ -184,7 +207,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -202,7 +225,9 @@ def boosted_trees_classifier_train_in_memory(
in_memory_classifier = estimator.Estimator(
model_fn=_model_fn, model_dir=model_dir, config=config)
- in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
+ in_memory_classifier.train(
+ input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn),
+ hooks=train_hooks)
return in_memory_classifier
# pylint: enable=protected-access
@@ -220,6 +245,7 @@ def boosted_trees_regressor_train_in_memory(
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None,
train_hooks=None):
"""Trains a boosted tree regressor with in memory dataset.
@@ -232,10 +258,13 @@ def boosted_trees_regressor_train_in_memory(
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
- def input_fn_train():
+ def train_input_fn():
dataset = create-dataset-from-training-data
- # Don't use repeat or cache, since it is assumed to be one epoch
- # This is either tf.data.Dataset, or a tuple of feature dict and label.
+ # This is tf.data.Dataset of a tuple of feature dict and label.
+ # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
+ # Dataset.from_tensors(label_array)))
+ # The returned Dataset shouldn't be batched.
+ # If Dataset repeats, only the first repetition would be used for training.
return dataset
regressor = boosted_trees_regressor_train_in_memory(
@@ -279,6 +308,9 @@ def boosted_trees_regressor_train_in_memory(
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
train_hooks: a list of Hook instances to be passed to estimator.train().
@@ -300,7 +332,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -317,7 +349,9 @@ def boosted_trees_regressor_train_in_memory(
in_memory_regressor = estimator.Estimator(
model_fn=_model_fn, model_dir=model_dir, config=config)
- in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks)
+ in_memory_regressor.train(
+ input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn),
+ hooks=train_hooks)
return in_memory_regressor
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index eee5910687..76cbefe5e9 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.contrib.estimator.python.estimator import boosted_trees
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
@@ -49,12 +50,24 @@ def _make_train_input_fn(is_classification):
"""Makes train input_fn for classification/regression."""
def _input_fn():
- features = dict(FEATURES_DICT)
- if is_classification:
- labels = CLASSIFICATION_LABELS
- else:
- labels = REGRESSION_LABELS
- return features, labels
+ features_dict = dict(FEATURES_DICT)
+ labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+ return features_dict, labels
+
+ return _input_fn
+
+
+def _make_train_input_fn_dataset(is_classification):
+ """Makes input_fn using Dataset."""
+
+ def _input_fn():
+ features_dict = dict(FEATURES_DICT)
+ labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+ ds = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(features_dict),
+ dataset_ops.Dataset.from_tensors(labels)
+ ))
+ return ds
return _input_fn
@@ -132,15 +145,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
est = boosted_trees.boosted_trees_classifier_train_in_memory(
- train_input_fn=train_input_fn,
- feature_columns=self._feature_columns,
- n_trees=1,
- max_depth=5)
+ train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+ n_trees=1, max_depth=5)
# It will stop after 5 steps because of the max depth and num trees.
self._assert_checkpoint(
est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
- # Check eval.
+ # Check evaluate and predict.
eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
# Validate predictions.
@@ -148,24 +159,59 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
+ def testBinaryClassifierTrainInMemoryWithDataset(self):
+ train_input_fn = _make_train_input_fn_dataset(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+ n_trees=1, max_depth=5)
+ # It will stop after 5 steps because of the max depth and num trees.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testRegressorTrainInMemoryAndEvalAndInfer(self):
train_input_fn = _make_train_input_fn(is_classification=False)
predict_input_fn = numpy_io.numpy_input_fn(
x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
est = boosted_trees.boosted_trees_regressor_train_in_memory(
- train_input_fn=train_input_fn,
- feature_columns=self._feature_columns,
- n_trees=1,
- max_depth=5)
+ train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+ n_trees=1, max_depth=5)
# It will stop after 5 steps because of the max depth and num trees.
self._assert_checkpoint(
est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
- # Check eval.
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
+
+ def testRegressorTrainInMemoryWithDataset(self):
+ train_input_fn = _make_train_input_fn_dataset(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_regressor_train_in_memory(
+ train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+ n_trees=1, max_depth=5)
+ # It will stop after 5 steps because of the max depth and num trees.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ # Check evaluate and predict.
eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
self.assertAllClose(eval_res['average_loss'], 2.478283)
- # Validate predictions.
predictions = list(est.predict(input_fn=predict_input_fn))
self.assertAllClose(
[[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 4a5ed0ab0f..11397e86bd 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -72,7 +72,9 @@ See the @{$python/contrib.framework} guide.
@@variable
@@VariableDeviceChooser
@@convolutional_delta_orthogonal
+@@convolutional_orthogonal_1d
@@convolutional_orthogonal_2d
+@@convolutional_orthogonal_3d
@@zero_initializer
@@load_checkpoint
@@ -118,7 +120,9 @@ from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
from tensorflow.python.ops.array_ops import broadcast_to
from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal
+from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d
from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d
+from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest', 'broadcast_to']
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 73acd05b60..6fa43059f3 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -710,7 +710,10 @@ def gan_train_ops(
be used to train a generator/discriminator pair.
"""
if isinstance(model, namedtuples.CycleGANModel):
- saved_params = locals()
+ # Get and store all arguments other than model and loss from locals.
+ # Contents of locals should not be modified, may not affect values. So make
+ # a copy. https://docs.python.org/2/library/functions.html#locals.
+ saved_params = dict(locals())
saved_params.pop('model', None)
saved_params.pop('loss', None)
kwargs = saved_params.pop('kwargs', {})
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index 8b7ff75ba5..2c5fa7af89 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -61,3 +61,22 @@ cuda_py_test(
shard_count = 5,
tags = ["noasan"],
)
+
+cuda_py_test(
+ name = "linear_operator_kronecker_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/linear_operator_kronecker_test.py"],
+ additional_deps = [
+ ":linalg_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 8,
+ tags = ["noasan"],
+)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index 14cc3b2b49..38bd66b13f 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -22,6 +22,7 @@ See the @{$python/contrib.linalg} guide.
@@LinearOperatorIdentity
@@LinearOperatorScaledIdentity
@@LinearOperatorFullMatrix
+@@LinearOperatorKronecker
@@LinearOperatorLowerTriangular
@@LinearOperatorLowRankUpdate
@@LinearOperatorComposition
@@ -36,6 +37,7 @@ from __future__ import print_function
from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import *
+from tensorflow.contrib.linalg.python.ops.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py
new file mode 100644
index 0000000000..6574da22a1
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py
@@ -0,0 +1,194 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.linalg.python.ops import linear_operator_kronecker as kronecker
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_test_util
+from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.platform import test
+
+linalg = linalg_lib
+random_seed.set_random_seed(23)
+rng = np.random.RandomState(0)
+
+
+def _kronecker_dense(factors):
+ """Convert a list of factors, into a dense Kronecker product."""
+ product = factors[0]
+ for factor in factors[1:]:
+ product = product[..., array_ops.newaxis, :, array_ops.newaxis]
+ factor_to_mul = factor[..., array_ops.newaxis, :, array_ops.newaxis, :]
+ product *= factor_to_mul
+ product = array_ops.reshape(
+ product,
+ shape=array_ops.concat(
+ [array_ops.shape(product)[:-4],
+ [array_ops.shape(product)[-4] * array_ops.shape(product)[-3],
+ array_ops.shape(product)[-2] * array_ops.shape(product)[-1]]
+ ], axis=0))
+
+ return product
+
+
+class KroneckerDenseTest(test.TestCase):
+
+ def testKroneckerDenseMatrix(self):
+ x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32)
+ y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32)
+ # From explicitly writing out the kronecker product of x and y.
+ z = ops.convert_to_tensor([
+ [2., 4., 3., 6.],
+ [10., -2., 15., -3.],
+ [1., 2., 2., 4.],
+ [5., -1., 10., -2.]], dtype=dtypes.float32)
+ # From explicitly writing out the kronecker product of y and x.
+ w = ops.convert_to_tensor([
+ [2., 3., 4., 6.],
+ [1., 2., 2., 4.],
+ [10., 15., -2., -3.],
+ [5., 10., -1., -2.]], dtype=dtypes.float32)
+
+ with self.test_session():
+ self.assertAllClose(_kronecker_dense([x, y]).eval(), z.eval())
+ self.assertAllClose(_kronecker_dense([y, x]).eval(), w.eval())
+
+
+class SquareLinearOperatorKroneckerTest(
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+ def setUp(self):
+ # Increase from 1e-6 to 1e-4
+ self._atol[dtypes.float32] = 1e-4
+ self._atol[dtypes.complex64] = 1e-4
+ self._rtol[dtypes.float32] = 1e-4
+ self._rtol[dtypes.complex64] = 1e-4
+
+ @property
+ def _operator_build_infos(self):
+ build_info = linear_operator_test_util.OperatorBuildInfo
+ return [
+ build_info((1, 1), factors=[(1, 1), (1, 1)]),
+ build_info((8, 8), factors=[(2, 2), (2, 2), (2, 2)]),
+ build_info((12, 12), factors=[(2, 2), (3, 3), (2, 2)]),
+ build_info((1, 3, 3), factors=[(1, 1), (1, 3, 3)]),
+ build_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]),
+ ]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = list(build_info.shape)
+ expected_factors = build_info.__dict__["factors"]
+ matrices = [
+ linear_operator_test_util.random_positive_definite_matrix(
+ block_shape, dtype, force_well_conditioned=True)
+ for block_shape in expected_factors
+ ]
+
+ if use_placeholder:
+ matrices_ph = [
+ array_ops.placeholder(dtype=dtype) for _ in expected_factors
+ ]
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # values are random and we want the same value used for both mat and
+ # feed_dict.
+ matrices = self.evaluate(matrices)
+ operator = kronecker.LinearOperatorKronecker(
+ [linalg.LinearOperatorFullMatrix(
+ m_ph, is_square=True) for m_ph in matrices_ph],
+ is_square=True)
+ feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
+ else:
+ operator = kronecker.LinearOperatorKronecker(
+ [linalg.LinearOperatorFullMatrix(
+ m, is_square=True) for m in matrices])
+ feed_dict = None
+ # Should be auto-set.
+ self.assertTrue(operator.is_square)
+
+ matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
+
+ kronecker_dense = _kronecker_dense(matrices)
+
+ if not use_placeholder:
+ kronecker_dense.set_shape(shape)
+
+ return operator, kronecker_dense, feed_dict
+
+ def test_is_x_flags(self):
+ # Matrix with two positive eigenvalues, 1, and 1.
+ # The matrix values do not effect auto-setting of the flags.
+ matrix = [[1., 0.], [1., 1.]]
+ operator = kronecker.LinearOperatorKronecker(
+ [linalg.LinearOperatorFullMatrix(matrix),
+ linalg.LinearOperatorFullMatrix(matrix)],
+ is_positive_definite=True,
+ is_non_singular=True,
+ is_self_adjoint=False)
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertFalse(operator.is_self_adjoint)
+
+ def test_is_non_singular_auto_set(self):
+ # Matrix with two positive eigenvalues, 11 and 8.
+ # The matrix values do not effect auto-setting of the flags.
+ matrix = [[11., 0.], [1., 8.]]
+ operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
+ operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
+
+ operator = kronecker.LinearOperatorKronecker(
+ [operator_1, operator_2],
+ is_positive_definite=False, # No reason it HAS to be False...
+ is_non_singular=None)
+ self.assertFalse(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+
+ with self.assertRaisesRegexp(ValueError, "always non-singular"):
+ kronecker.LinearOperatorKronecker(
+ [operator_1, operator_2], is_non_singular=False)
+
+ def test_name(self):
+ matrix = [[11., 0.], [1., 8.]]
+ operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left")
+ operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right")
+
+ operator = kronecker.LinearOperatorKronecker([operator_1, operator_2])
+
+ self.assertEqual("left_x_right", operator.name)
+
+ def test_different_dtypes_raises(self):
+ operators = [
+ linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)),
+ linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))
+ ]
+ with self.assertRaisesRegexp(TypeError, "same dtype"):
+ kronecker.LinearOperatorKronecker(operators)
+
+ def test_empty_or_one_operators_raises(self):
+ with self.assertRaisesRegexp(ValueError, ">=1 operators"):
+ kronecker.LinearOperatorKronecker([])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py b/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py
new file mode 100644
index 0000000000..79080d194f
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py
@@ -0,0 +1,560 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Construct the Kronecker product of one or more `LinearOperators`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg_impl as linalg
+from tensorflow.python.ops.linalg import linear_operator
+
+
+def _vec(x):
+ """Stacks column of matrix to form a single column."""
+ return array_ops.reshape(
+ array_ops.matrix_transpose(x),
+ array_ops.concat(
+ [array_ops.shape(x)[:-2], [-1]], axis=0))
+
+
+def _unvec_by(y, num_col):
+ """Unstack vector to form a matrix, with a specified amount of columns."""
+ return array_ops.matrix_transpose(
+ array_ops.reshape(
+ y,
+ array_ops.concat(
+ [array_ops.shape(y)[:-1], [num_col, -1]], axis=0)))
+
+
+def _rotate_last_dim(x, rotate_right=False):
+ """Rotate the last dimension either left or right."""
+ ndims = array_ops.rank(x)
+ if rotate_right:
+ transpose_perm = array_ops.concat(
+ [[ndims - 1], math_ops.range(0, ndims - 1)], axis=0)
+ else:
+ transpose_perm = array_ops.concat(
+ [math_ops.range(1, ndims), [0]], axis=0)
+ return array_ops.transpose(x, transpose_perm)
+
+
+class LinearOperatorKronecker(linear_operator.LinearOperator):
+ """Kronecker product between two `LinearOperators`.
+
+ This operator composes one or more linear operators `[op1,...,opJ]`,
+ building a new `LinearOperator` representing the Kronecker product:
+ `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is
+ associative).
+
+ If `opj` has shape `batch_shape_j` + [M_j, N_j`, then the composed operator
+ will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`,
+ where the product is over all operators.
+
+ ```python
+ # Create a 4 x 4 linear operator composed of two 2 x 2 operators.
+ operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
+ operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]])
+ operator = LinearOperatorKronecker([operator_1, operator_2])
+
+ operator.to_dense()
+ ==> [[1., 2., 0., 0.],
+ [3., 4., 0., 0.],
+ [2., 4., 1., 2.],
+ [6., 8., 3., 4.]]
+
+ operator.shape
+ ==> [4, 4]
+
+ operator.log_abs_determinant()
+ ==> scalar Tensor
+
+ x = ... Shape [4, 2] Tensor
+ operator.matmul(x)
+ ==> Shape [4, 2] Tensor
+
+ # Create a [2, 3] batch of 4 x 5 linear operators.
+ matrix_45 = tf.random_normal(shape=[2, 3, 4, 5])
+ operator_45 = LinearOperatorFullMatrix(matrix)
+
+ # Create a [2, 3] batch of 5 x 6 linear operators.
+ matrix_56 = tf.random_normal(shape=[2, 3, 5, 6])
+ operator_56 = LinearOperatorFullMatrix(matrix_56)
+
+ # Compose to create a [2, 3] batch of 20 x 30 operators.
+ operator_large = LinearOperatorKronecker([operator_45, operator_56])
+
+ # Create a shape [2, 3, 20, 2] vector.
+ x = tf.random_normal(shape=[2, 3, 6, 2])
+ operator_large.matmul(x)
+ ==> Shape [2, 3, 30, 2] Tensor
+ ```
+
+ #### Performance
+
+ The performance of `LinearOperatorKronecker` on any operation is equal to
+ the sum of the individual operators' operations.
+
+ #### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint, positive_definite, square`.
+ These have the following meaning:
+
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ operators,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=None,
+ name=None):
+ r"""Initialize a `LinearOperatorKronecker`.
+
+ `LinearOperatorKronecker` is initialized with a list of operators
+ `[op_1,...,op_J]`.
+
+ Args:
+ operators: Iterable of `LinearOperator` objects, each with
+ the same `dtype` and composable shape, representing the Kronecker
+ factors.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the quadratic form `x^H A x` has positive real part for all
+ nonzero `x`. Note that we do not require the operator to be
+ self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name for this `LinearOperator`. Default is the individual
+ operators names joined with `_x_`.
+
+ Raises:
+ TypeError: If all operators do not have the same `dtype`.
+ ValueError: If `operators` is empty.
+ """
+ # Validate operators.
+ check_ops.assert_proper_iterable(operators)
+ operators = list(operators)
+ if not operators:
+ raise ValueError(
+ "Expected a list of >=1 operators. Found: %s" % operators)
+ self._operators = operators
+
+ # Validate dtype.
+ dtype = operators[0].dtype
+ for operator in operators:
+ if operator.dtype != dtype:
+ name_type = (str((o.name, o.dtype)) for o in operators)
+ raise TypeError(
+ "Expected all operators to have the same dtype. Found %s"
+ % " ".join(name_type))
+
+ # Auto-set and check hints.
+ # A Kronecker product is invertible, if and only if all factors are
+ # invertible.
+ if all(operator.is_non_singular for operator in operators):
+ if is_non_singular is False:
+ raise ValueError(
+ "The Kronecker product of non-singular operators is always "
+ "non-singular.")
+ is_non_singular = True
+
+ if all(operator.is_self_adjoint for operator in operators):
+ if is_self_adjoint is False:
+ raise ValueError(
+ "The Kronecker product of self-adjoint operators is always "
+ "self-adjoint.")
+ is_self_adjoint = True
+
+ # The eigenvalues of a Kronecker product are equal to the products of eigen
+ # values of the corresponding factors.
+ if all(operator.is_positive_definite for operator in operators):
+ if is_positive_definite is False:
+ raise ValueError("The Kronecker product of positive-definite operators "
+ "is always positive-definite.")
+ is_positive_definite = True
+
+ # Initialization.
+ graph_parents = []
+ for operator in operators:
+ graph_parents.extend(operator.graph_parents)
+
+ if name is None:
+ name = operators[0].name
+ for operator in operators[1:]:
+ name += "_x_" + operator.name
+ with ops.name_scope(name, values=graph_parents):
+ super(LinearOperatorKronecker, self).__init__(
+ dtype=dtype,
+ graph_parents=graph_parents,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+ @property
+ def operators(self):
+ return self._operators
+
+ def _shape(self):
+ # Get final matrix shape.
+ domain_dimension = self.operators[0].domain_dimension
+ for operator in self.operators[1:]:
+ domain_dimension *= operator.domain_dimension
+
+ range_dimension = self.operators[0].range_dimension
+ for operator in self.operators[1:]:
+ range_dimension *= operator.range_dimension
+
+ matrix_shape = tensor_shape.TensorShape([
+ range_dimension, domain_dimension])
+
+ # Get broadcast batch shape.
+ # broadcast_shape checks for compatibility.
+ batch_shape = self.operators[0].batch_shape
+ for operator in self.operators[1:]:
+ batch_shape = common_shapes.broadcast_shape(
+ batch_shape, operator.batch_shape)
+
+ return batch_shape.concatenate(matrix_shape)
+
+ def _shape_tensor(self):
+ domain_dimension = self.operators[0].domain_dimension_tensor()
+ for operator in self.operators[1:]:
+ domain_dimension *= operator.domain_dimension_tensor()
+
+ range_dimension = self.operators[0].range_dimension_tensor()
+ for operator in self.operators[1:]:
+ range_dimension *= operator.range_dimension_tensor()
+
+ matrix_shape = [range_dimension, domain_dimension]
+
+ # Get broadcast batch shape.
+ # broadcast_shape checks for compatibility.
+ batch_shape = self.operators[0].batch_shape_tensor()
+ for operator in self.operators[1:]:
+ batch_shape = array_ops.broadcast_dynamic_shape(
+ batch_shape, operator.batch_shape_tensor())
+
+ return array_ops.concat((batch_shape, matrix_shape), 0)
+
+ def _matmul(self, x, adjoint=False, adjoint_arg=False):
+ # Here we heavily rely on Roth's column Lemma [1]:
+ # (A x B) * vec X = vec BXA^T,
+ # where vec stacks all the columns of the matrix under each other. In our
+ # case, x represents a batch of vec X (i.e. we think of x as a batch of
+ # column vectors, rather than a matrix). Each member of the batch can be
+ # reshaped to a matrix (hence we get a batch of matrices).
+ # We can iteratively apply this lemma by noting that if B is a Kronecker
+ # product, then we can apply the lemma again.
+
+ # [1] W. E. Roth, "On direct product matrices,"
+ # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
+ # 1934
+
+ # Efficiency
+
+ # Naively doing the Kronecker product, by calculating the dense matrix and
+ # applying it will can take cubic time in the size of domain_dimension
+ # (assuming a square matrix). The other issue is that calculating the dense
+ # matrix can be prohibitively expensive, in that it can take a large amount
+ # of memory.
+ #
+ # This implementation avoids this memory blow up by only computing matmuls
+ # with the factors. In this way, we don't have to realize the dense matrix.
+ # In terms of complexity, if we have Kronecker Factors of size:
+ # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
+ # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
+ # With this approach (ignoring reshaping of tensors and transposes for now),
+ # the time complexity can be O(M * (\sum n_i) * N). There is also the
+ # benefit of batched multiplication (In this example, the batch size is
+ # roughly M * N) so this can be much faster. However, not factored in are
+ # the costs of the several transposing of tensors, which can affect cache
+ # behavior.
+
+ # Below we document the shape manipulation for adjoint=False,
+ # adjoint_arg=False, but the general case of different adjoints is still
+ # handled.
+
+ if adjoint_arg:
+ x = linalg.adjoint(x)
+
+ # Always add a batch dimension to enable broadcasting to work.
+ batch_shape = array_ops.concat(
+ [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
+ x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype)
+
+ # x has shape [B, R, C], where B represent some number of batch dimensions,
+ # R represents the number of rows, and C represents the number of columns.
+ # In order to apply Roth's column lemma, we need to operate on a batch of
+ # column vectors, so we reshape into a batch of column vectors. We put it
+ # at the front to ensure that broadcasting between operators to the batch
+ # dimensions B still works.
+ output = _rotate_last_dim(x, rotate_right=True)
+
+ # Also expand the shape to be [A, C, B, R]. The first dimension will be
+ # used to accumulate dimensions from each operator matmul.
+ output = output[array_ops.newaxis, ...]
+
+ # In this loop, A is going to refer to the value of the accumulated
+ # dimension. A = 1 at the start, and will end up being self.range_dimension.
+ # V will refer to the last dimension. V = R at the start, and will end up
+ # being 1 in the end.
+ for operator in self.operators[:-1]:
+ # Reshape output from [A, C, B, V] to be
+ # [A, C, B, V / op.domain_dimension, op.domain_dimension]
+ if adjoint:
+ operator_dimension = operator.range_dimension_tensor()
+ else:
+ operator_dimension = operator.domain_dimension_tensor()
+
+ output = _unvec_by(output, operator_dimension)
+
+ # We are computing (XA^T) = (AX^T)^T.
+ # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
+ # which is being converted to:
+ # [A, C, B, V / op.domain_dimension, op.range_dimension]
+ output = array_ops.matrix_transpose(output)
+ output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False)
+ output = array_ops.matrix_transpose(output)
+ # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
+ output = _rotate_last_dim(output, rotate_right=False)
+ output = _vec(output)
+ output = _rotate_last_dim(output, rotate_right=True)
+
+ # After the loop, we will have
+ # A = self.range_dimension / op[-1].range_dimension
+ # V = op[-1].domain_dimension
+
+ # We convert that using matvec to get:
+ # [A, C, B, op[-1].range_dimension]
+ output = self.operators[-1].matvec(output, adjoint=adjoint)
+ # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
+ output = _rotate_last_dim(output, rotate_right=False)
+ output = _vec(output)
+ output = _rotate_last_dim(output, rotate_right=False)
+
+ if x.shape.is_fully_defined():
+ column_dim = x.shape[-1]
+ broadcast_batch_shape = common_shapes.broadcast_shape(
+ x.shape[:-2], self.batch_shape)
+ if adjoint:
+ matrix_dimensions = [self.domain_dimension, column_dim]
+ else:
+ matrix_dimensions = [self.range_dimension, column_dim]
+
+ print("x: ", x)
+ print("bathc_shape:", self.batch_shape)
+ print("self.shape:", self.shape)
+ print("output: ", output)
+ output.set_shape(broadcast_batch_shape.concatenate(
+ matrix_dimensions))
+
+ return output
+
+ def _determinant(self):
+ # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m
+ # matrix, and X2 is an n x n matrix. We can iteratively apply this property
+ # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the
+ # domain dimension of all operators, then we have:
+ # |X1 x X2 x X3 ...| =
+ # |X1| ** (T / m) * |X2 x X3 ... | ** m =
+ # |X1| ** (T / m) * |X2| ** (m * (T / m) / n) * ... =
+ # |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n)
+ # And by doing induction we have product(|X_i| ** (T / dim(X_i))).
+ total = self.domain_dimension_tensor()
+ determinant = 1.
+ for operator in self.operators:
+ determinant *= operator.determinant() ** math_ops.cast(
+ total / operator.domain_dimension_tensor(),
+ dtype=operator.dtype)
+ return determinant
+
+ def _log_abs_determinant(self):
+ # This will be sum((total / dim(x_i)) * log |X_i|)
+ total = self.domain_dimension_tensor()
+ log_abs_det = 0.
+ for operator in self.operators:
+ log_abs_det += operator.log_abs_determinant() * math_ops.cast(
+ total / operator.domain_dimension_tensor(),
+ dtype=operator.dtype)
+ return log_abs_det
+
+ def _trace(self):
+ # tr(A x B) = tr(A) * tr(B)
+ trace = 1.
+ for operator in self.operators:
+ trace *= operator.trace()
+ return trace
+
+ def _solve(self, rhs, adjoint=False, adjoint_arg=False):
+ # Here we follow the same use of Roth's column lemma as in `matmul`, with
+ # the key difference that we replace all `matmul` instances with `solve`.
+ # This follows from the property that inv(A x B) = inv(A) x inv(B).
+
+ # Below we document the shape manipulation for adjoint=False,
+ # adjoint_arg=False, but the general case of different adjoints is still
+ # handled.
+
+ if adjoint_arg:
+ rhs = linalg.adjoint(rhs)
+
+ # Always add a batch dimension to enable broadcasting to work.
+ batch_shape = array_ops.concat(
+ [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
+ rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype)
+
+ # rhs has shape [B, R, C], where B represent some number of batch
+ # dimensions,
+ # R represents the number of rows, and C represents the number of columns.
+ # In order to apply Roth's column lemma, we need to operate on a batch of
+ # column vectors, so we reshape into a batch of column vectors. We put it
+ # at the front to ensure that broadcasting between operators to the batch
+ # dimensions B still works.
+ output = _rotate_last_dim(rhs, rotate_right=True)
+
+ # Also expand the shape to be [A, C, B, R]. The first dimension will be
+ # used to accumulate dimensions from each operator matmul.
+ output = output[array_ops.newaxis, ...]
+
+ # In this loop, A is going to refer to the value of the accumulated
+ # dimension. A = 1 at the start, and will end up being self.range_dimension.
+ # V will refer to the last dimension. V = R at the start, and will end up
+ # being 1 in the end.
+ for operator in self.operators[:-1]:
+ # Reshape output from [A, C, B, V] to be
+ # [A, C, B, V / op.domain_dimension, op.domain_dimension]
+ if adjoint:
+ operator_dimension = operator.range_dimension_tensor()
+ else:
+ operator_dimension = operator.domain_dimension_tensor()
+
+ output = _unvec_by(output, operator_dimension)
+
+ # We are computing (XA^-1^T) = (A^-1 X^T)^T.
+ # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
+ # which is being converted to:
+ # [A, C, B, V / op.domain_dimension, op.range_dimension]
+ output = array_ops.matrix_transpose(output)
+ output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
+ output = array_ops.matrix_transpose(output)
+ # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
+ output = _rotate_last_dim(output, rotate_right=False)
+ output = _vec(output)
+ output = _rotate_last_dim(output, rotate_right=True)
+
+ # After the loop, we will have
+ # A = self.range_dimension / op[-1].range_dimension
+ # V = op[-1].domain_dimension
+
+ # We convert that using matvec to get:
+ # [A, C, B, op[-1].range_dimension]
+ output = self.operators[-1].solvevec(output, adjoint=adjoint)
+ # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
+ output = _rotate_last_dim(output, rotate_right=False)
+ output = _vec(output)
+ output = _rotate_last_dim(output, rotate_right=False)
+
+ if rhs.shape.is_fully_defined():
+ column_dim = rhs.shape[-1]
+ broadcast_batch_shape = common_shapes.broadcast_shape(
+ rhs.shape[:-2], self.batch_shape)
+ if adjoint:
+ matrix_dimensions = [self.domain_dimension, column_dim]
+ else:
+ matrix_dimensions = [self.range_dimension, column_dim]
+
+ output.set_shape(broadcast_batch_shape.concatenate(
+ matrix_dimensions))
+
+ return output
+
+ def _diag_part(self):
+ diag_part = self.operators[0].diag_part()
+ for operator in self.operators[1:]:
+ diag_part = diag_part[..., :, array_ops.newaxis]
+ op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
+ diag_part *= op_diag_part
+ diag_part = array_ops.reshape(
+ diag_part,
+ shape=array_ops.concat(
+ [array_ops.shape(diag_part)[:-2], [-1]], axis=0))
+ if self.range_dimension > self.domain_dimension:
+ diag_dimension = self.domain_dimension
+ else:
+ diag_dimension = self.range_dimension
+ diag_part.set_shape(
+ self.batch_shape.concatenate(diag_dimension))
+ return diag_part
+
+ def _to_dense(self):
+ product = self.operators[0].to_dense()
+ for operator in self.operators[1:]:
+ # Product has shape [B, R1, 1, C1].
+ product = product[
+ ..., :, array_ops.newaxis, :, array_ops.newaxis]
+ # Operator has shape [B, 1, R2, 1, C2].
+ op_to_mul = operator.to_dense()[
+ ..., array_ops.newaxis, :, array_ops.newaxis, :]
+ # This is now [B, R1, R2, C1, C2].
+ product *= op_to_mul
+ # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
+ product = array_ops.reshape(
+ product,
+ shape=array_ops.concat(
+ [array_ops.shape(product)[:-4],
+ [array_ops.shape(product)[-4] * array_ops.shape(product)[-3],
+ array_ops.shape(product)[-2] * array_ops.shape(product)[-1]]
+ ], axis=0))
+ product.set_shape(self.shape)
+ return product
+
+ def _assert_non_singular(self):
+ if all(operator.is_square for operator in self.operators):
+ asserts = [operator.assert_non_singular() for operator in self.operators]
+ return control_flow_ops.group(asserts)
+ else:
+ raise errors.InvalidArgumentError(
+ node_def=None, op=None, message="All Kronecker factors must be "
+ "square for the product to be invertible.")
+
+ def _assert_self_adjoint(self):
+ if all(operator.is_square for operator in self.operators):
+ asserts = [operator.assert_self_adjoint() for operator in self.operators]
+ return control_flow_ops.group(asserts)
+ else:
+ raise errors.InvalidArgumentError(
+ node_def=None, op=None, message="All Kronecker factors must be "
+ "square for the product to be self adjoint.")
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ac50699f59..6e6c812adc 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -105,11 +105,13 @@ def make_example_dict(example_protos, example_weights):
def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
random.seed(1)
+
sparse_features = [
SparseFeatureColumn(
- [int(i / num_non_zero) for i in range(num_examples * num_non_zero)],
- [int(random.random() * dim) for _ in range(
- num_examples * num_non_zero)],
+ [i for i in range(num_examples) for _ in range(num_non_zero)], [
+ i for _ in range(num_examples)
+ for i in random.sample(range(dim), num_non_zero)
+ ],
[num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)])
]
examples_dict = dict(
@@ -289,6 +291,34 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
# It would be 0.01 without shuffling and 0.02 with adaptive sampling.
self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3)
+ def testSparseDuplicate(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0] * 5,
+ 'gender': [0] * 5
+ }, 0),
+ make_example_proto({
+ 'age': [1] * 5,
+ 'gender': [1] * 5
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ train_op = lr.minimize()
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'Duplicate'):
+ train_op.run()
+
def testDistributedSimple(self):
# Setup test data
example_protos = [
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 9c4533079c..1534f97d76 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -137,6 +137,7 @@ cc_library(
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ "//tensorflow/contrib/lite/profiling:profiler",
"//tensorflow/contrib/lite/schema:schema_fbs",
],
)
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index f258654608..91b6c414bf 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/interpreter.h"
+
#include <cassert>
#include <cstdarg>
#include <cstdint>
#include <cstring>
+
#include "tensorflow/contrib/lite/arena_planner.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#include "tensorflow/contrib/lite/profiling/profiler.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/util.h"
@@ -245,11 +248,8 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
// Initialize the output tensors's delegate-related fields.
for (int tensor_index : subgraph.output_tensors) {
TfLiteTensor* tensor = &tensors_[tensor_index];
- TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr);
- TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle,
- kTfLiteNullBufferHandle);
- // buffer_handle will be filled in delegate's `Prepare`
- // function.
+ TF_LITE_ENSURE(&context_, tensor->delegate == nullptr ||
+ tensor->delegate == delegate);
tensor->delegate = delegate;
}
@@ -547,6 +547,7 @@ TfLiteStatus Interpreter::Invoke() {
TfLiteNode& node = nodes_and_registration_[node_index].first;
const TfLiteRegistration& registration =
nodes_and_registration_[node_index].second;
+ SCOPED_OPERATOR_PROFILE(profiler_, node_index);
// TODO(ycling): This is an extra loop through inputs to check if the data
// need to be copied from Delegate buffer to raw memory, which is often not
@@ -570,6 +571,12 @@ TfLiteStatus Interpreter::Invoke() {
}
}
+ if (!allow_buffer_handle_output_) {
+ for (int tensor_index : outputs_) {
+ EnsureTensorDataIsReadable(tensor_index);
+ }
+ }
+
return status;
}
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index df67cce9de..a49134b95e 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -20,10 +20,12 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <vector>
+
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/memory_planner.h"
+#include "tensorflow/contrib/lite/profiling/profiler.h"
namespace tflite {
@@ -282,6 +284,7 @@ class Interpreter {
// Ensure the data in `tensor.data` is readable. In case delegate is used,
// it might require to copy the data from delegate buffer to raw memory.
+ // WARNING: This is an experimental API and subject to change.
TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
TfLiteTensor* tensor = &tensors_[tensor_index];
@@ -320,6 +323,12 @@ class Interpreter {
TfLiteBufferHandle* buffer_handle,
TfLiteDelegate** delegate);
+ void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }
+
+ profiling::Profiler* GetProfiler(profiling::Profiler* profiler) {
+ return profiler_;
+ }
+
// The default capacity of `tensors_` vector.
static constexpr int kTensorsReservedCapacity = 128;
// The capacity headroom of `tensors_` vector before calling ops'
@@ -328,6 +337,18 @@ class Interpreter {
// pointers to existing tensors.
static constexpr int kTensorsCapacityHeadroom = 16;
+ // Set if buffer handle output is allowed.
+ //
+ // When using hardware delegation, Interpreter will make the data of output
+ // tensors available in `tensor->data` by default. If the application can
+ // consume the buffer handle directly (e.g. reading output from OpenGL
+ // texture), it can set this flag to false, so Interpreter won't copy the data
+ // from buffer handle to CPU memory.
+ // WARNING: This is an experimental API and subject to change.
+ void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
+ allow_buffer_handle_output_ = allow_buffer_handle_output;
+ }
+
private:
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
@@ -518,6 +539,11 @@ class Interpreter {
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
std::unique_ptr<MemoryPlanner> memory_planner_;
+
+ bool allow_buffer_handle_output_ = false;
+
+ // Profiler for this interpreter instance.
+ profiling::Profiler* profiler_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index ac7c3f071f..8cfa7e53d1 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -825,8 +825,7 @@ tf_cc_test(
"comparisons_test.cc",
],
tags = [
- "tflite_not_portable_ios_arm64",
- "tflite_not_portable_ios_x86_64",
+ "tflite_not_portable_ios",
],
deps = [
":builtin_ops",
diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD
new file mode 100644
index 0000000000..15999e5d41
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/BUILD
@@ -0,0 +1,44 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+common_copts = [
+ "-Wall",
+]
+
+cc_library(
+ name = "profiler",
+ hdrs = ["profiler.h"],
+ copts = common_copts,
+ deps = [":profile_buffer"],
+)
+
+cc_test(
+ name = "profiler_test",
+ srcs = ["profiler_test.cc"],
+ copts = ["-DTFLITE_PROFILING_ENABLED"],
+ defines = ["TFLITE_PROFILING_ENABLED"],
+ deps = [
+ ":profiler",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "profile_buffer",
+ hdrs = ["profile_buffer.h"],
+ copts = common_copts,
+)
+
+cc_test(
+ name = "profile_buffer_test",
+ srcs = ["profile_buffer_test.cc"],
+ copts = ["-DTFLITE_PROFILING_ENABLED"],
+ defines = ["TFLITE_PROFILING_ENABLED"],
+ deps = [
+ ":profile_buffer",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h
new file mode 100644
index 0000000000..3bfe02571b
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_buffer.h
@@ -0,0 +1,150 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_
+
+#include <cstddef>
+#include <cstdint>
+
+namespace tflite {
+namespace profiling {
+
+// A profiling event.
+struct ProfileEvent {
+ // Describes the type of event.
+ // The event_metadata field may contain additional data for interpreting
+ // the event.
+ enum class EventType {
+ // Default event type, the metadata field has no special significance.
+ DEFAULT = 0,
+ // The event is an operator invocation and the event_metadata field is the
+ // index of operator node.
+ OPERATOR_INVOKE_EVENT = 1
+ };
+
+ // Label of the event. This usually describes the event.
+ const char* tag;
+ // Timestamp in microseconds when the event began.
+ int64_t begin_timestamp_ms;
+ // Timestamp in microseconds when the event ended.
+ int64_t end_timestamp_ms;
+ // The field containing the type of event. This must be one of the event types
+ // in EventType.
+ EventType event_type;
+ // Extra data describing the details of the event.
+ uint32_t event_metadata;
+};
+} // namespace profiling
+} // namespace tflite
+
+#ifdef TFLITE_PROFILING_ENABLED
+
+#include <sys/time.h>
+#include <vector>
+
+namespace tflite {
+namespace profiling {
+constexpr uint32_t kInvalidEventHandle = static_cast<uint32_t>(~0) - 1;
+
+// A ring buffer of profile events.
+// This class is not thread safe.
+class ProfileBuffer {
+ public:
+ ProfileBuffer(uint32_t max_num_entries, bool enabled)
+ : enabled_(enabled), current_index_(0), event_buffer_(max_num_entries) {}
+
+ // Adds an event to the buffer with begin timestamp set to the current
+ // timestamp. Returns a handle to event that can be used to call EndEvent. If
+ // buffer is disabled this has no affect.
+ // The tag of the event should remain valid till the buffer is valid.
+ uint32_t BeginEvent(const char* tag, ProfileEvent::EventType event_type,
+ uint32_t event_metadata) {
+ if (!enabled_) {
+ return kInvalidEventHandle;
+ }
+ int64_t timestamp = NowMicros();
+ int index = current_index_ % event_buffer_.size();
+ event_buffer_[index].tag = tag;
+ event_buffer_[index].event_type = event_type;
+ event_buffer_[index].event_metadata = event_metadata;
+ event_buffer_[index].begin_timestamp_ms = timestamp;
+ event_buffer_[index].end_timestamp_ms = 0;
+ current_index_++;
+ return index;
+ }
+
+ // Sets the enabled state of buffer to |enabled|
+ void SetEnabled(bool enabled) { enabled_ = enabled; }
+
+ // Sets the end timestamp for event for the handle to current time.
+ // If the buffer is disabled or previous event has been overwritten this
+ // operation has not effect.
+ void EndEvent(uint32_t event_handle) {
+ if (!enabled_ || event_handle == kInvalidEventHandle ||
+ event_handle > current_index_) {
+ return;
+ }
+ const uint32_t max_size = event_buffer_.size();
+ if (current_index_ > (max_size + event_handle)) {
+ // Ignore, buffer has already overflowed.
+ return;
+ }
+
+ int event_index = event_handle % max_size;
+ event_buffer_[event_index].end_timestamp_ms = NowMicros();
+ }
+
+ // Returns the size of the buffer.
+ size_t Size() const {
+ return (current_index_ >= event_buffer_.size()) ? event_buffer_.size()
+ : current_index_;
+ }
+
+ // Resets the buffer.
+ void Reset() {
+ enabled_ = false;
+ current_index_ = 0;
+ }
+
+ // Returns the profile event at the given index. If the index is invalid a
+ // nullptr is returned. The return event may get overwritten if more events
+ // are added to buffer.
+ const struct ProfileEvent* const At(int index) const {
+ size_t size = Size();
+ if (index >= size) {
+ return nullptr;
+ }
+ const uint32_t max_size = event_buffer_.size();
+ uint32_t start =
+ (current_index_ > max_size) ? current_index_ % max_size : max_size;
+ index = (index + start) % max_size;
+ return &event_buffer_[index];
+ }
+
+ private:
+ static int64_t NowMicros() {
+ // TODO(shashishekhar): Refactor this to a separate file.
+ struct timeval tv;
+ gettimeofday(&tv, nullptr);
+ return static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
+ }
+ bool enabled_;
+ uint32_t current_index_;
+ std::vector<ProfileEvent> event_buffer_;
+};
+} // namespace profiling
+} // namespace tflite
+#endif // TFLITE_PROFILING_ENABLED
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_
diff --git a/tensorflow/contrib/lite/profiling/profile_buffer_test.cc b/tensorflow/contrib/lite/profiling/profile_buffer_test.cc
new file mode 100644
index 0000000000..0c5f0cd314
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_buffer_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/profiling/profile_buffer.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace profiling {
+
+namespace {
+
+std::vector<const ProfileEvent*> GetProfileEvents(const ProfileBuffer& buffer) {
+ std::vector<const ProfileEvent*> events;
+ for (auto i = 0; i < buffer.Size(); i++) {
+ events.push_back(buffer.At(i));
+ }
+ return events;
+}
+
+TEST(ProfileBufferTest, Empty) {
+ ProfileBuffer buffer(/*max_size*/ 0, /*enabled*/ true);
+ EXPECT_EQ(0, buffer.Size());
+}
+
+TEST(ProfileBufferTest, AddEvent) {
+ ProfileBuffer buffer(/*max_size*/ 10, /*enabled*/ true);
+ EXPECT_EQ(0, buffer.Size());
+ auto event_handle = buffer.BeginEvent(
+ "hello", ProfileEvent::EventType::DEFAULT, /* event_metadata */ 42);
+
+ EXPECT_GE(event_handle, 0);
+ EXPECT_EQ(1, buffer.Size());
+
+ auto event = GetProfileEvents(buffer)[0];
+ EXPECT_EQ(event->tag, "hello");
+ EXPECT_GT(event->begin_timestamp_ms, 0);
+ EXPECT_EQ(event->event_type, ProfileEvent::EventType::DEFAULT);
+ EXPECT_EQ(event->event_metadata, 42);
+
+ buffer.EndEvent(event_handle);
+ EXPECT_EQ(1, buffer.Size());
+ EXPECT_GE(event->end_timestamp_ms, event->begin_timestamp_ms);
+}
+
+TEST(ProfileBufferTest, OverFlow) {
+ const int max_size = 4;
+ ProfileBuffer buffer{max_size, true};
+ std::vector<std::string> eventNames = {"first", "second", "third", "fourth"};
+ for (int i = 0; i < 2 * max_size; i++) {
+ buffer.BeginEvent(eventNames[i % 4].c_str(),
+ ProfileEvent::EventType::DEFAULT, i);
+ size_t expected_size = std::min(i + 1, max_size);
+ EXPECT_EQ(expected_size, buffer.Size());
+ }
+ EXPECT_EQ(max_size, buffer.Size());
+ for (int j = 0; j < buffer.Size(); ++j) {
+ auto event = buffer.At(j);
+ EXPECT_EQ(eventNames[j % 4], event->tag);
+ EXPECT_EQ(ProfileEvent::EventType::DEFAULT, event->event_type);
+ EXPECT_EQ(4 + j, event->event_metadata);
+ }
+}
+
+TEST(ProfileBufferTest, Enable) {
+ ProfileBuffer buffer(/*max_size*/ 10, /*enabled*/ false);
+ EXPECT_EQ(0, buffer.Size());
+ auto event_handle = buffer.BeginEvent(
+ "hello", ProfileEvent::EventType::DEFAULT, /* event_metadata */ 42);
+ EXPECT_EQ(kInvalidEventHandle, event_handle);
+ EXPECT_EQ(0, buffer.Size());
+ buffer.SetEnabled(true);
+ event_handle = buffer.BeginEvent("hello", ProfileEvent::EventType::DEFAULT,
+ /* event_metadata */ 42);
+ EXPECT_GE(event_handle, 0);
+ EXPECT_EQ(1, buffer.Size());
+}
+
+} // namespace
+} // namespace profiling
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/profiling/profiler.h b/tensorflow/contrib/lite/profiling/profiler.h
new file mode 100644
index 0000000000..dfa98a6708
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profiler.h
@@ -0,0 +1,174 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/profiling/profile_buffer.h"
+
+#ifdef TFLITE_PROFILING_ENABLED
+
+namespace tflite {
+namespace profiling {
+class ScopedProfile;
+class ScopedOperatorProfile;
+
+// Controls whether profiling is enabled or disabled and collects profiles.
+// TFLite is used on platforms that don't have posix threads, so the profiler is
+// kept as simple as possible. It is designed to be used only on a single
+// thread.
+//
+// Profiles are collected using Scoped*Profile objects that begin and end a
+// profile event.
+// An example usage is shown in the example below:
+//
+// Say Worker class has a DoWork method and we are interested in profiling
+// the overall execution time for DoWork and time spent in Task1 and Task2
+// functions.
+//
+// class Worker {
+// public:
+// void DoWork() {
+// ScopedProfile(&controller, "DoWork");
+// Task1();
+// Task2();
+// .....
+// }
+//
+// void Task1() {
+// ScopedProfile(&controller, "Task1");
+// ....
+// }
+//
+// void Task2() {
+// ScopedProfile(&controller, "Task2");
+// }
+//
+// Profiler profiler;
+// }
+//
+// We instrument the functions that need to be profiled.
+//
+// Profile can be collected by enable profiling and then getting profile
+// events.
+//
+// void ProfileWorker() {
+// Worker worker;
+// worker.profiler.EnableProfiling();
+// worker.DoWork();
+// worker.profiler.DisableProfiling();
+// // Profiling is complete, extract profiles.
+// auto profile_events = worker.profiler.GetProfiles();
+// }
+//
+//
+class Profiler {
+ public:
+ Profiler() : buffer_(1024, false) {}
+
+ void StartProfiling() { buffer_.SetEnabled(true); }
+ void StopProfiling() { buffer_.SetEnabled(false); }
+ void Reset() { buffer_.Reset(); }
+ std::vector<const ProfileEvent*> GetProfileEvents() {
+ std::vector<const ProfileEvent*> profile_events;
+ profile_events.reserve(buffer_.Size());
+ for (int i = 0; i < buffer_.Size(); i++) {
+ profile_events.push_back(buffer_.At(i));
+ }
+ return profile_events;
+ }
+
+ private:
+ friend class ScopedProfile;
+ friend class ScopedOperatorProfile;
+ ProfileBuffer* GetProfileBuffer() { return &buffer_; }
+ ProfileBuffer buffer_;
+};
+
+class ScopedProfile {
+ public:
+ // Adds a profile event to profile that begins with the construction
+ // of object and ends when the object goes out of scope.
+ // The lifetime of tag should be at least the lifetime of profiler.
+ ScopedProfile(Profiler* profiler, const char* tag) {
+ if (profiler) {
+ buffer_ = profiler->GetProfileBuffer();
+ event_handle_ =
+ buffer_->BeginEvent(tag, ProfileEvent::EventType::DEFAULT, 0);
+ }
+ }
+ ~ScopedProfile() {
+ if (buffer_) {
+ buffer_->EndEvent(event_handle_);
+ }
+ }
+
+ private:
+ ProfileBuffer* buffer_;
+ int32_t event_handle_;
+};
+
+class ScopedOperatorProfile {
+ public:
+ // Adds a profile event to profile that begins with the construction
+ // of object and ends when the object goes out of scope.
+ // The lifetime of tag should be at least the lifetime of profiler.
+ ScopedOperatorProfile(Profiler* profiler, const char* tag, int node_index) {
+ if (profiler) {
+ buffer_ = profiler->GetProfileBuffer();
+ event_handle_ = buffer_->BeginEvent(
+ tag, ProfileEvent::EventType::OPERATOR_INVOKE_EVENT, node_index);
+ }
+ }
+
+ ~ScopedOperatorProfile() {
+ if (buffer_) {
+ buffer_->EndEvent(event_handle_);
+ }
+ }
+
+ private:
+ ProfileBuffer* buffer_;
+ int32_t event_handle_;
+};
+
+} // namespace profiling
+} // namespace tflite
+
+#define SCOPED_OPERATOR_PROFILE(profiler, node_index) \
+ tflite::profiling::ScopedOperatorProfile _profile((profiler), "OpInvoke", \
+ (node_index))
+#else
+
+namespace tflite {
+namespace profiling {
+// A noop version of profiler when profiling is disabled.
+class Profiler {
+ public:
+ Profiler() {}
+ void StartProfiling() {}
+ void StopProfiling() {}
+ void Reset() {}
+ std::vector<const ProfileEvent*> GetProfileEvents() { return {}; }
+};
+} // namespace profiling
+} // namespace tflite
+
+#define SCOPED_OPERATOR_PROFILE(profiler, node_index)
+
+#endif // TFLITE_PROFILING_ENABLED
+
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_
diff --git a/tensorflow/contrib/lite/profiling/profiler_test.cc b/tensorflow/contrib/lite/profiling/profiler_test.cc
new file mode 100644
index 0000000000..994523a8fb
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profiler_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+
+#include <chrono> // NOLINT(build/c++11)
+#include <cmath>
+#include <thread> // NOLINT(build/c++11)
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace profiling {
+namespace {
+
+void AssertDurationOfEventAroundMs(const ProfileEvent* event,
+ double expected_ms, double eps_ms) {
+ double duration_ms =
+ (event->end_timestamp_ms - event->begin_timestamp_ms) / 1e3;
+ EXPECT_NEAR(expected_ms, duration_ms, eps_ms);
+}
+
+void SleepForQuarterSecond(Profiler* profiler) {
+ ScopedProfile profile(profiler, "SleepForQuarter");
+ std::this_thread::sleep_for(std::chrono::milliseconds(250));
+}
+
+void ChildFunction(Profiler* profiler) {
+ ScopedProfile profile(profiler, "Child");
+ SleepForQuarterSecond(profiler);
+}
+
+void ParentFunction(Profiler* profiler) {
+ ScopedProfile profile(profiler, "Parent");
+ for (int i = 0; i < 2; i++) {
+ ChildFunction(profiler);
+ }
+}
+
+TEST(ProfilerTest, NoProfilesAreCollectedWhenDisabled) {
+ Profiler profiler;
+ ParentFunction(&profiler);
+ auto profile_events = profiler.GetProfileEvents();
+ EXPECT_EQ(0, profile_events.size());
+}
+
+TEST(ProfilingTest, ProfilesAreCollected) {
+ Profiler profiler;
+ profiler.StartProfiling();
+ ParentFunction(&profiler);
+ profiler.StopProfiling();
+ auto profile_events = profiler.GetProfileEvents();
+ // ParentFunction calls the ChildFunction 2 times.
+ // Each ChildFunction calls SleepForQuarterSecond once.
+ // We expect 1 entry for ParentFunction, 2 for ChildFunction and 2 for
+ // SleepForQuarterSecond: Total: 1+ 2 + 2 = 5
+ // Profiles should look like:
+ // Parent ~ 500 ms (due to 2 Child calls)
+ // - Child ~ 250 ms (due to SleepForQuarter calls)
+ // - SleepForQuarter ~ 250ms
+ // - Child ~ 250 ms (due to SleepForQuarter calls)
+ // - SleepForQuarter ~ 250ms
+ //
+ ASSERT_EQ(5, profile_events.size());
+ EXPECT_EQ("Parent", profile_events[0]->tag);
+ EXPECT_EQ("Child", profile_events[1]->tag);
+ EXPECT_EQ("SleepForQuarter", profile_events[2]->tag);
+ EXPECT_EQ("Child", profile_events[3]->tag);
+ EXPECT_EQ("SleepForQuarter", profile_events[4]->tag);
+
+ AssertDurationOfEventAroundMs(profile_events[0], /*expected_ms*/ 500,
+ /*eps_ms*/ 2);
+ AssertDurationOfEventAroundMs(profile_events[1], /*expected_ms*/ 250,
+ /*eps_ms*/ 2);
+ AssertDurationOfEventAroundMs(profile_events[2], /*expected_ms*/ 250,
+ /*eps_ms*/ 2);
+ AssertDurationOfEventAroundMs(profile_events[3], /*expected_ms*/ 250,
+ /*eps_ms*/ 2);
+ AssertDurationOfEventAroundMs(profile_events[4], /*expected_ms*/ 250,
+ /*eps_ms*/ 2);
+}
+
+} // namespace
+} // namespace profiling
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 5b86e4e5ae..398978b145 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -238,6 +238,7 @@ cc_library(
"graph_transformations/merge_reshape_into_preceding_transpose.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
+ "graph_transformations/propagate_fake_quant_num_bits.cc",
"graph_transformations/propagate_fixed_sizes.cc",
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
@@ -249,6 +250,7 @@ cc_library(
"graph_transformations/remove_trivial_binary.cc",
"graph_transformations/remove_trivial_concatenation.cc",
"graph_transformations/remove_trivial_concatenation_input.cc",
+ "graph_transformations/remove_trivial_fake_quant.cc",
"graph_transformations/remove_trivial_passthrough.cc",
"graph_transformations/remove_trivial_passthrough.h",
"graph_transformations/remove_trivial_quantized_activation_func.cc",
@@ -303,7 +305,7 @@ cc_library(
":runtime",
":toco_port",
":tooling_util",
- ":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -378,7 +380,6 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
- "//tensorflow/contrib/lite/kernels/internal:quantization_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@protobuf_archive//:protobuf_headers",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 7a7059e357..71e7318ac3 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -237,6 +237,7 @@ struct ParsedTocoFlags {
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
+ Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index c8352741b4..c289ddcd92 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -95,10 +95,8 @@ Color GetColorForArray(const Model& model, const string& array_name) {
array_name == dump_options.graphviz_last_array) {
return Color(0x9E, 0x9E, 0x9E);
}
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return Color(0x9E, 0x9E, 0x9E);
- }
+ if (IsOutputArray(model, array_name)) {
+ return Color(0x9E, 0x9E, 0x9E);
}
// Remaining arrays are intermediate activation arrays.
// Lighter tone of the same grey as for input/output arrays:
@@ -119,6 +117,12 @@ void AppendArrayVal(string* string, Array const& array, int index) {
return;
}
AppendF(string, "%d", data[index]);
+ } else if (array.buffer->type == ArrayDataType::kInt16) {
+ const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
+ if (index >= data.size()) {
+ return;
+ }
+ AppendF(string, "%d", data[index]);
} else if (array.buffer->type == ArrayDataType::kInt32) {
const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
if (index >= data.size()) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
index badefeca88..708ecf6e0a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -47,7 +47,7 @@ bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kDepthwiseConv ||
op->type == OperatorType::kFullyConnected) {
if (ProcessLinearOperator(model, op)) {
- AddMessageF("Added bias vector to %s", LogName(*op));
+ AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index dbf029a853..56b3dec5c4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -135,6 +135,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
+DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
@@ -144,6 +145,7 @@ DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedMinMax)
@@ -163,7 +165,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
-DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
@@ -210,6 +211,23 @@ class RemoveTrivialReshape : public GraphTransformation {
bool treat_expand_dims_as_trivial_ = false;
};
+class ResolveConstantFakeQuant : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "ResolveConstantFakeQuant"; }
+
+ // True if the num_bits should adjust the final data type.
+ bool propagate_fake_quant_num_bits() const {
+ return propagate_fake_quant_num_bits_;
+ }
+ void set_propagate_fake_quant_num_bits(bool val) {
+ propagate_fake_quant_num_bits_ = val;
+ }
+
+ private:
+ bool propagate_fake_quant_num_bits_ = false;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 183b3d3f2e..45d9f73a1e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
new file mode 100644
index 0000000000..0bce183c18
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -0,0 +1,307 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
+ ArrayDataType new_data_type,
+ const MinMax* new_minmax) {
+ // Ensure the array ends up in the new type (if it hasn't yet been quantized).
+ array->final_data_type = new_data_type;
+
+ if (array->minmax && array->quantization_params) {
+ // The array is already quantized and has min/max info.
+ // As we are changing the data type we need to fix up the existing min/max
+ // to the new data type range.
+
+ double old_quantized_min, old_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(
+ array->data_type, &old_quantized_min, &old_quantized_max))
+ << "Existing data type is not quantized: "
+ << ArrayDataTypeName(array->data_type);
+ double new_quantized_min, new_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min,
+ &new_quantized_max))
+ << "New data type is not quantized: "
+ << ArrayDataTypeName(new_data_type);
+
+ // Compute new minmax values.
+ double min = (old_quantized_min - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ double max =
+ (old_quantized_max + 1 - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ max = max - 1.0 / (new_quantized_max + 1);
+
+ auto& array_minmax = array->GetOrCreateMinMax();
+ transformation->AddMessageF(
+ "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
+ array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
+ ArrayDataTypeName(new_data_type));
+
+ array_minmax.min = min;
+ array_minmax.max = max;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ array_minmax, array->quantization_params.get());
+
+ // Directly change the type as the array was already quantized.
+ array->data_type = new_data_type;
+ } else {
+ // Array has not yet been quantized so we can just set the final data type
+ // and assign the new min/max value (if provided).
+ CHECK(!array->quantization_params);
+
+ if (!array->minmax && new_minmax) {
+ transformation->AddMessageF("Forcing new minmax to %g,%g (%s)",
+ new_minmax->min, new_minmax->max,
+ ArrayDataTypeName(new_data_type));
+ auto& array_minmax = array->GetOrCreateMinMax();
+ array_minmax.min = new_minmax->min;
+ array_minmax.max = new_minmax->max;
+ }
+ }
+}
+
+// Returns true if the op blocks our backward recursive data type propagation.
+bool DoesOpBlockBackwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kConcatenation:
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Concat shouldn't block propagation, but we do expect that all inputs
+ // have the same range.
+ return false;
+ case OperatorType::kDequantize:
+ // Dequantize ops are inserted between the value we care about and the
+ // FakeQuant so make sure we move across them.
+ case OperatorType::kGather:
+ // Gathers need their parameters changed to the appropriate data type.
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Reshapes and transposes don't change values.
+ return false;
+ default:
+ return true;
+ }
+}
+
+// Returns true if the input of an op blocks our backward recursive data type
+// propagation.
+bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
+ switch (op.type) {
+ case OperatorType::kGather:
+ // Ignore gather indices.
+ return input_index != 0;
+ break;
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Ignore reshape/transpose shapes/dimensions.
+ return input_index != 0;
+ default:
+ return false;
+ }
+}
+
+// Propagates the data type up into the input arrays if they are model inputs
+// that may need their type changed. May act recursively if the inputs are
+// produced by ops that we can move over (such as Dequantize).
+bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type,
+ const MinMax& new_minmax) {
+ bool did_change = false;
+ for (int input_index = 0; input_index < op->inputs.size(); ++input_index) {
+ const auto& input = op->inputs[input_index];
+ auto& input_array = model->GetArray(input);
+ if (input_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ // Prevent moving into constant param args that we don't want to modify.
+ if (DoesOpInputBlockBackwardPropagation(*op, input_index)) {
+ continue;
+ }
+
+ if (input_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting input final data type of array %s from %s to %s", input,
+ ArrayDataTypeName(input_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &input_array, new_data_type,
+ &new_minmax);
+
+ // Walk up into all ops producing the inputs to this op.
+ for (auto& producing_op : model->operators) {
+ if (!DoesOpBlockBackwardPropagation(*producing_op)) {
+ for (const auto& output : producing_op->outputs) {
+ if (input == output) {
+ did_change |= RecursivelyBackwardPropagateDataType(
+ transformation, model, producing_op.get(), new_data_type,
+ new_minmax);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+// Returns true if the op blocks our forward recursive data type propagation.
+bool DoesOpBlockForwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kFakeQuant:
+ // Always stop at another FakeQuant, as it will likely have different
+ // parameters.
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Recurses down the graph setting the data type of all arrays until an operator
+// that blocks propagation (like another FakeQuant) or a final_data_type is
+// already specified.
+bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type) {
+ bool did_change = false;
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (output_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ if (output_array.final_data_type == ArrayDataType::kNone ||
+ output_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting output final data type of array %s from %s to %s", output,
+ ArrayDataTypeName(output_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &output_array, new_data_type,
+ nullptr);
+
+ // Walk down into all ops consuming the output of this op.
+ for (auto& consuming_op : model->operators) {
+ if (!DoesOpBlockForwardPropagation(*consuming_op)) {
+ for (const auto& input : consuming_op->inputs) {
+ if (input == output) {
+ did_change |= RecursivelyForwardPropagateDataType(
+ transformation, model, consuming_op.get(), new_data_type);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+} // namespace
+
+// Propagates the num_bits on a FakeQuant operator into the final data types
+// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know
+// the output must be int16 and assume all inputs up until the preceding op are
+// also 16.
+//
+// This can be thought of as a bidirectional flood-fill of the num_bits implied
+// final_data_type that terminates at other FakeQuant ops (and a few others as
+// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation).
+// Once all FakeQuant ops have been visted the arrays should all have
+// appropriate final_data_types if the source graph was annotated with the
+// proper FakeQuant ops.
+//
+// Annotating a graph requires following a few hard rules:
+// - every input MUST have a FakeQuant immediately following it
+// - every output MUST have a FakeQuant immediately preceding it
+// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant
+// immediately following it
+// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc)
+// MUST have FakeQuants between them and the consuming op
+// Additional FakeQuants may be used if desired, especially in areas that may
+// suffer from large precision changes - such as between a Softmax and a
+// FullyConnected. Only by validating accuracy differences between float
+// inference with the FakeQuant ops simulating quantization and the actually
+// quantized graph can you be sure the appropriate FakeQuant ops are present.
+//
+// You can tell if you're missing some FakeQuants by looking for warnings from
+// quantize.cc about minmax ranges being determined by the contents of constant
+// arrays. This will almost never produce functional models during inference.
+//
+// As this op may change the data types and ranges of input and output arrays
+// downstream tools must also be sure to parse the output model flags to get the
+// post-Transform values that may have changed due to this transformation.
+//
+// This isn't a GraphTransformation in the traditional respect as it affects ops
+// outside of the one under transformation. This is primarily so that we can
+// utilize the graph traversal and repeated pass system underlying the
+// transformation system to exhaustively find all FakeQuant ops. It also gets us
+// nice logging and integration with the graphviz video dumping mode.
+// In general you should not copy this style of transformation and stick to
+// local-only changes as seen in the other transformations.
+bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
+
+ ArrayDataType quantized_data_type = ArrayDataType::kNone;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
+ LogName(*op), fakequant_op->num_bits);
+ return false;
+ }
+ const auto& final_minmax = *fakequant_op->minmax;
+
+ AddMessageF(
+ "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s",
+ LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max,
+ ArrayDataTypeName(quantized_data_type));
+
+ bool did_change = false;
+
+ // Propagate the FakeQuant information backward up the graph.
+ // This will possibly adjust input arrays or constant types (like Gather).
+ did_change |= RecursivelyBackwardPropagateDataType(
+ this, model, op, quantized_data_type, final_minmax);
+
+ // Propagate the FakeQuant information forward down the graph.
+ // This will possibly adjust output arrays.
+ did_change |=
+ RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
+
+ return did_change;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
index e080df4bed..d74cad9a62 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
@@ -22,6 +22,20 @@ limitations under the License.
namespace toco {
+bool InferQuantizedDataTypeFromFakeQuant(
+ const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
+ if (op.num_bits <= 8) {
+ *out_quantized_data_type = ArrayDataType::kUint8;
+ return true;
+ } else if (op.num_bits <= 16) {
+ *out_quantized_data_type = ArrayDataType::kInt16;
+ return true;
+ } else {
+ *out_quantized_data_type = ArrayDataType::kNone;
+ return false;
+ }
+}
+
bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
double* out_min_value,
double* out_max_value) {
@@ -103,6 +117,80 @@ void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
}
}
+namespace {
+
+template <ArrayDataType A>
+std::unique_ptr<GenericBuffer> QuantizeBuffer(
+ const GenericBuffer& buffer,
+ const QuantizationParams& quantization_params) {
+ const auto inverse_scale = 1. / quantization_params.scale;
+ CHECK(buffer.type == ArrayDataType::kFloat);
+ const auto& float_buffer =
+ static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
+ auto* quantized_buffer = new Buffer<A>;
+ quantized_buffer->data.resize(float_buffer.data.size());
+ for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
+ const float src_val = float_buffer.data[i];
+ double scaled_val; // Astonishingly, using 'float' degrades accuracy just
+ // enough to make a few tests fail!
+ if (quantization_params.scale == 0) {
+ CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
+ << "so all its values should be 0.";
+ scaled_val = quantization_params.zero_point;
+ } else {
+ scaled_val = quantization_params.zero_point + inverse_scale * src_val;
+ }
+ quantized_buffer->data[i] =
+ tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ }
+ return std::unique_ptr<GenericBuffer>(quantized_buffer);
+}
+
+template <ArrayDataType A>
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name,
+ const QuantizationParams& quantization_params) {
+ auto& array = model->GetArray(name);
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ CHECK(!array.quantization_params);
+ array.GetOrCreateQuantizationParams() = quantization_params;
+ if (array.buffer) {
+ array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ }
+ array.data_type = A;
+ array.final_data_type = A;
+ transformation->AddMessageF(
+ "Quantized array %s to %s zero_point=%g, scale=%g", name,
+ ArrayDataTypeName(array.data_type), quantization_params.zero_point,
+ quantization_params.scale);
+}
+
+} // namespace
+
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params) {
+ ArrayDataType adjusted_data_type = quantized_data_type;
+ auto& array = model->GetArray(name);
+ if (array.final_data_type == ArrayDataType::kInt16) {
+ adjusted_data_type = array.final_data_type;
+ }
+
+ switch (adjusted_data_type) {
+ case ArrayDataType::kUint8:
+ return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt16:
+ return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt32:
+ return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
+ quantization_params);
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
const Array& array, double clamp_min,
double clamp_max) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
index 35fb310777..79a2ce7e50 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
@@ -15,11 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
namespace toco {
+// Gets the target quantized data type of an array based on the fake quant op.
+// For example, if the num_bits is 8 the data type will be kUint8.
+bool InferQuantizedDataTypeFromFakeQuant(
+ const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type);
+
// Gets the min/max numerical range for the given quantized data type.
// For example, kUint8 will return [0,255].
// Returns true if the ranges were set and false if the type is not quantized.
@@ -32,11 +38,28 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
ArrayDataType GetQuantizedDataType(const Array& array,
ArrayDataType default_type);
-// Gets the quantization params for the array with the given data type and
+// Returns the quantization params for the array with the given data type and
// minmax.
void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
QuantizationParams* quantization_params);
+// Returns the quantization params for the data type and minmax values.
+template <ArrayDataType A>
+void GetQuantizationParamsFromMinMax(const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ using Integer = DataType<A>;
+ const double rmin = minmax.min;
+ const double rmax = minmax.max;
+ *quantization_params =
+ ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
+}
+
+// Quantizes an array by setting its data type and (if constant) quantizing
+// all values in the array.
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params);
+
// Returns true if the given array, when quantized, contains only values between
// the provided clamp min/max.
// Either clamp_min or clamp_max may be +/-infinity to indicate that the value
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index d6cae3cdbf..fa46e6bc38 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -57,72 +57,6 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kTranspose || type == OperatorType::kMean;
}
-template <ArrayDataType A>
-std::unique_ptr<GenericBuffer> QuantizeBuffer(
- const GenericBuffer& buffer,
- const QuantizationParams& quantization_params) {
- const auto inverse_scale = 1. / quantization_params.scale;
- CHECK(buffer.type == ArrayDataType::kFloat);
- const auto& float_buffer =
- static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
- auto* quantized_buffer = new Buffer<A>;
- quantized_buffer->data.resize(float_buffer.data.size());
- for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
- const float src_val = float_buffer.data[i];
- double scaled_val; // Astonishingly, using 'float' degrades accuracy just
- // enough to make a few tests fail!
- if (quantization_params.scale == 0) {
- CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
- << "so all its values should be 0.";
- scaled_val = quantization_params.zero_point;
- } else {
- scaled_val = quantization_params.zero_point + inverse_scale * src_val;
- }
- quantized_buffer->data[i] =
- tflite::SafeCast<DataType<A>>(std::round(scaled_val));
- }
- return std::unique_ptr<GenericBuffer>(quantized_buffer);
-}
-
-template <ArrayDataType A>
-void QuantizeArray(GraphTransformation* transformation, Model* model,
- const string& name,
- const QuantizationParams& quantization_params) {
- auto& array = model->GetArray(name);
- CHECK(array.data_type == ArrayDataType::kFloat);
- CHECK(!array.quantization_params);
- array.GetOrCreateQuantizationParams() = quantization_params;
- if (array.buffer) {
- array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
- }
- array.data_type = A;
- transformation->AddMessageF("Quantized array %s", name);
-}
-
-void QuantizeArray(GraphTransformation* transformation, Model* model,
- const string& name, ArrayDataType quantized_data_type,
- const QuantizationParams& quantization_params) {
- ArrayDataType adjusted_data_type = quantized_data_type;
- auto& array = model->GetArray(name);
- if (array.final_data_type == ArrayDataType::kInt16) {
- adjusted_data_type = array.final_data_type;
- }
-
- switch (adjusted_data_type) {
- case ArrayDataType::kUint8:
- return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
- quantization_params);
- case ArrayDataType::kInt16:
- return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
- quantization_params);
- case ArrayDataType::kInt32:
- return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
- quantization_params);
- default:
- LOG(FATAL) << "Unhandled case.";
- }
-}
-
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
auto& array = model->GetArray(array_name);
// Normally we should have a MinMax recorded on this Array,
@@ -245,6 +179,8 @@ bool ChooseQuantizationForOperatorInput(
const auto& input_weights = model->GetArray(op.inputs[weights_input_index]);
if (!input_activations.quantization_params ||
!input_weights.quantization_params) {
+ transformation->AddMessageF(
+ "Input array %s is a bias vector but has no qparams", input);
return false;
}
const auto input_activations_scale =
@@ -366,6 +302,9 @@ bool ChooseQuantizationForOperatorOutput(
const auto& output = op.outputs[output_index];
auto& array = model->GetArray(output);
if (array.data_type != ArrayDataType::kFloat) {
+ transformation->AddMessageF("Array data type already set to %s, final=%s",
+ ArrayDataTypeName(array.data_type),
+ ArrayDataTypeName(array.final_data_type));
return false;
}
*quantized_data_type = model->GetArray(op.inputs[0]).data_type;
@@ -427,29 +366,22 @@ bool ChooseQuantizationForOperatorOutput(
// Fixes array minmax info to match the quantization parameters.
// This is required for when quantization parameters change for an array during
// quantization (such as ChooseQuantizationForOperatorOutput).
-void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
+void FixMinMaxPostQuantization(GraphTransformation* transformation,
+ ArrayDataType quantized_data_type,
const QuantizationParams& quantization_params,
MinMax* minmax) {
- double qmin, qmax;
- switch (quantized_data_type) {
- case ArrayDataType::kUint8:
- qmin = 0;
- qmax = 255;
- break;
- case ArrayDataType::kInt16:
- qmin = -32768;
- qmax = 32767;
- break;
- default:
- // No update required.
- return;
+ double quantized_min, quantized_max;
+ if (!GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
+ &quantized_max)) {
+ // Not quantized - no update required.
+ return;
}
// Compute new minmax values.
- double min =
- (qmin - quantization_params.zero_point) * quantization_params.scale;
- double max =
- (qmax - quantization_params.zero_point) * quantization_params.scale;
+ double min = (quantized_min - quantization_params.zero_point) *
+ quantization_params.scale;
+ double max = (quantized_max - quantization_params.zero_point) *
+ quantization_params.scale;
// If we are close to the existing minmax values don't bother changing them.
// This prevents propagating small floating point precision errors.
@@ -457,6 +389,9 @@ void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
const double width = max - min;
if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
std::abs(max - minmax->max) > kMinMaxThreshold * width) {
+ transformation->AddMessageF(
+ "Adjusting min/max from %g,%g to %g,%g to match quantization params",
+ minmax->min, minmax->max, min, max);
minmax->min = min;
minmax->max = max;
}
@@ -566,10 +501,33 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
// input instead.
for (int i = 0; i < model->flags.output_arrays_size(); i++) {
if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
- model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ // TODO(b/78013785): never rename output arrays.
+ if (IsInputArray(*model, dequantize_op->inputs[0])) {
+ // The op input is an input array and the output is an output
+ // array and we can't have an array be both. Insert a copy
+ // op to ensure the two arrays stay separate.
+ AddMessageF(
+ "Tried to rename output array %d while removing dequant "
+ "op %s but array is also an input; inserting copy %s "
+ "-> %s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ InsertCopyOperator(model, dequantize_op->inputs[0],
+ dequantize_op->outputs[0]);
+ } else {
+ // Op output is strictly used as an output array, so we can
+ // just rename the array and directly bypass the op.
+ AddMessageF(
+ "Renaming output array %d after removing dequant op %s: "
+ "%s -> %s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
+ }
+ break;
}
}
- model->EraseArray(dequantize_op->outputs[0]);
model->operators.erase(dequantize_it);
}
changed = true;
@@ -615,7 +573,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
CHECK(output_array.minmax)
<< "Output array named " << output << " lacks minmax";
auto& output_minmax = output_array.GetMinMax();
- FixMinMaxPostQuantization(quantized_data_type, quantization_params,
+ FixMinMaxPostQuantization(this, quantized_data_type, quantization_params,
&output_minmax);
QuantizeArray(this, model, output, quantized_data_type,
@@ -626,6 +584,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
auto& dequantized_output_array =
model->GetOrCreateArray(dequantized_output);
dequantized_output_array.data_type = ArrayDataType::kFloat;
+ dequantized_output_array.final_data_type = output_array.data_type;
auto& dequantized_output_minmax =
dequantized_output_array.GetOrCreateMinMax();
dequantized_output_minmax.min = output_minmax.min;
@@ -642,6 +601,12 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
dequantize_op->outputs = {dequantized_output};
for (int i = 0; i < model->flags.output_arrays_size(); i++) {
if (model->flags.output_arrays(i) == output) {
+ // TODO(b/78013785): never rename output arrays.
+ AddMessageF(
+ "Renaming output array %d after inserting dequant op %s: %s -> "
+ "%s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantized_output);
model->flags.set_output_arrays(i, dequantized_output);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
new file mode 100644
index 0000000000..2c8d04440f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model,
+ const FakeQuantOperator& fakequant_op) {
+ CHECK(fakequant_op.type == OperatorType::kFakeQuant);
+
+ if (!fakequant_op.minmax) {
+ // Require ReadFakeQuantMinMax to have run.
+ return false;
+ }
+
+ // FakeQuants are trivial if they are taking input from another identical
+ // FakeQuant op.
+ auto* producing_op = GetOpWithOutput(model, fakequant_op.inputs[0]);
+ if (!producing_op || producing_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ const auto& producing_fakequant_op =
+ *static_cast<FakeQuantOperator*>(producing_op);
+ if (!producing_fakequant_op.minmax) {
+ // Require ReadFakeQuantMinMax to have run.
+ return false;
+ }
+
+ if (*fakequant_op.minmax == *producing_fakequant_op.minmax &&
+ fakequant_op.num_bits == producing_fakequant_op.num_bits) {
+ transformation->AddMessageF(
+ "%s is trivial because it is preceded by an identical FakeQuant %s",
+ LogName(fakequant_op), LogName(producing_fakequant_op));
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace
+
+// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc).
+bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+ if (op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
+
+ if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) {
+ AddMessageF("%s is not trivial", LogName(*fakequant_op));
+ return false;
+ }
+
+ AddMessageF("Removing trivial %s", LogName(*fakequant_op));
+
+ CHECK_EQ(fakequant_op->inputs.size(), 1);
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 625d90205a..efb7bb2184 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,9 +46,29 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
}
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+
+ // Determine the final data type in the same way as PropagateFakeQuantNumBits.
+ ArrayDataType quantized_data_type = input_array.final_data_type;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
+ return false;
+ }
+
+ AddMessageF("Resolving constant %s", LogName(*fakequant_op));
+
auto& output_array = model->GetArray(fakequant_op->outputs[0]);
CHECK(input_array.data_type == ArrayDataType::kFloat);
output_array.data_type = ArrayDataType::kFloat;
+
+ // We'll set the final data type to what the fake quant indicates we should
+ // have (and would have been set if this stayed around until
+ // PropagateFakeQuantNumBits).
+ if (propagate_fake_quant_num_bits()) {
+ output_array.final_data_type = quantized_data_type;
+ }
+
CHECK(!output_array.buffer);
const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
output_array.GetOrCreateMinMax() = *fakequant_op->minmax;
@@ -66,7 +87,9 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
output_buffer.data[i] = dst_val;
}
- if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+
+ if (IsDiscardableArray(*model, fakequant_op->inputs[0]) &&
+ CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index cc7803dd86..d1d68b6b47 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -126,6 +126,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
"If true, disable fusion of known identifiable cell subgraphs into "
"cells. This includes, for example, specific forms of LSTM cell."),
+ Flag("propagate_fake_quant_num_bits",
+ parsed_flags.propagate_fake_quant_num_bits.bind(),
+ parsed_flags.propagate_fake_quant_num_bits.default_value(),
+ "If true, use FakeQuant* operator num_bits attributes to adjust "
+ "array data_types."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -211,6 +216,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
+ READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
+ READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 3237147a73..751aca948c 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 14.
+// Next ID to use: 15.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -141,4 +141,13 @@ message TocoFlags {
// Disables transformations that fuse subgraphs such as known LSTMs (not all
// LSTMs are identified).
optional bool debug_disable_recurrent_cell_fusion = 13;
+
+ // Uses the FakeQuantWithMinMaxArgs.num_bits attribute to adjust quantized
+ // array data types throughout the graph. The graph must be properly annotated
+ // with FakeQuant* ops on at least the edges and may contain additional ops on
+ // the interior of the graph to widen/narrow as desired.
+ //
+ // Input and output array data types may change because of this propagation
+ // and users must be sure to query the final data_type values.
+ optional bool propagate_fake_quant_num_bits = 14;
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 5ba093a830..b69852453c 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -66,6 +66,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new RemoveTensorFlowIdentity);
transformations->Add(new RemoveTrivialConcatenation);
transformations->Add(new RemoveTrivialConcatenationInput);
+ transformations->Add(new RemoveTrivialFakeQuant);
transformations->Add(new RemoveTrivialSlice);
transformations->Add(new RemoveUnusedOp);
transformations->Add(new EnsureBiasVectors);
@@ -109,7 +110,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveMeanAttributes);
transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
- transformations->Add(new ResolveConstantFakeQuant);
transformations->Add(new UnpartitionEmbeddingLookup);
}
@@ -233,6 +233,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
MakeGeneralGraphTransformationsSet(&transformations);
auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape);
+ auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
+ if (quantize_output) {
+ resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
+ toco_flags.propagate_fake_quant_num_bits());
+ }
+ transformations.Add(resolve_constant_fake_quant);
if (SupportsFusedActivationFunction(output_format)) {
transformations.Add(new FuseActivationFunctions);
} else {
@@ -264,9 +270,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
+ // Fix any issues with IO edges. This must happen after any transform that
+ // may modify the structure of the edges.
+ FixEdgeArrays(model);
+
if (quantize_output) {
+ if (toco_flags.propagate_fake_quant_num_bits()) {
+ RunGraphTransformations(model,
+ "fake quant propagation graph transformations",
+ {new PropagateFakeQuantNumBits});
+ }
RunGraphTransformations(model, "pre-quantization graph transformations",
- {new HardcodeMinMax, new DropFakeQuant});
+ {
+ new HardcodeMinMax,
+ new DropFakeQuant,
+ });
}
if (quantize_output) {
@@ -303,10 +321,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
}
- // Fix any issues with IO edges. This must happen after any transform that
- // may modify the structure of the edges.
- FixEdgeArrays(model);
-
LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 224df9973e..ecac0c28a5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -93,9 +93,18 @@ string ArrayDataTypeName(ArrayDataType data_type) {
}
}
-bool IsInputArray(const Model& model, const string& name) {
+bool IsInputArray(const Model& model, const string& array_name) {
for (const auto& input_array : model.flags.input_arrays()) {
- if (input_array.name() == name) {
+ if (array_name == input_array.name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsOutputArray(const Model& model, const string& array_name) {
+ for (const auto& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
return true;
}
}
@@ -106,10 +115,8 @@ bool IsArrayConsumed(const Model& model, const string& name) {
if (GetOpWithInput(model, name)) {
return true;
}
- for (const string& model_output : model.flags.output_arrays()) {
- if (model_output == name) {
- return true;
- }
+ if (IsOutputArray(model, name)) {
+ return true;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (rnn_state.back_edge_source_array() == name) {
@@ -379,6 +386,7 @@ string HelpfulOperatorTypeName(const Operator& op) {
bool OperatorSupportsFusedActivation(OperatorType type) {
switch (type) {
case OperatorType::kConcatenation:
+ case OperatorType::kFakeQuant:
case OperatorType::kGather:
case OperatorType::kSlice:
case OperatorType::kSqueeze:
@@ -1064,16 +1072,38 @@ void FixEdgeArrays(Model* model) {
}
}
+namespace {
+void CopyArrayAttribs(const Array& source_array, Array* target_array) {
+ target_array->data_type = source_array.data_type;
+ target_array->final_data_type = source_array.final_data_type;
+ target_array->copy_shape(source_array.shape());
+
+ if (source_array.minmax) {
+ target_array->GetOrCreateMinMax() = source_array.GetMinMax();
+ } else {
+ target_array->minmax.reset();
+ }
+
+ if (source_array.quantization_params) {
+ target_array->GetOrCreateQuantizationParams() =
+ source_array.GetQuantizationParams();
+ } else {
+ target_array->quantization_params.reset();
+ }
+}
+} // namespace
+
void InsertCopyOperator(Model* model, const string& source_array_name,
const string& target_array_name) {
+ // Reshape to the same size. This should be a no-op.
+ const Array& source_array = model->GetArray(source_array_name);
+ std::vector<int> shape = source_array.shape().dims();
+
// Drop constant data from the target array as the copy will be done at
// runtime.
Array& target_array = model->GetOrCreateArray(target_array_name);
target_array.buffer.reset();
-
- // Reshape to the same size. This should be a no-op.
- const Array& source_array = model->GetArray(source_array_name);
- std::vector<int> shape = source_array.shape().dims();
+ CopyArrayAttribs(source_array, &target_array);
// Insert copy operator.
auto* copy_op = new TensorFlowReshapeOperator;
@@ -1089,6 +1119,7 @@ void CloneArray(Model* model, const string& source_array_name,
CHECK(!model->HasArray(target_array_name));
const Array& source_array = model->GetArray(source_array_name);
Array& target_array = model->GetOrCreateArray(target_array_name);
+ CopyArrayAttribs(source_array, &target_array);
if (source_array.minmax) {
const auto& smm = source_array.GetMinMax();
@@ -1513,14 +1544,9 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
if (model.IsOptionalArray(array_name)) return false;
// The model's input and output arrays are externally allocated.
// They are not transient arrays.
- if (IsInputArray(model, array_name)) {
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
return false;
}
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
- }
const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
@@ -1898,15 +1924,8 @@ int AxesCount(AxesOrder axes_order) {
}
bool IsDiscardableArray(const Model& model, const string& array_name) {
- for (const auto& input_array : model.flags.input_arrays()) {
- if (array_name == input_array.name()) {
- return false;
- }
- }
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
+ return false;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
@@ -1960,8 +1979,8 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
- << static_cast<int>(array.data_type) << ","
- << static_cast<int>(array.final_data_type) << ").";
+ << ArrayDataTypeName(array.data_type) << ","
+ << ArrayDataTypeName(array.final_data_type) << ").";
}
}
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index ed0ecd4d0f..4c705f4e5f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,7 +28,6 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
-#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -57,7 +56,11 @@ string LogName(const Operator& op);
string ArrayDataTypeName(ArrayDataType data_type);
-bool IsInputArray(const Model& model, const string& name);
+// Returns true if the given array is specified as a model input array.
+bool IsInputArray(const Model& model, const string& array_name);
+// Returns true if the given array is specified as a model output array.
+bool IsOutputArray(const Model& model, const string& array_name);
+
bool IsArrayConsumed(const Model& model, const string& name);
int CountTrueOutputs(const Model& model, const Operator& op);
@@ -175,17 +178,6 @@ void CloneArray(Model* model, const string& source_array_name,
void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
-template <ArrayDataType A>
-void GetQuantizationParamsFromMinMax(const MinMax& minmax,
- QuantizationParams* quantization_params) {
- using Integer = DataType<A>;
- const double rmin = minmax.min;
- const double rmax = minmax.max;
-
- *quantization_params =
- ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
-}
-
template <typename T>
T ConvertOperator(Operator* o, OperatorType type) {
if (o != nullptr && o->type == type) {
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 2bf281b791..6f7d8a19c2 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -62,6 +62,7 @@ def _safe_div(numerator, denominator, name):
0,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
@@ -107,6 +108,7 @@ def streaming_true_positives(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.true_negatives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_negatives(predictions,
@@ -152,6 +154,7 @@ def streaming_true_negatives(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.false_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_false_positives(predictions,
@@ -197,6 +200,7 @@ def streaming_false_positives(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.false_negatives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_false_negatives(predictions,
@@ -241,6 +245,7 @@ def streaming_false_negatives(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.mean')
def streaming_mean(values,
weights=None,
@@ -290,6 +295,7 @@ def streaming_mean(values,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.mean_tensor')
def streaming_mean_tensor(values,
weights=None,
@@ -345,7 +351,7 @@ def streaming_mean_tensor(values,
@deprecated(None, 'Please switch to tf.metrics.accuracy. Note that the order '
- 'of the labels and predictions arguments has been switched.')
+ 'of the labels and predictions arguments has been switched.')
def streaming_accuracy(predictions,
labels,
weights=None,
@@ -402,8 +408,9 @@ def streaming_accuracy(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.precision. Note that the order '
- 'of the labels and predictions arguments has been switched.')
+ 'of the labels and predictions arguments has been switched.')
def streaming_precision(predictions,
labels,
weights=None,
@@ -459,8 +466,9 @@ def streaming_precision(predictions,
updates_collections=updates_collections,
name=name)
+
@deprecated(None, 'Please switch to tf.metrics.recall. Note that the order '
- 'of the labels and predictions arguments has been switched.')
+ 'of the labels and predictions arguments has been switched.')
def streaming_recall(predictions,
labels,
weights=None,
@@ -981,7 +989,7 @@ def streaming_curve_points(labels=None,
@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of '
- 'the labels and predictions arguments has been switched.')
+ 'the labels and predictions arguments has been switched.')
def streaming_auc(predictions,
labels,
weights=None,
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 046652cbc5..3e9b1a0b8d 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
py_library(
name = "proto",
srcs = [
@@ -14,3 +16,17 @@ py_library(
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
)
+
+py_library(
+ name = "proto_pip",
+ data = [
+ "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
+ ] + if_static(
+ [],
+ otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
+ ),
+ deps = [
+ ":proto",
+ "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..a380a131f8
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -0,0 +1,86 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# Much of the work in this BUILD file actually happens in the corresponding
+# build_defs.bzl, which creates an individual testcase for each example .pbtxt
+# file in this directory.
+#
+load(":build_defs.bzl", "decode_proto_test_suite")
+load(":build_defs.bzl", "encode_proto_test_suite")
+
+# This expands to a tf_py_test for each test file.
+# It defines the test_suite :decode_proto_op_tests.
+decode_proto_test_suite(
+ name = "decode_proto_tests",
+ examples = glob(["*.pbtxt"]),
+)
+
+# This expands to a tf_py_test for each test file.
+# It defines the test_suite :encode_proto_op_tests.
+encode_proto_test_suite(
+ name = "encode_proto_tests",
+ examples = glob(["*.pbtxt"]),
+)
+
+# Below here are tests that are not tied to an example text proto.
+filegroup(
+ name = "test_messages",
+ srcs = glob(["*.pbtxt"]),
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+tf_py_test(
+ name = "decode_proto_fail_test",
+ size = "small",
+ srcs = ["decode_proto_fail_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
+py_library(
+ name = "test_case",
+ srcs = ["test_case.py"],
+ deps = ["//tensorflow/python:client_testlib"],
+)
+
+py_library(
+ name = "py_test_deps",
+ deps = [
+ ":test_case",
+ ":test_example_proto_py",
+ ],
+)
+
+tf_proto_library(
+ name = "test_example_proto",
+ srcs = ["test_example.proto"],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+)
+
+tf_cc_shared_object(
+ name = "libtestexample.so",
+ linkstatic = 1,
+ deps = [
+ ":test_example_proto_cc",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
new file mode 100644
index 0000000000..f425601691
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
@@ -0,0 +1,89 @@
+"""BUILD rules for generating file-driven proto test cases.
+
+The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
+of text protos and generates a tf_py_test() for each one.
+"""
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "register_extension_info")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
+def _test_name(test, path):
+ return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
+
+def decode_proto_test_suite(name, examples):
+ """Build the decode_proto py_test for each test filename."""
+ for test_filename in examples:
+ tf_py_test(
+ name = _test_name("decode_proto", test_filename),
+ srcs = ["decode_proto_op_test.py"],
+ size = "small",
+ data = [test_filename] + if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ main = "decode_proto_op_test.py",
+ args = [
+ "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
+ ],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ ],
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+ )
+ native.test_suite(
+ name = name,
+ tests = [":" + _test_name("decode_proto", test_filename)
+ for test_filename in examples],
+ )
+
+def encode_proto_test_suite(name, examples):
+ """Build the encode_proto py_test for each test filename."""
+ for test_filename in examples:
+ tf_py_test(
+ name = _test_name("encode_proto", test_filename),
+ srcs = ["encode_proto_op_test.py"],
+ size = "small",
+ data = [test_filename] + if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ main = "encode_proto_op_test.py",
+ args = [
+ "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
+ ],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+ )
+ native.test_suite(
+ name = name,
+ tests = [":" + _test_name("encode_proto", test_filename)
+ for test_filename in examples],
+ )
+
+register_extension_info(
+ extension_name = "decode_proto_test_suite",
+ label_regex_map = {
+ "deps": "deps:decode_example_.*",
+ })
+
+register_extension_info(
+ extension_name = "encode_proto_test_suite",
+ label_regex_map = {
+ "deps": "deps:encode_example_.*",
+ })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
new file mode 100644
index 0000000000..5298342ee7
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
@@ -0,0 +1,68 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.ops import decode_proto_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class DecodeProtoFailTest(test_case.ProtoOpTestCase):
+ """Test failure cases for DecodeToProto."""
+
+ def _TestCorruptProtobuf(self, sanitize):
+ """Test failure cases for DecodeToProto."""
+
+ # The goal here is to check the error reporting.
+ # Testing against a variety of corrupt protobufs is
+ # done by fuzzing.
+ corrupt_proto = 'This is not a binary protobuf'
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(corrupt_proto, dtype=object)
+ msg_type = 'tensorflow.contrib.proto.TestCase'
+ field_names = ['sizes']
+ field_types = [dtypes.int32]
+
+ with self.test_session() as sess:
+ ctensor, vtensor = decode_proto_op.decode_proto(
+ batch,
+ message_type=msg_type,
+ field_names=field_names,
+ output_types=field_types,
+ sanitize=sanitize)
+ with self.assertRaisesRegexp(errors.DataLossError,
+ 'Unable to parse binary protobuf'
+ '|Failed to consume entire buffer'):
+ _ = sess.run([ctensor] + vtensor)
+
+ def testCorrupt(self):
+ self._TestCorruptProtobuf(sanitize=False)
+
+ def testSanitizerCorrupt(self):
+ self._TestCorruptProtobuf(sanitize=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
new file mode 100644
index 0000000000..d1c13c82bc
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -0,0 +1,300 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for decode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+"""
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.ops import decode_proto_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('message_text_file', None,
+ 'A file containing a text serialized TestCase protobuf.')
+
+
+class DecodeProtoOpTest(test_case.ProtoOpTestCase):
+
+ def _compareValues(self, fd, vs, evs):
+ """Compare lists/arrays of field values."""
+
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for i, ev in enumerate(evs):
+ # Special case fuzzy match for float32. TensorFlow seems to mess with
+ # MAX_FLT slightly and the test doesn't work otherwise.
+ # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
+ if fd.cpp_type == fd.CPPTYPE_FLOAT:
+ # Numpy isclose() is better than assertIsClose() which uses an absolute
+ # value comparison.
+ self.assertTrue(
+ np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
+ elif fd.cpp_type == fd.CPPTYPE_STRING:
+ # In Python3 string tensor values will be represented as bytes, so we
+ # reencode the proto values to match that.
+ self.assertEqual(vs[i], ev.encode('ascii'))
+ else:
+ # Doubles and other types pass through unscathed.
+ self.assertEqual(vs[i], ev)
+
+ def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields,
+ field_dict):
+ """Compare protos of type RepeatedPrimitiveValue.
+
+ Args:
+ batch_shape: the shape of the input tensor of serialized messages.
+ sizes: int matrix of repeat counts returned by decode_proto
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ field_dict: map from field names to decoded numpy tensors of values
+ """
+
+ # Check that expected values match.
+ for field in fields:
+ values = field_dict[field.name]
+ self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
+
+ fd = field.expected.DESCRIPTOR.fields_by_name[field.name]
+
+ # Values has the same shape as the input plus an extra
+ # dimension for repeats.
+ self.assertEqual(list(values.shape)[:-1], batch_shape)
+
+ # Nested messages are represented as TF strings, requiring
+ # some special handling.
+ if field.name == 'message_value':
+ vs = []
+ for buf in values.flat:
+ msg = test_example_pb2.PrimitiveValue()
+ msg.ParseFromString(buf)
+ vs.append(msg)
+ evs = getattr(field.expected, field.name)
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for v, ev in zip(vs, evs):
+ self.assertEqual(v, ev)
+ continue
+
+ # This can be a little confusing. For testing we are using
+ # RepeatedPrimitiveValue in two ways: it's the proto that we
+ # decode for testing, and it's used in the expected value as a
+ # union type. The two cases are slightly different: this is the
+ # second case.
+ # We may be fetching the uint64_value from the test proto, but
+ # in the expected proto we store it in the int64_value field
+ # because TensorFlow doesn't support unsigned int64.
+ tf_type_to_primitive_value_field = {
+ dtypes.float32:
+ 'float_value',
+ dtypes.float64:
+ 'double_value',
+ dtypes.int32:
+ 'int32_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.int8:
+ 'int8_value',
+ dtypes.string:
+ 'string_value',
+ dtypes.int64:
+ 'int64_value',
+ dtypes.bool:
+ 'bool_value',
+ # Unhandled TensorFlow types:
+ # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
+ # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ }
+ tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
+ if tf_field_name is None:
+ self.fail('Unhandled tensorflow type %d' % field.dtype)
+
+ self._compareValues(fd, values.flat,
+ getattr(field.expected, tf_field_name))
+
+ def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
+ message_type, message_format, sanitize,
+ force_disordered=False):
+ """Run decode tests on a batch of messages.
+
+ Args:
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ case_sizes: expected sizes array
+ batch_shape: the shape of the input tensor of serialized messages
+ batch: list of serialized messages
+ message_type: descriptor name for messages
+ message_format: format of messages, 'text' or 'binary'
+ sanitize: whether to sanitize binary protobuf inputs
+ force_disordered: whether to force fields encoded out of order.
+ """
+
+ if force_disordered:
+ # Exercise code path that handles out-of-order fields by prepending extra
+ # fields with tag numbers higher than any real field. Note that this won't
+ # work with sanitization because that forces reserialization using a
+ # trusted decoder and encoder.
+ assert not sanitize
+ extra_fields = test_example_pb2.ExtraFields()
+ extra_fields.string_value = 'IGNORE ME'
+ extra_fields.bool_value = False
+ extra_msg = extra_fields.SerializeToString()
+ batch = [extra_msg + msg for msg in batch]
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(batch, dtype=object)
+ batch = np.reshape(batch, batch_shape)
+
+ field_names = [f.name for f in fields]
+ output_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, vtensor = decode_proto_op.decode_proto(
+ batch,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=output_types,
+ message_format=message_format,
+ sanitize=sanitize)
+
+ vlist = sess.run([sizes] + vtensor)
+ sizes = vlist[0]
+ # Values is a list of tensors, one for each field.
+ value_tensors = vlist[1:]
+
+ # Check that the repeat sizes are correct.
+ self.assertTrue(
+ np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
+
+ # Check that the decoded sizes match the expected sizes.
+ self.assertEqual(len(sizes.flat), len(case_sizes))
+ self.assertTrue(
+ np.all(sizes.flat == np.array(
+ case_sizes, dtype=np.int32)))
+
+ field_dict = dict(zip(field_names, value_tensors))
+
+ self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
+ field_dict)
+
+ def testBinary(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=False)
+
+ def testBinaryDisordered(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=False,
+ force_disordered=True)
+
+ def testPacked(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases
+ # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
+ # To do this we rely on the text format being the same for packed and
+ # unpacked fields, and reparse the test message using the packed version
+ # of the proto.
+ packed_batch = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ primitive, float_format='.17g'),
+ test_example_pb2.PackedPrimitiveValue()).SerializeToString()
+ for primitive in case.primitive
+ ]
+
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ packed_batch,
+ 'tensorflow.contrib.proto.PackedPrimitiveValue',
+ 'binary',
+ sanitize=False)
+
+ def testText(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_batch = [
+ text_format.MessageToString(
+ primitive, float_format='.17g') for primitive in case.primitive
+ ]
+
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ text_batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'text',
+ sanitize=False)
+
+ def testSanitizerGood(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
new file mode 100644
index 0000000000..30e58e6336
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -0,0 +1,180 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for encode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+
+It tests that encode_proto is a lossless inverse of decode_proto
+(for the specified fields).
+"""
+# Python3 readiness boilerplate
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.ops import decode_proto_op
+from tensorflow.contrib.proto.python.ops import encode_proto_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('message_text_file', None,
+ 'A file containing a text serialized TestCase protobuf.')
+
+
+class EncodeProtoOpTest(test_case.ProtoOpTestCase):
+
+ def testBadInputs(self):
+ # Invalid field name
+ with self.test_session():
+ with self.assertRaisesOpError('Unknown field: non_existent_field'):
+ encode_proto_op.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['non_existent_field']).eval()
+
+ # Incorrect types.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Incompatible type for field double_value.'):
+ encode_proto_op.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value']).eval()
+
+ # Incorrect shapes of sizes.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ r'sizes should be batch_size \+ \[len\(field_names\)\]'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values = array_ops.placeholder(dtypes.float64)
+ encode_proto_op.encode_proto(
+ sizes=sizes,
+ values=[values],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value']).eval(feed_dict={
+ sizes: [[[0, 0]]],
+ values: [[0.0]]
+ })
+
+ # Inconsistent shapes of values.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Values must match up to the last dimension'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values1 = array_ops.placeholder(dtypes.float64)
+ values2 = array_ops.placeholder(dtypes.int32)
+ (encode_proto_op.encode_proto(
+ sizes=[[1, 1]],
+ values=[values1, values2],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value', 'int32_value']).eval(feed_dict={
+ values1: [[0.0]],
+ values2: [[0], [0]]
+ }))
+
+ def _testRoundtrip(self, in_bufs, message_type, fields):
+
+ field_names = [f.name for f in fields]
+ out_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, field_tensors = decode_proto_op.decode_proto(
+ in_bufs,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=out_types)
+
+ out_tensors = encode_proto_op.encode_proto(
+ sizes,
+ field_tensors,
+ message_type=message_type,
+ field_names=field_names)
+
+ out_bufs, = sess.run([out_tensors])
+
+ # Check that the re-encoded tensor has the same shape.
+ self.assertEqual(in_bufs.shape, out_bufs.shape)
+
+ # Compare the input and output.
+ for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
+ in_obj = test_example_pb2.RepeatedPrimitiveValue()
+ in_obj.ParseFromString(in_buf)
+
+ out_obj = test_example_pb2.RepeatedPrimitiveValue()
+ out_obj.ParseFromString(out_buf)
+
+ # Check that the deserialized objects are identical.
+ self.assertEqual(in_obj, out_obj)
+
+ # Check that the input and output serialized messages are identical.
+ # If we fail here, there is a difference in the serialized
+ # representation but the new serialization still parses. This could
+ # be harmless (a change in map ordering?) or it could be bad (e.g.
+ # loss of packing in the encoding).
+ self.assertEqual(in_buf, out_buf)
+
+ def testRoundtrip(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
+
+ def testRoundtripPacked(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases
+ # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
+ # To do this we rely on the text format being the same for packed and
+ # unpacked fields, and reparse the test message using the packed version
+ # of the proto.
+ in_bufs = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ primitive, float_format='.17g'),
+ test_example_pb2.PackedPrimitiveValue()).SerializeToString()
+ for primitive in case.primitive
+ ]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
new file mode 100644
index 0000000000..b170f89c0f
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
@@ -0,0 +1,161 @@
+primitive {
+ double_value: -1.7976931348623158e+308
+ double_value: 2.2250738585072014e-308
+ double_value: 1.7976931348623158e+308
+ float_value: -3.402823466e+38
+ float_value: 1.175494351e-38
+ float_value: 3.402823466e+38
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ uint64_value: 0
+ uint64_value: 18446744073709551615
+ int32_value: -2147483648
+ int32_value: 2147483647
+ fixed64_value: 0
+ fixed64_value: 18446744073709551615
+ fixed32_value: 0
+ fixed32_value: 4294967295
+ bool_value: false
+ bool_value: true
+ string_value: ""
+ string_value: "I refer to the infinite."
+ uint32_value: 0
+ uint32_value: 4294967295
+ sfixed32_value: -2147483648
+ sfixed32_value: 2147483647
+ sfixed64_value: -9223372036854775808
+ sfixed64_value: 9223372036854775807
+ sint32_value: -2147483648
+ sint32_value: 2147483647
+ sint64_value: -9223372036854775808
+ sint64_value: 9223372036854775807
+}
+shape: 1
+sizes: 3
+sizes: 3
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: -1.7976931348623158e+308
+ double_value: 2.2250738585072014e-308
+ double_value: 1.7976931348623158e+308
+ }
+}
+field {
+ name: "float_value"
+ dtype: DT_FLOAT
+ expected {
+ float_value: -3.402823466e+38
+ float_value: 1.175494351e-38
+ float_value: 3.402823466e+38
+ }
+}
+field {
+ name: "int64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
+field {
+ name: "uint64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 0
+ int64_value: -1
+ }
+}
+field {
+ name: "int32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "fixed64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 0
+ int64_value: -1 # unsigned is 18446744073709551615
+ }
+}
+field {
+ name: "fixed32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: 0
+ int32_value: -1 # unsigned is 4294967295
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: false
+ bool_value: true
+ }
+}
+field {
+ name: "string_value"
+ dtype: DT_STRING
+ expected {
+ string_value: ""
+ string_value: "I refer to the infinite."
+ }
+}
+field {
+ name: "uint32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: 0
+ int32_value: -1 # unsigned is 4294967295
+ }
+}
+field {
+ name: "sfixed32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "sfixed64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
+field {
+ name: "sint32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "sint64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
new file mode 100644
index 0000000000..c664e52851
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
@@ -0,0 +1,16 @@
+primitive {
+ message_value {
+ double_value: 23.5
+ }
+}
+shape: 1
+sizes: 1
+field {
+ name: "message_value"
+ dtype: DT_STRING
+ expected {
+ message_value {
+ double_value: 23.5
+ }
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
new file mode 100644
index 0000000000..125651d7ea
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
@@ -0,0 +1,20 @@
+primitive {
+ bool_value: true
+}
+shape: 1
+sizes: 1
+sizes: 0
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ }
+}
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 0.0
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
new file mode 100644
index 0000000000..db7555bf2d
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
@@ -0,0 +1,21 @@
+primitive {
+ fixed32_value: 4294967295
+ uint32_value: 4294967295
+}
+shape: 1
+sizes: 1
+sizes: 1
+field {
+ name: "fixed32_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 4294967295
+ }
+}
+field {
+ name: "uint32_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 4294967295
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
new file mode 100644
index 0000000000..61c7ac53f7
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
@@ -0,0 +1,32 @@
+primitive {
+ double_value: 23.5
+ double_value: 123.0
+ bool_value: true
+}
+primitive {
+ double_value: 3.1
+ bool_value: false
+}
+shape: 2
+sizes: 2
+sizes: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ double_value: 123.0
+ double_value: 3.1
+ double_value: 0.0
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ bool_value: false
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
new file mode 100644
index 0000000000..f4828076d5
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
@@ -0,0 +1,62 @@
+primitive {
+ double_value: 23.5
+ bool_value: true
+}
+primitive {
+ double_value: 44.0
+ bool_value: false
+}
+primitive {
+ double_value: 3.14159
+ bool_value: true
+}
+primitive {
+ double_value: 1.414
+ bool_value: true
+}
+primitive {
+ double_value: -32.2
+ bool_value: false
+}
+primitive {
+ double_value: 0.0001
+ bool_value: true
+}
+shape: 3
+shape: 2
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ double_value: 44.0
+ double_value: 3.14159
+ double_value: 1.414
+ double_value: -32.2
+ double_value: 0.0001
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ bool_value: false
+ bool_value: true
+ bool_value: true
+ bool_value: false
+ bool_value: true
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
new file mode 100644
index 0000000000..dc20ac147b
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
@@ -0,0 +1,21 @@
+primitive {
+ double_value: 23.5
+ bool_value: true
+}
+shape: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/test_case.py
new file mode 100644
index 0000000000..b95202c5df
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_case.py
@@ -0,0 +1,35 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestCase(test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(ProtoOpTestCase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
new file mode 100644
index 0000000000..dc495034ff
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
@@ -0,0 +1,149 @@
+// Test description and protos to work with it.
+//
+// Many of the protos in this file are for unit tests that haven't been written yet.
+
+syntax = "proto2";
+
+import "tensorflow/core/framework/types.proto";
+
+package tensorflow.contrib.proto;
+
+// A TestCase holds a proto and a bunch of assertions
+// about how it should decode.
+message TestCase {
+ // A batch of primitives to be serialized and decoded.
+ repeated RepeatedPrimitiveValue primitive = 1;
+ // The shape of the batch.
+ repeated int32 shape = 2;
+ // Expected sizes for each field.
+ repeated int32 sizes = 3;
+ // Expected values for each field.
+ repeated FieldSpec field = 4;
+};
+
+// FieldSpec describes the expected output for a single field.
+message FieldSpec {
+ optional string name = 1;
+ optional tensorflow.DataType dtype = 2;
+ optional RepeatedPrimitiveValue expected = 3;
+};
+
+message TestValue {
+ optional PrimitiveValue primitive_value = 1;
+ optional EnumValue enum_value = 2;
+ optional MessageValue message_value = 3;
+ optional RepeatedMessageValue repeated_message_value = 4;
+ optional RepeatedPrimitiveValue repeated_primitive_value = 6;
+}
+
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
+message RepeatedPrimitiveValue {
+ repeated double double_value = 1;
+ repeated float float_value = 2;
+ repeated int64 int64_value = 3;
+ repeated uint64 uint64_value = 4;
+ repeated int32 int32_value = 5;
+ repeated fixed64 fixed64_value = 6;
+ repeated fixed32 fixed32_value = 7;
+ repeated bool bool_value = 8;
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13;
+ repeated sfixed32 sfixed32_value = 15;
+ repeated sfixed64 sfixed64_value = 16;
+ repeated sint32 sint32_value = 17;
+ repeated sint64 sint64_value = 18;
+ repeated PrimitiveValue message_value = 19;
+}
+
+// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
+// in the text format, but the binary serializion is different.
+// We test the packed representations by loading the same test cases
+// using this definition instead of RepeatedPrimitiveValue.
+// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
+// in every way except the packed=true declaration.
+message PackedPrimitiveValue {
+ repeated double double_value = 1 [packed = true];
+ repeated float float_value = 2 [packed = true];
+ repeated int64 int64_value = 3 [packed = true];
+ repeated uint64 uint64_value = 4 [packed = true];
+ repeated int32 int32_value = 5 [packed = true];
+ repeated fixed64 fixed64_value = 6 [packed = true];
+ repeated fixed32 fixed32_value = 7 [packed = true];
+ repeated bool bool_value = 8 [packed = true];
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13 [packed = true];
+ repeated sfixed32 sfixed32_value = 15 [packed = true];
+ repeated sfixed64 sfixed64_value = 16 [packed = true];
+ repeated sint32 sint32_value = 17 [packed = true];
+ repeated sint64 sint64_value = 18 [packed = true];
+ repeated PrimitiveValue message_value = 19;
+}
+
+message EnumValue {
+ enum Color {
+ RED = 0;
+ ORANGE = 1;
+ YELLOW = 2;
+ GREEN = 3;
+ BLUE = 4;
+ INDIGO = 5;
+ VIOLET = 6;
+ };
+ optional Color enum_value = 14;
+ repeated Color repeated_enum_value = 15;
+}
+
+
+message InnerMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+}
+
+message MiddleMessageValue {
+ repeated int32 int32_values = 5;
+ optional InnerMessageValue message_value = 11;
+ optional uint32 uint32_value = 13;
+}
+
+message MessageValue {
+ optional double double_value = 1;
+ optional MiddleMessageValue message_value = 11;
+}
+
+message RepeatedMessageValue {
+ message NestedMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+ }
+
+ repeated NestedMessageValue message_values = 11;
+}
+
+// Message containing fields with field numbers higher than any field above. An
+// instance of this message is prepended to each binary message in the test to
+// exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}
diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD
index 597f18c771..dbd311a276 100644
--- a/tensorflow/contrib/rpc/BUILD
+++ b/tensorflow/contrib/rpc/BUILD
@@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
py_library(
name = "rpc",
srcs = [
@@ -11,3 +13,17 @@ py_library(
],
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
)
+
+py_library(
+ name = "rpc_pip",
+ data = if_static(
+ [],
+ otherwise = ["//tensorflow/contrib/rpc/python/kernel_tests:libtestexample.so"],
+ ),
+ deps = [
+ ":rpc",
+ "//tensorflow/contrib/rpc/python/kernel_tests:py_test_deps",
+ "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_base",
+ "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_servicer",
+ ],
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..2311c15a68
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -0,0 +1,80 @@
+# TODO(b/76425722): Port everything in here to OS (currently excluded).
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+# Placeholder for loading internal BUILD rule.
+
+tf_proto_library(
+ name = "test_example_proto",
+ srcs = ["test_example.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+)
+
+py_library(
+ name = "py_test_deps",
+ deps = [":test_example_proto_py"],
+)
+
+py_library(
+ name = "rpc_op_test_base",
+ srcs = ["rpc_op_test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/contrib/proto",
+ "//tensorflow/contrib/rpc",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "rpc_op_test_servicer",
+ srcs = ["rpc_op_test_servicer.py"],
+ deps = [
+ ":py_test_deps",
+ ":rpc_op_test_base",
+ "//tensorflow/core:protos_all_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_cc_shared_object(
+ name = "libtestexample.so",
+ linkstatic = 1,
+ deps = [
+ ":test_example_proto_cc",
+ ],
+)
+
+tf_py_test(
+ name = "rpc_op_test",
+ size = "small",
+ srcs = ["rpc_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ ":rpc_op_test_base",
+ ":rpc_op_test_servicer",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
new file mode 100644
index 0000000000..e2e0dbc7a2
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for RpcOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+import grpc
+from grpc.framework.foundation import logging_pool
+import portpicker
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_servicer
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+from tensorflow.python.platform import test
+
+
+class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
+ _protocol = 'grpc'
+
+ invalid_method_string = 'Method not found'
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(RpcOpTest, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ def get_method_name(self, suffix):
+ return '/tensorflow.contrib.rpc.TestCaseService/%s' % suffix
+
+ def setUp(self):
+ super(RpcOpTest, self).setUp()
+
+ service_port = portpicker.pick_unused_port()
+
+ server = grpc.server(logging_pool.pool(max_workers=25))
+ servicer = rpc_op_test_servicer.RpcOpTestServicer()
+ test_example_pb2_grpc.add_TestCaseServiceServicer_to_server(
+ servicer, server)
+ self._address = 'localhost:%d' % service_port
+ server.add_insecure_port(self._address)
+ server.start()
+ self._server = server
+
+ def tearDown(self):
+ # TODO(ebrevdo): Figure out why this sometimes times out.
+ # self._service.ExitLoop()
+ # self._service_thread.join()
+ # self._server.stop()
+ super(RpcOpTest, self).tearDown()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
new file mode 100644
index 0000000000..89f3ee1a1c
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -0,0 +1,336 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Base class for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.contrib.proto.python.ops import decode_proto_op
+from tensorflow.contrib.proto.python.ops import encode_proto_op
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.rpc.python.ops import rpc_op
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+
+__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
+
+I_WARNED_YOU = 'I warned you!'
+
+
+class RpcOpTestBase(object):
+ # pylint: disable=missing-docstring,invalid-name
+ """Base class for RpcOp tests."""
+
+ def get_method_name(self, suffix):
+ raise NotImplementedError
+
+ def rpc(self, *args, **kwargs):
+ return rpc_op.rpc(*args, protocol=self._protocol, **kwargs)
+
+ def try_rpc(self, *args, **kwargs):
+ return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
+
+ def testScalarHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = (
+ test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(response_tensors.shape, ())
+ response_values = sess.run(response_tensors)
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values))
+ self.assertAllEqual([2, 3, 4], response_message.shape)
+
+ def testScalarHostPortTryRpc(self):
+ with self.test_session() as sess:
+ request_tensors = (
+ test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(status_code.shape, ())
+ self.assertEqual(status_message.shape, ())
+ self.assertEqual(response_tensors.shape, ())
+ response_values, status_code_values, status_message_values = (
+ sess.run((response_tensors, status_code, status_message)))
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values))
+ self.assertAllEqual([2, 3, 4], response_message.shape)
+ # For the base Rpc op, don't expect to get error status back.
+ self.assertEqual(errors.OK, status_code_values)
+ self.assertEqual(b'', status_message_values)
+
+ def testEmptyHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = []
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertAllEqual(response_tensors.shape, [0])
+ response_values = sess.run(response_tensors)
+ self.assertAllEqual(response_values.shape, [0])
+
+ def testInvalidAddresses(self):
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(
+ self.rpc(
+ method='/InvalidService.IncrementTestShapes',
+ address=self._address,
+ request=''))
+
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(
+ self.rpc(
+ method=self.get_method_name('InvalidMethodName'),
+ address=self._address,
+ request=''))
+
+ # This also covers the case of address=''
+ # and address='localhost:293874293874'
+ with self.assertRaises(errors.UnavailableError):
+ sess.run(
+ self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+ request=''))
+
+ # Test invalid method with the TryRpc op
+ _, status_code_value, status_message_value = sess.run(
+ self.try_rpc(
+ method=self.get_method_name('InvalidMethodName'),
+ address=self._address,
+ request=''))
+ self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertTrue(
+ self.invalid_method_string in status_message_value.decode('ascii'))
+
+ def testAlwaysFailingMethod(self):
+ with self.test_session() as sess:
+ response_tensors = self.rpc(
+ method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+ address=self._address,
+ request='')
+ self.assertEqual(response_tensors.shape, ())
+ with self.assertRaisesOpError(I_WARNED_YOU):
+ sess.run(response_tensors)
+
+ def testSometimesFailingMethodWithManyRequests(self):
+ with self.test_session() as sess:
+ # Fail hard by default.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SometimesFailWithInvalidArgument'),
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ with self.assertRaisesOpError(I_WARNED_YOU):
+ sess.run(response_tensors)
+
+ # Don't fail hard, use TryRpc - return the failing status instead.
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('SometimesFailWithInvalidArgument'),
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ self.assertEqual(status_code.shape, (20,))
+ self.assertEqual(status_message.shape, (20,))
+ status_code_values, status_message_values = sess.run((status_code,
+ status_message))
+ self.assertTrue([
+ x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
+ ])
+ expected_message_values = np.where(
+ status_code_values == errors.INVALID_ARGUMENT,
+ I_WARNED_YOU.encode('ascii'), b'')
+ self.assertAllEqual(expected_message_values, status_message_values)
+
+ def testVecHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(response_tensors.shape, (20,))
+ response_values = sess.run(response_tensors)
+ self.assertEqual(response_values.shape, (20,))
+ for i in range(20):
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+ def testVecHostPortManyParallelRpcs(self):
+ with self.test_session() as sess:
+ request_tensors = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ many_response_tensors = [
+ self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors) for _ in range(10)
+ ]
+ # Launch parallel 10 calls to the RpcOp, each containing
+ # 20 rpc requests.
+ many_response_values = sess.run(many_response_tensors)
+ self.assertEqual(10, len(many_response_values))
+ for response_values in many_response_values:
+ self.assertEqual(response_values.shape, (20,))
+ for i in range(20):
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+ def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
+ with self.test_session() as sess:
+ request_tensors = encode_proto_op.encode_proto(
+ message_type='tensorflow.contrib.rpc.TestCase',
+ field_names=['shape'],
+ sizes=[[3]] * 20,
+ values=[
+ [[i, i + 1, i + 2] for i in range(20)],
+ ])
+ response_tensor_strings = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ _, (response_shape,) = decode_proto_op.decode_proto(
+ bytes=response_tensor_strings,
+ message_type='tensorflow.contrib.rpc.TestCase',
+ field_names=['shape'],
+ output_types=[dtypes.int32])
+ response_shape_values = sess.run(response_shape)
+ self.assertAllEqual([[i + 1, i + 2, i + 3]
+ for i in range(20)], response_shape_values)
+
+ def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
+ with self.test_session() as sess:
+ request_tensors = [''] * 25 # This will launch 25 RPC requests.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SleepForever'),
+ address=self._address,
+ request=request_tensors)
+ for timeout_ms in [1, 500, 1000]:
+ options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
+ with self.assertRaises((errors.UnavailableError,
+ errors.DeadlineExceededError)):
+ sess.run(response_tensors, options=options)
+
+ def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
+ with self.test_session() as sess:
+ request_tensors = [''] * 25 # This will launch 25 RPC requests.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SleepForever'),
+ address=self._address,
+ timeout_in_ms=1000,
+ request=request_tensors)
+ with self.assertRaises(errors.DeadlineExceededError):
+ sess.run(response_tensors)
+
+ def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
+ with self.test_session() as sess:
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('SometimesSleepForever'),
+ timeout_in_ms=1000,
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ self.assertEqual(status_code.shape, (20,))
+ self.assertEqual(status_message.shape, (20,))
+ status_code_values = sess.run(status_code)
+ self.assertTrue([
+ x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
+ ])
+
+ def testTryRpcWithMultipleAddressesSingleRequest(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ addresses = flatten([[
+ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ ] for _ in range(10)])
+ request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ response_tensors, status_code, _ = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=addresses,
+ request=request)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+ status_code_values)
+ for i in range(10):
+ self.assertTrue(response_tensors_values[2 * i])
+ self.assertFalse(response_tensors_values[2 * i + 1])
+
+ def testTryRpcWithMultipleMethodsSingleRequest(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ methods = flatten(
+ [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
+ for _ in range(10)])
+ request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ response_tensors, status_code, _ = self.try_rpc(
+ method=methods, address=self._address, request=request)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
+ status_code_values)
+ for i in range(10):
+ self.assertTrue(response_tensors_values[2 * i])
+ self.assertFalse(response_tensors_values[2 * i + 1])
+
+ def testTryRpcWithMultipleAddressesAndRequests(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ addresses = flatten([[
+ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ ] for _ in range(10)])
+ requests = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ response_tensors, status_code, _ = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=addresses,
+ request=requests)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+ status_code_values)
+ for i in range(20):
+ if i % 2 == 1:
+ self.assertFalse(response_tensors_values[i])
+ else:
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(
+ response_message.ParseFromString(response_tensors_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
new file mode 100644
index 0000000000..7cbd636cb1
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test servicer for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import time
+
+import grpc
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+
+
+class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
+ """Test servicer for RpcOp tests."""
+
+ def IncrementTestShapes(self, request, context):
+ """Increment the entries in the shape attribute of request.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ for i in range(len(request.shape)):
+ request.shape[i] += 1
+ return request
+
+ def AlwaysFailWithInvalidArgument(self, request, context):
+ """Always fails with an InvalidArgument status.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ del request
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details(rpc_op_test_base.I_WARNED_YOU)
+
+ def SometimesFailWithInvalidArgument(self, request, context):
+ """Sometimes fails with an InvalidArgument status.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ if random.randint(0, 1) == 1:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details(rpc_op_test_base.I_WARNED_YOU)
+ return request
+
+ def SleepForever(self, request, context):
+ """Sleeps forever.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ # TODO(ebrevdo): Make this async wait like the stubby version.
+ time.sleep(5)
+
+ def SometimesSleepForever(self, request, context):
+ """Sometimes sleeps forever.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ if random.randint(0, 1) == 1:
+ time.sleep(5)
+ return request
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
new file mode 100644
index 0000000000..96f4550f62
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
@@ -0,0 +1,171 @@
+// Test description and protos to work with it.
+//
+// Many of the protos in this file are for unit tests that haven't been written yet.
+
+syntax = "proto2";
+
+import "tensorflow/core/framework/types.proto";
+
+package tensorflow.contrib.rpc;
+
+// A TestCase holds a proto and a bunch of assertions
+// about how it should decode.
+message TestCase {
+ // A batch of primitives to be serialized and decoded.
+ repeated RepeatedPrimitiveValue primitive = 1;
+ // The shape of the batch.
+ repeated int32 shape = 2;
+ // Expected sizes for each field.
+ repeated int32 sizes = 3;
+ // Expected values for each field.
+ repeated FieldSpec field = 4;
+};
+
+service TestCaseService {
+ // Copy input, and increment each entry in 'shape' by 1.
+ rpc IncrementTestShapes(TestCase) returns (TestCase) {
+ }
+
+ // Sleep forever.
+ rpc SleepForever(TestCase) returns (TestCase) {
+ }
+
+ // Sleep forever 50% of the time, return immediately the other 50%.
+ rpc SometimesSleepForever(TestCase) returns (TestCase) {
+ }
+
+ // Always fails with InvalidArgument.
+ rpc AlwaysFailWithInvalidArgument(TestCase) returns (TestCase) {
+ }
+
+ // Fails with InvalidArgument 50% of the time.
+ rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
+ }
+};
+
+// FieldSpec describes the expected output for a single field.
+message FieldSpec {
+ optional string name = 1;
+ optional tensorflow.DataType dtype = 2;
+ optional RepeatedPrimitiveValue expected = 3;
+};
+
+message TestValue {
+ optional PrimitiveValue primitive_value = 1;
+ optional EnumValue enum_value = 2;
+ optional MessageValue message_value = 3;
+ optional RepeatedMessageValue repeated_message_value = 4;
+ optional RepeatedPrimitiveValue repeated_primitive_value = 6;
+}
+
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
+message RepeatedPrimitiveValue {
+ repeated double double_value = 1;
+ repeated float float_value = 2;
+ repeated int64 int64_value = 3;
+ repeated uint64 uint64_value = 4;
+ repeated int32 int32_value = 5;
+ repeated fixed64 fixed64_value = 6;
+ repeated fixed32 fixed32_value = 7;
+ repeated bool bool_value = 8;
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13;
+ repeated sfixed32 sfixed32_value = 15;
+ repeated sfixed64 sfixed64_value = 16;
+ repeated sint32 sint32_value = 17;
+ repeated sint64 sint64_value = 18;
+ repeated PrimitiveValue message_value = 19;
+}
+
+// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
+// in the text format, but the binary serializion is different.
+// We test the packed representations by loading the same test cases
+// using this definition instead of RepeatedPrimitiveValue.
+// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
+// in every way except the packed=true declaration.
+message PackedPrimitiveValue {
+ repeated double double_value = 1 [packed = true];
+ repeated float float_value = 2 [packed = true];
+ repeated int64 int64_value = 3 [packed = true];
+ repeated uint64 uint64_value = 4 [packed = true];
+ repeated int32 int32_value = 5 [packed = true];
+ repeated fixed64 fixed64_value = 6 [packed = true];
+ repeated fixed32 fixed32_value = 7 [packed = true];
+ repeated bool bool_value = 8 [packed = true];
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13 [packed = true];
+ repeated sfixed32 sfixed32_value = 15 [packed = true];
+ repeated sfixed64 sfixed64_value = 16 [packed = true];
+ repeated sint32 sint32_value = 17 [packed = true];
+ repeated sint64 sint64_value = 18 [packed = true];
+ repeated PrimitiveValue message_value = 19;
+}
+
+message EnumValue {
+ enum Color {
+ RED = 0;
+ ORANGE = 1;
+ YELLOW = 2;
+ GREEN = 3;
+ BLUE = 4;
+ INDIGO = 5;
+ VIOLET = 6;
+ };
+ optional Color enum_value = 14;
+ repeated Color repeated_enum_value = 15;
+}
+
+
+message InnerMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+}
+
+message MiddleMessageValue {
+ repeated int32 int32_values = 5;
+ optional InnerMessageValue message_value = 11;
+ optional uint32 uint32_value = 13;
+}
+
+message MessageValue {
+ optional double double_value = 1;
+ optional MiddleMessageValue message_value = 11;
+}
+
+message RepeatedMessageValue {
+ message NestedMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+ }
+
+ repeated NestedMessageValue message_values = 11;
+}
+
+// Message containing fields with field numbers higher than any field above. An
+// instance of this message is prepended to each binary message in the test to
+// exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index a62069a252..1a1591d798 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -3,9 +3,12 @@
licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
+package(default_visibility = [
+ "//learning/brain/google/xla/tests:__subpackages__",
+ "//tensorflow:__subpackages__",
+])
-package(default_visibility = ["//tensorflow:__subpackages__"])
+exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,6 +41,7 @@ tf_custom_op_py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:functional_ops",
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index ac830ae98e..b549cbf568 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase):
# Mostly a smoke test
time_steps = max_out
+ expected_length = sequence_length
if maximum_iterations is not None:
time_steps = min(max_out, maximum_iterations)
+ expected_length = [min(x, maximum_iterations) for x in expected_length]
self.assertEqual(
_t((batch_size, time_steps, cell_depth)),
sess_results["final_outputs"].rnn_output.shape)
self.assertEqual(
_t((batch_size, time_steps)),
sess_results["final_outputs"].sample_id.shape)
+ self.assertItemsEqual(expected_length,
+ sess_results["final_sequence_length"])
def testDynamicDecodeRNNBatchMajor(self):
self._testDynamicDecodeRNN(time_major=False)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index a0f57417b8..1c9d179e3c 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -655,7 +655,7 @@ def monotonic_attention(p_choose_i, previous_attention, mode):
shifted_1mp_choose_i = array_ops.concat(
[array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1)
# Compute attention distribution recursively as
- # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i]
+ # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i]
# attention[i] = p_choose_i[i]*q[i]
attention = p_choose_i*array_ops.transpose(functional_ops.scan(
# Need to use reshape to remind TF of the shape between loop iterations
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 898493662d..e69725ff8a 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -181,6 +182,15 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
+ def _is_xla_tensor(tensor):
+ try:
+ op = tensor.op
+ except AttributeError:
+ return False
+ if control_flow_util.IsInXLAContext(op):
+ return True
+ return False
+
with variable_scope.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
@@ -198,6 +208,11 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
+ is_xla = False
+ if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
+ is_xla = True
+ if is_xla and maximum_iterations is None:
+ raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
@@ -215,11 +230,13 @@ def dynamic_decode(decoder,
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
+ dynamic_size = maximum_iterations is None or not is_xla
+
def _create_ta(s, d):
return tensor_array_ops.TensorArray(
dtype=d,
- size=0,
- dynamic_size=True,
+ size=0 if dynamic_size else maximum_iterations,
+ dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))
initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
@@ -251,11 +268,8 @@ def dynamic_decode(decoder,
next_finished = decoder_finished
else:
next_finished = math_ops.logical_or(decoder_finished, finished)
- if maximum_iterations is not None:
- next_finished = math_ops.logical_or(
- next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
- math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+ math_ops.logical_not(finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
@@ -296,11 +310,16 @@ def dynamic_decode(decoder,
res = control_flow_ops.while_loop(
condition,
body,
- loop_vars=[
- initial_time, initial_outputs_ta, initial_state, initial_inputs,
- initial_finished, initial_sequence_lengths,
- ],
+ loop_vars=(
+ initial_time,
+ initial_outputs_ta,
+ initial_state,
+ initial_inputs,
+ initial_finished,
+ initial_sequence_lengths,
+ ),
parallel_iterations=parallel_iterations,
+ maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
final_outputs_ta = res[1]
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 567b4af88d..b81ae9dc3e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -444,8 +444,8 @@ class Converter {
* remove this and annotate the edge as a control dependency.
************************************************************************/
// skip control nodes
- if (input_name[0] == '^' ) continue;
- string name = input_name;
+ if (input_name[0] == '^') continue;
+ string name = input_name;
auto first = name.find_first_of(':');
if (first != string::npos && first + 2 == name.size() &&
name[first + 1] == '0')
@@ -2511,7 +2511,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
std::vector<string> input_names;
std::vector<tensorflow::DataType> input_dtypes;
for (const std::pair<int, int>& input : s.input_inds) {
- VLOG(2) << "parsing input. Node id= " << input.first ;
+ VLOG(2) << "parsing input. Node id= " << input.first;
int node_id = input.first;
int output_idx = input.second;
tensorflow::Node* node = s.graph.FindNodeId(node_id);
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index bb60f3e2d7..dc90668559 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -43,6 +43,7 @@
@@TPUEstimator
@@TPUEstimatorSpec
@@RunConfig
+@@InputPipelineConfig
@@TPUConfig
"""
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index b53f9be2e2..5e85a967ad 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -128,6 +128,7 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
// Dumps profile data to <logdir>/plugins/profile/<run>/.
string host_prefix = host.empty() ? "" : StrCat(host, ".");
string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
+ *os << "Creating directory: " << profile_run_dir;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
// Ignore computation_graph for now.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3882377d3d..823893d02c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -160,6 +160,7 @@ exports_files(["ops/ops.pbtxt"])
#
# Note that some protos are in neither additional_core_proto_srcs nor this
# filegroup; e.g. ones with individual proto_library targets.
+# LINT.IfChange
CORE_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
@@ -201,6 +202,7 @@ CORE_PROTO_SRCS = [
"util/memmapped_file_system.proto",
"util/saved_tensor_slice.proto",
]
+# LINT.ThenChange(//tensorflow/core/android_proto_config.asciipb)
# Protos which are not needed on mobile builds, but should be included in
# protos_all.
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
index 62876a293c..3f181e91ce 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
@@ -13,6 +13,30 @@ END
A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
END
}
+ in_arg {
+ name: "l1"
+ description: <<END
+l1 regularization factor on leaf weights, per instance based.
+END
+ }
+ in_arg {
+ name: "l2"
+ description: <<END
+l2 regularization factor on leaf weights, per instance based.
+END
+ }
+ in_arg {
+ name: "tree_complexity"
+ description: <<END
+adjustment to the gain, per leaf based.
+END
+ }
+ in_arg {
+ name: "min_node_weight"
+ description: <<END
+mininum avg of hessians in a node before required for the node to be considered for splitting.
+END
+ }
out_arg {
name: "node_ids_list"
description: <<END
@@ -44,24 +68,6 @@ A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_l
END
}
attr {
- name: "l1"
- description: <<END
-l1 regularization factor on leaf weights, per instance based.
-END
- }
- attr {
- name: "l2"
- description: <<END
-l2 regularization factor on leaf weights, per instance based.
-END
- }
- attr {
- name: "tree_complexity"
- description: <<END
-adjustment to the gain, per leaf based.
-END
- }
- attr {
name: "max_splits"
description: <<END
the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
index b23e77a1fa..60ad9b4640 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
@@ -27,12 +27,6 @@ scalar, dimension of the logits, to be used for partial logits
shape.
END
}
- attr {
- name: "max_depth"
- description: <<END
-scalar, max depth of trees. To be used for parallelization costs.
-END
- }
summary: "Runs multiple additive regression ensemble predictors on input instances and"
description: <<END
computes the logits. It is designed to be used during prediction.
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
index 7203d3cb58..f8a3639c9b 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
@@ -54,12 +54,6 @@ scalar, dimension of the logits, to be used for partial logits
shape.
END
}
- attr {
- name: "max_depth"
- description: <<END
-scalar, max depth of trees. To be used for parallelization costs.
-END
- }
summary: "Runs multiple additive regression ensemble predictors on input instances and"
description: <<END
computes the update to cached logits. It is designed to be used during training.
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
index 00f8953875..3cf486d087 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
@@ -51,13 +51,13 @@ of the feature's splits. Will be added to the previous node values to constitute
the values of the right nodes.
END
}
- attr {
+ in_arg {
name: "max_depth"
description: <<END
Max depth of the tree to build.
END
}
- attr {
+ in_arg {
name: "learning_rate"
description: <<END
shrinkage const for each new tree.
diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc
index 7ff360ee26..22fd940d82 100644
--- a/tensorflow/core/common_runtime/process_util.cc
+++ b/tensorflow/core/common_runtime/process_util.cc
@@ -54,12 +54,12 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
if (inter_op != 0) return inter_op;
#ifdef INTEL_MKL
// MKL library executes ops in parallel using OMP threads
- // Set inter_op conservatively to avoid thread oversubscription that could
+ // Set inter_op conservatively to avoid thread oversubscription that could
// lead to severe perf degradations and OMP resource exhaustion
const int mkl_intra_op = omp_get_max_threads();
CHECK_GE(mkl_intra_op, 1);
const int32 mkl_inter_op = std::max(
- (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
+ (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
VLOG(0) << "Creating new thread pool with default inter op setting: "
<< mkl_inter_op
<< ". Tune using inter_op_parallelism_threads for best performance.";
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index 62b299d5c2..0abac4f3c7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -35,7 +35,7 @@ class GrpcByteSource : public TensorResponse::Source {
explicit GrpcByteSource(::grpc::ByteBuffer* buffer) : buffer_(buffer) {}
~GrpcByteSource() override { DeleteStream(); }
- typedef ::grpc::GrpcProtoBufferReader Reader;
+ typedef ::grpc::ProtoBufferReader Reader;
protobuf::io::ZeroCopyInputStream* contents() override {
DeleteStream();
diff --git a/tensorflow/core/framework/graph_transfer_info.proto b/tensorflow/core/framework/graph_transfer_info.proto
index 016259ddbf..41dd54d78c 100644
--- a/tensorflow/core/framework/graph_transfer_info.proto
+++ b/tensorflow/core/framework/graph_transfer_info.proto
@@ -8,6 +8,46 @@ option java_package = "org.tensorflow.framework";
import "tensorflow/core/framework/types.proto";
+message GraphTransferNodeInput {
+ int32 node_id = 1;
+ int32 output_port = 2;
+}
+message GraphTransferNodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ string type_name = 3;
+ int32 soc_op_id = 4;
+ int32 padding_id = 5;
+ int32 input_count = 6;
+ int32 output_count = 7;
+};
+message GraphTransferConstNodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ repeated int64 shape = 3;
+ bytes data = 4;
+ DataType dtype = 5;
+};
+message GraphTransferNodeInputInfo {
+ int32 node_id = 1;
+ repeated GraphTransferNodeInput node_input = 2;
+};
+message GraphTransferNodeOutputInfo {
+ int32 node_id = 1;
+ repeated int32 max_byte_size = 2;
+};
+message GraphTransferGraphInputNodeInfo {
+ string name = 1;
+ repeated int64 shape = 2;
+ DataType dtype = 3;
+}
+
+message GraphTransferGraphOutputNodeInfo {
+ string name = 1;
+ repeated int64 shape = 2;
+ DataType dtype = 3;
+}
+
// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
@@ -16,53 +56,14 @@ message GraphTransferInfo {
NOP = 0;
HEXAGON = 1;
}
- message NodeInput {
- int32 node_id = 1;
- int32 output_port = 2;
- }
- message NodeInfo {
- string name = 1;
- int32 node_id = 2;
- string type_name = 3;
- int32 soc_op_id = 4;
- int32 padding_id = 5;
- int32 input_count = 6;
- int32 output_count = 7;
- };
- message ConstNodeInfo {
- string name = 1;
- int32 node_id = 2;
- repeated int64 shape = 3;
- bytes data = 4;
- DataType dtype = 5;
- };
- message NodeInputInfo {
- int32 node_id = 1;
- repeated NodeInput node_input = 2;
- };
- message NodeOutputInfo {
- int32 node_id = 1;
- repeated int32 max_byte_size = 2;
- };
- message GraphInputNodeInfo {
- string name = 1;
- repeated int64 shape = 2;
- DataType dtype = 3;
- }
-
- message GraphOutputNodeInfo {
- string name = 1;
- repeated int64 shape = 2;
- DataType dtype = 3;
- }
- repeated NodeInfo node_info = 1;
- repeated ConstNodeInfo const_node_info = 2;
- repeated NodeInputInfo node_input_info = 3;
- repeated NodeOutputInfo node_output_info = 4;
+ repeated GraphTransferNodeInfo node_info = 1;
+ repeated GraphTransferConstNodeInfo const_node_info = 2;
+ repeated GraphTransferNodeInputInfo node_input_info = 3;
+ repeated GraphTransferNodeOutputInfo node_output_info = 4;
// Input Node parameters of transferred graph
- repeated GraphInputNodeInfo graph_input_node_info = 5;
- repeated GraphOutputNodeInfo graph_output_node_info = 6;
+ repeated GraphTransferGraphInputNodeInfo graph_input_node_info = 5;
+ repeated GraphTransferGraphOutputNodeInfo graph_output_node_info = 6;
// Destination of graph transfer
Destination destination = 7;
};
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 05171006b0..ca91d68f79 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -1273,51 +1273,59 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
}
#endif
+namespace {
+template <class OpKernelT>
+void CtxFailureInternal(OpKernelT* op_kernel, const char* file, int line,
+ const Status& s) {
+ const string logging_prefix =
+ file == nullptr ? "CtxFailure: "
+ : strings::StrCat("CtxFailure at ", io::Basename(file),
+ ":", line, ": ");
+
+ if (errors::IsOutOfRange(s)) {
+ // VLOG OutOfRange errors. Dataset ops create OutOfRange errors when they
+ // reach end-of-sequence.
+ VLOG(1) << logging_prefix << s;
+ } else {
+ LOG(WARNING) << logging_prefix << s;
+ }
+ op_kernel->SetStatus(s);
+}
+} // anonymous namespace
+
void OpKernelConstruction::CtxFailure(const Status& s) {
- VLOG(1) << s;
- SetStatus(s);
+ CtxFailureInternal(this, nullptr, 0, s);
}
void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
- LOG(WARNING) << s;
- SetStatus(s);
+ CtxFailureInternal(this, nullptr, 0, s);
}
void OpKernelConstruction::CtxFailure(const char* file, int line,
const Status& s) {
- VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
- << " : " << s;
- SetStatus(s);
+ CtxFailureInternal(this, file, line, s);
}
void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
- LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
- << " : " << s;
- SetStatus(s);
+ CtxFailureInternal(this, file, line, s);
}
void OpKernelContext::CtxFailure(const Status& s) {
- VLOG(1) << s;
- SetStatus(s);
+ CtxFailureInternal(this, nullptr, 0, s);
}
void OpKernelContext::CtxFailureWithWarning(const Status& s) {
- LOG(WARNING) << s;
- SetStatus(s);
+ CtxFailureInternal(this, nullptr, 0, s);
}
void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
- VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
- << " : " << s;
- SetStatus(s);
+ CtxFailureInternal(this, file, line, s);
}
void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
- LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
- << " : " << s;
- SetStatus(s);
+ CtxFailureInternal(this, file, line, s);
}
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 33949319d5..ddbf7f3697 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -41,6 +41,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":utils",
+ "//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 9fa2b7a259..a9c777e551 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -355,6 +356,8 @@ void VerboseLogUnknownDimensionSources(
// information is refined.
class TopoQueue {
public:
+ explicit TopoQueue(const std::unordered_map<const Node*, int>& topo_order)
+ : queue_(CompareNodes(topo_order)) {}
void push(const Node* n) { queue_.insert(n); }
const Node* pop() {
CHECK(!empty());
@@ -371,9 +374,15 @@ class TopoQueue {
// Graph nodes are created in (roughly) topological order. Therefore we can
// use their id to ensure they're sorted topologically.
struct CompareNodes {
+ explicit CompareNodes(
+ const std::unordered_map<const Node*, int>& topo_ordering)
+ : topo_order(topo_ordering) {}
bool operator()(const Node* lhs, const Node* rhs) const {
- return lhs->id() < rhs->id();
+ return topo_order.at(lhs) < topo_order.at(rhs);
}
+
+ private:
+ const std::unordered_map<const Node*, int>& topo_order;
};
std::set<const Node*, CompareNodes> queue_;
};
@@ -689,9 +698,36 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
// nodes to propagate any known shape from the Merge node.
Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
const Node* node, bool relax,
- TopoQueue* new_shapes) {
+ bool* new_shapes) const {
InferenceContext* c = shape_refiner->GetContext(node);
- CHECK_NE(c, nullptr);
+ if (!c) {
+ // The shape refiner can't handle loops. Therefore we first need to remove
+ // all edges
+ std::vector<Edge> edges;
+ std::vector<const Edge*> edge_ptrs;
+ for (const Edge* edge : node->in_edges()) {
+ if (!edge->IsControlEdge()) {
+ edges.push_back(*edge);
+ edge_ptrs.push_back(edge);
+ }
+ }
+ for (const Edge* edge : edge_ptrs) {
+ if (!edge->IsControlEdge()) {
+ graph_->RemoveEdge(edge);
+ }
+ }
+ // Now we can run shape inference
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, relax, new_shapes));
+ // And add all the edges back
+ for (const Edge& edge : edges) {
+ graph_->AddEdge(edge.src(), edge.src_output(), edge.dst(),
+ edge.dst_input());
+ }
+
+ c = shape_refiner->GetContext(node);
+ *new_shapes = true;
+ CHECK_NE(c, nullptr);
+ }
ShapeHandle out1;
TF_RETURN_IF_ERROR(c->WithRank(c->output(1), 0, &out1));
@@ -711,6 +747,11 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
}
InferenceContext* in = shape_refiner->GetContext(e->src());
+ if (!relax && !in) {
+ // Handling a loop for the first time, the back edge won't have any shape
+ // info.
+ continue;
+ }
ShapeHandle input = in->output(e->src_output());
if (relax) {
c->RelaxInput(e->dst_input(), input);
@@ -731,7 +772,7 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
if (!shape_refiner->EquivalentShapes(out, c->output(0))) {
c->set_output(0, out);
- new_shapes->push(node);
+ *new_shapes = true;
}
return Status::OK();
@@ -740,7 +781,7 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
Status GraphProperties::OverwriteFedPorts(
SymbolicShapeRefiner* shape_refiner,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
- const Node* node, TopoQueue* new_shapes) const {
+ const Node* node, bool* new_shapes) const {
auto it = fed_ports.find(node->name());
Status status;
if (it != fed_ports.end()) {
@@ -749,7 +790,7 @@ Status GraphProperties::OverwriteFedPorts(
for (const int output_port : it->second) {
status.Update(shape_refiner->SetUnknownShape(node, output_port));
}
- new_shapes->push(node);
+ *new_shapes = true;
}
return status;
}
@@ -758,9 +799,12 @@ Status GraphProperties::OverwriteFedPorts(
// outputs.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
const Node* node, bool relax,
- TopoQueue* new_shapes) {
+ bool* new_shapes) {
auto enter_ctx = shape_refiner->GetContext(node);
- CHECK_NE(enter_ctx, nullptr);
+ if (!enter_ctx) {
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, relax, new_shapes));
+ enter_ctx = shape_refiner->GetContext(node);
+ }
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) {
@@ -775,7 +819,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
enter_ctx->MergeInput(0, input);
}
enter_ctx->set_output(0, input);
- new_shapes->push(node);
+ *new_shapes = true;
}
}
return Status::OK();
@@ -784,7 +828,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
Status GraphProperties::UpdateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
- const Node* n, TopoQueue* new_shapes) const {
+ const Node* n, bool* new_shapes) const {
if (n->IsEnter()) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
@@ -800,7 +844,7 @@ Status GraphProperties::UpdateShapes(
// We want to avoid propagating through loops on the merge pass because
// the shapes are not guaranteed to converge.
if (relax || !n->IsNextIteration()) {
- new_shapes->push(n);
+ *new_shapes = true;
}
}
}
@@ -837,11 +881,15 @@ Status GraphProperties::PropagateShapes(
while (!new_shapes->empty() &&
num_loop_iterations++ < max_loop_iterations) {
const Node* n = new_shapes->pop();
- for (const Edge* e : n->out_edges()) {
- if (!e->IsControlEdge()) {
- const Node* fanout = e->dst();
- TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports,
- fanout, new_shapes));
+ bool updated = false;
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(shape_refiner, relax, fed_ports, n, &updated));
+ if (updated) {
+ for (const Edge* e : n->out_edges()) {
+ if (!e->IsControlEdge()) {
+ const Node* fanout = e->dst();
+ new_shapes->push(fanout);
+ }
}
}
}
@@ -913,7 +961,12 @@ Status GraphProperties::UpdateResource(
queue_shapes_and_types)) {
qctx->set_output_handle_shapes_and_types(0, queue_shapes_and_types);
- new_shapes->push(qnode);
+ for (const Edge* e : qnode->out_edges()) {
+ if (!e->IsControlEdge()) {
+ const Node* fanout = e->dst();
+ new_shapes->push(fanout);
+ }
+ }
}
return Status::OK();
@@ -923,6 +976,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library());
Graph graph(function_library);
+ graph_ = &graph;
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
@@ -932,6 +986,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// the device placement of nodes has also completed, so there
// is no need to validate colocation constraints again.
options.validate_colocation_constraints = false;
+ options.validate_shape = false;
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
@@ -944,14 +999,29 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
}
}
+ std::unordered_map<const NodeDef*, int> topo_order;
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
+
+ std::unordered_map<string, int> order_by_name;
+ for (const auto topo : topo_order) {
+ order_by_name[topo.first->name()] = topo.second;
+ }
+
// List the resources and the nodes using them. Also collect the Enter and
// Merge nodes.
+ std::unordered_map<const Node*, int> graph_topo_order;
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
- std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes;
std::unordered_set<const Node*> fed_nodes;
+ std::unordered_set<const Node*> primary_inputs;
int num_loops = 0;
for (const Node* const node : graph.nodes()) {
+ auto it = order_by_name.find(node->name());
+ if (it == order_by_name.end()) {
+ continue;
+ }
+ graph_topo_order[node] = it->second;
+
for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) == DataType::DT_RESOURCE) {
const Node* resource;
@@ -959,8 +1029,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
resources[resource].insert(node);
}
}
- if (node->IsEnter()) {
- enter_nodes.insert(node);
+ if (node->num_inputs() == 0) {
+ primary_inputs.insert(node);
} else if (node->IsMerge()) {
merge_nodes.insert(node);
} else if (node->IsNextIteration()) {
@@ -979,22 +1049,20 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// we exclusively relax shapes and propagate shapes through loops until
// reaching fixed point.
for (int relax = 0; relax < 2; relax++) {
- TopoQueue new_shapes;
- // Force the propagation of shapes of Enter nodes manually (the Enter shape
- // function always forwards an UnknownShape).
- for (const Node* node : enter_nodes) {
- TF_RETURN_IF_ERROR(
- UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
- }
+ TopoQueue new_shapes(graph_topo_order);
// Seed the propagation of shapes through merge nodes.
- for (const Node* node : merge_nodes) {
- TF_RETURN_IF_ERROR(
- UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
+ if (relax) {
+ for (const Node* node : merge_nodes) {
+ new_shapes.push(node);
+ }
+ }
+ // Also seed the propagation of shapes in the fanout of primary inputs.
+ for (const Node* node : primary_inputs) {
+ new_shapes.push(node);
}
// Also seed the propagation of shapes in the fanout of fed nodes.
for (const Node* node : fed_nodes) {
- TF_RETURN_IF_ERROR(
- OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes));
+ new_shapes.push(node);
}
// Propagate shapes normally.
TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources,
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 8ff572fe4f..30351f58fd 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -24,6 +24,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
namespace tensorflow {
+class Graph;
+
namespace grappler {
class SymbolicShapeRefiner;
@@ -95,24 +97,22 @@ class GraphProperties {
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
- static Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
- const Node* node, bool relax,
- TopoQueue* new_shapes);
+ Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, const Node* node,
+ bool relax, bool* new_shapes) const;
// Process the Enter node, and enqueue its fanout in new_shapes if needed.
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
- const Node* node, bool relax,
- TopoQueue* new_shapes);
+ const Node* node, bool relax, bool* new_shapes);
// Process a node that is used to feed the model.
Status OverwriteFedPorts(
SymbolicShapeRefiner* shape_refiner,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
- const Node* node, TopoQueue* new_shapes) const;
+ const Node* node, bool* new_shapes) const;
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
Status UpdateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
- const Node* n, TopoQueue* new_shapes) const;
+ const Node* n, bool* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
@@ -127,6 +127,8 @@ class GraphProperties {
std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
const std::vector<OpInfo::TensorProperties> missing_properties_;
+
+ Graph* graph_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index d3d89b59af..3de697bd37 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -303,9 +303,9 @@ TEST_F(GraphPropertiesTest, Queues) {
root.WithOpName("Queue5"),
{DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
Output rnd2 =
- ops::RandomNormal(root.WithOpName("rnd"), {10}, DataType::DT_DOUBLE);
+ ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
Output rnd3 =
- ops::RandomNormal(root.WithOpName("rnd"), {1, 2, 3}, DataType::DT_FLOAT);
+ ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
auto enqueue5 =
ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
auto dequeue5 = ops::QueueDequeue(
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 087190ad2a..b35873ce38 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -35,6 +35,7 @@ constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
constexpr char kIdentity[] = "Identity";
+constexpr char kIdentityN[] = "IdentityN";
constexpr char kRefIdentity[] = "RefIdentity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
@@ -211,6 +212,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kIdentityN, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 96342fedc1..3070eb1799 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -112,6 +112,7 @@ tf_cc_test(
name = "constant_folding_test",
srcs = ["constant_folding_test.cc"],
shard_count = 5,
+ tags = ["noasan"],
deps = [
":constant_folding",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index b80ae5fa40..232132e1e8 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -260,7 +260,7 @@ NodeDef* GetTailOfValuePreservingChain(
is_value_preserving_non_branching);
}
-// Graph optimizer context extension specific to ArithmeticOptimizer
+// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_simplify(nodes_to_simplify) {}
@@ -365,27 +365,37 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
// Check if input can become a part of current optimized nodes group.
virtual bool IsAbsorbableByOptimizedNodesGroup(
- const OptimizedNodesGroup& group, const string& input) const = 0;
+ const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
Status AbsorbInputByOptimizedNodesGroup(const string& input,
OptimizedNodesGroup* group) const {
- NodeDef* node;
- TF_RETURN_IF_ERROR(GetInputNode(input, &node));
-
- if (IsAbsorbableByOptimizedNodesGroup(*group, input)) {
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input_i = node->input(i);
- if (!IsControlInput(input)) {
- TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
+ std::deque<const string*> input_tensors;
+ input_tensors.push_front(&input);
+
+ while (!input_tensors.empty()) {
+ const string* input_tensor = input_tensors.front();
+ input_tensors.pop_front();
+
+ // Get a node for the input tensor.
+ NodeDef* input_node;
+ TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
+
+ if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
+ group->optimized_nodes.push_back(input_node);
+ for (int i = input_node->input_size() - 1; i >= 0; --i) {
+ const string& absorbed_node_input = input_node->input(i);
+ // TODO(ezhulenev): support control inputs
+ if (IsControlInput(absorbed_node_input)) continue;
+ input_tensors.push_front(&absorbed_node_input);
}
+ } else {
+ // If input node can't be absorbed, add it to OptimizedNodesGroup input.
+ OpInfo::TensorProperties properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
+ group->inputs.emplace_back(*input_tensor, properties.shape());
}
- group->optimized_nodes.push_back(node);
- } else {
- // If node can't be absorbed, add it to OptimizedNodesGroup input
- OpInfo::TensorProperties properties;
- TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
- group->inputs.emplace_back(input, properties.shape());
}
+
return Status::OK();
}
@@ -401,9 +411,9 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
group->optimized_nodes.reserve(root_node->input_size());
for (int i = 0; i < root_node->input_size(); ++i) {
const string& input_i = root_node->input(i);
- if (!IsControlInput(input_i)) {
- TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
- }
+ // TODO(ezhulenev): add support for control inputs
+ if (IsControlInput(input_i)) continue;
+ TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
}
return Status::OK();
@@ -455,6 +465,11 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
optimized_nodes_.insert(node->name());
}
+ void AddAllMembersToOptimizedNodes(const OptimizedNodesGroup& group) {
+ AddToOptimizedNodes(group.root_node);
+ for (const NodeDef* opt : group.optimized_nodes) AddToOptimizedNodes(opt);
+ }
+
bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
const NodeDef& node) const {
return group.root_node->device() == node.device();
@@ -510,7 +525,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
// Check if a node can become a root of AddOpsGroup
bool IsSupported(const NodeDef* node) const override {
- if (!CanOptimize(node)) return false;
+ if (!CanOptimize(*node)) return false;
// shape must be symbolically defined and all inputs compatible with it
OpInfo::TensorProperties properties;
@@ -522,59 +537,69 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
protected:
// Check if a node can be absorbed by current OptimizedNodesGroup
bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
- const string& input) const override {
- NodeDef* node;
- Status node_status = GetInputNode(input, &node);
- if (!node_status.ok() || !CanOptimize(node)) return false;
+ const NodeDef& node) const override {
+ if (!CanOptimize(node)) return false;
- if (!IsOnTheSameDevice(group, *node)) {
+ if (!IsOnTheSameDevice(group, node)) {
return false;
}
// with a single output data consumer (presumably if we reach this node from
// previously absorbed or a root node, it means that this node is not used
// as an input to any other op, outside of the group)
- if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
+ if (NumNonControlDataOutputs(node, *ctx_.node_map) != 1) {
return false;
}
// All input shapes must be broadcastable to the node shape
OpInfo::TensorProperties properties;
- Status has_properties = GetTensorProperties(input, &properties);
+ Status has_properties = GetTensorProperties(node.name(), &properties);
return has_properties.ok() &&
- HasAllInputsBroadcastableToShape(*node, properties);
+ HasAllInputsBroadcastableToShape(node, properties);
}
// Node requirements both for a root node and an absorbed node
- bool CanOptimize(const NodeDef* node) const {
+ bool CanOptimize(const NodeDef& node) const {
// TODO(ezhulenev): check if AccumulateNV2 can be supported too
- if (!IsAdd(*node) && !IsAddN(*node)) {
- return false;
- }
- if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
+ if (!IsAdd(node) && !IsAddN(node)) {
return false;
}
- // it must not be created by this stage at any of previous optimization runs
- if (str_util::StrContains(node->name(), stage_name_)) {
+ if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) {
return false;
}
// TODO(ezhulenev): relax this condition for root node
- return !(IsDrivenByControlDependency(*node) ||
- DrivesControlDependency(*node));
+ return !(IsDrivenByControlDependency(node) ||
+ DrivesControlDependency(node));
}
// Rewrite a group of add ops into a single AddN if all input shapes are
// symbolically equal. If not, create AddN for equal shapes first, and then
// build an Add tree, minimizing the cost of broadcasts.
string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
- // all new nodes will be placed under the scope of a root node
+ VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
+ << " op=" << group.root_node->op()
+ << " num_optimized_nodes=" << group.optimized_nodes.size()
+ << " num_inputs=" << group.inputs.size();
+
+ // Do not optimize any of the nodes that are part of this group.
+ AddAllMembersToOptimizedNodes(group);
+
+ // All new nodes will be placed under the scope of a root node.
auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
- // Find what shapes are present in the inputs of absorbed nodes
+ // Find what shapes are present in the inputs of absorbed nodes.
std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
for (const auto& input : group.inputs) {
shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
}
- // Collect all the shapes from representative elements
+ using SigKV = decltype(shape_sig_to_inputs)::value_type;
+ VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
+ << " unique shapes: "
+ << str_util::Join(shape_sig_to_inputs, ", ",
+ [](string* out, SigKV p) {
+ strings::StrAppend(out, p.first);
+ });
+
+ // Collect all the shapes from representative elements.
std::vector<TensorShapeProto> shapes;
shapes.reserve(shape_sig_to_inputs.size());
for (const auto& el : shape_sig_to_inputs)
@@ -936,6 +961,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
bool IsSupported(const NodeDef* node) const override {
if (!IsBinaryAssociative(*node)) return false;
+ if (IsAlreadyOptimized(*node)) return false;
// has a symbolically defined shape with broadcastable inputs
OpInfo::TensorProperties properties;
@@ -955,33 +981,29 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
// Check if a node can be absorbed by current OptimizedNodesGroup
bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
- const string& input) const override {
- NodeDef* node;
- Status node_status = GetInputNode(input, &node);
- if (!node_status.ok()) return false;
-
- if (!IsSameOp(group, *node)) {
+ const NodeDef& node) const override {
+ if (!IsSameOp(group, node)) {
return false;
}
- if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
+ if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) {
return false;
}
- if (IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)) {
+ if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
return false;
}
- if (!IsOnTheSameDevice(group, *node)) {
+ if (!IsOnTheSameDevice(group, node)) {
return false;
}
// Optimized nodes updated in place, and that would break the graph, if the
// node has multiple output consumers
- if (NumNonControlOutputs(*node, *ctx_.node_map) != 1) {
+ if (NumNonControlOutputs(node, *ctx_.node_map) != 1) {
return false;
}
// All input shapes must be broadcastable to the node shape
OpInfo::TensorProperties properties;
- Status has_properties = GetTensorProperties(input, &properties);
+ Status has_properties = GetTensorProperties(node.name(), &properties);
return has_properties.ok() &&
- HasAllInputsBroadcastableToShape(*node, properties);
+ HasAllInputsBroadcastableToShape(node, properties);
}
std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
@@ -993,7 +1015,15 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
}
string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
+ VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
+ << " op=" << group.root_node->op()
+ << " num_optimized_nodes=" << group.optimized_nodes.size();
+
+ // Do not optimize any of the nodes that are part of this group.
+ AddAllMembersToOptimizedNodes(group);
+
if (CountUniqueShapes(group.inputs) <= 1) {
+ VLOG(3) << "Skip min-bcast group with single unique shape";
// nothing to optimize when all shapes are the same
return group.root_node->name();
}
@@ -1033,8 +1063,8 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
NodeDef* node;
if (!optimized_nodes.empty()) {
// re-purpose optimized nodes to build a new tree
- node = optimized_nodes.front();
- optimized_nodes.pop_front();
+ node = optimized_nodes.back();
+ optimized_nodes.pop_back();
} else {
// or use root node if none optimized nodes left
node = group.root_node;
@@ -1101,9 +1131,6 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
AddToOptimizationQueue(node);
}
- // Do not add updated node to any other group
- AddToOptimizedNodes(node);
-
TensorShapeProto shape; // shape is not important at this point
return InputAndShape(node->name(), shape);
}
@@ -1969,8 +1996,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
- VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages()
- << " arithmetic optimization stages";
+ VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
+ << str_util::Join(pipeline.StageNames(), ", ");
while (!nodes_to_simplify.Empty()) {
NodeDef* node = nodes_to_simplify.PopBack();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e639812858..cb1f2ea732 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -105,6 +105,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_identity_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
+ options.remove_negation = false;
optimizer->options_ = options;
}
@@ -2069,20 +2070,20 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
// a b c D a b
NodeMap node_map(&output);
- const NodeDef* mul1_node = node_map.GetNode("mul1");
+ const NodeDef* mul1_node = node_map.GetNode("mul2");
ASSERT_NE(mul1_node, nullptr);
EXPECT_EQ("a", mul1_node->input(0));
EXPECT_EQ("b", mul1_node->input(1));
- const NodeDef* mul2_node = node_map.GetNode("mul2");
+ const NodeDef* mul2_node = node_map.GetNode("mul1");
ASSERT_NE(mul2_node, nullptr);
- EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("mul2", mul2_node->input(0));
EXPECT_EQ("c", mul2_node->input(1));
const NodeDef* mul3_node = node_map.GetNode("mul3");
ASSERT_NE(mul3_node, nullptr);
EXPECT_EQ("D", mul3_node->input(0));
- EXPECT_EQ("mul2", mul3_node->input(1));
+ EXPECT_EQ("mul1", mul3_node->input(1));
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 072f772946..ed398525f3 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -239,6 +239,14 @@ class GraphOptimizerStagePipeline {
std::size_t NumStages() { return stages_.size(); }
+ std::vector<string> StageNames() {
+ std::vector<string> names;
+ for (const auto& stage : stages_) {
+ names.push_back(stage->stage_name());
+ }
+ return names;
+ }
+
private:
std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_;
std::function<bool(const Result&)> break_predicate_;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 24131cb51e..47cb344091 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3556,6 +3556,7 @@ tf_kernel_library(
"pooling_ops_3d_gpu.cu.cc",
],
deps = [
+ ":bounds_check",
":conv_2d",
":conv_3d",
":conv_ops",
@@ -3566,6 +3567,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:stream_executor",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b13a450546..1b5ce32b7b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -50,7 +50,6 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
OP_REQUIRES(context, logits_dimension_ == 1,
errors::InvalidArgument(
"Currently only one dimensional outputs are supported."));
- OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
}
void Compute(OpKernelContext* const context) override {
@@ -155,9 +154,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
output_partial_logits(i, 0) = partial_all_logit;
}
};
- // Assume we will not go over more than one full tree. 4 is a magic
- // number.
- const int64 cost = 4 * max_depth_;
+ // 30 is the magic number. The actual value might be a function of (the
+ // number of layers) * (cpu cycles spent on each layer), but this value
+ // would work for many cases. May be tuned later.
+ const int64 cost = 30;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
@@ -168,7 +168,6 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
private:
int32 logits_dimension_; // the size of the output prediction vector.
int32 num_bucketized_features_; // Indicates the number of features.
- int32 max_depth_;
};
REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
@@ -186,7 +185,6 @@ class BoostedTreesPredictOp : public OpKernel {
OP_REQUIRES(context, logits_dimension_ == 1,
errors::InvalidArgument(
"Currently only one dimensional outputs are supported."));
- OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
}
void Compute(OpKernelContext* const context) override {
@@ -243,7 +241,10 @@ class BoostedTreesPredictOp : public OpKernel {
output_logits(i, 0) = tree_logit;
}
};
- const int64 cost = (latest_tree + 1) * max_depth_;
+ // 10 is the magic number. The actual number might depend on (the number of
+ // layers in the trees) and (cpu cycles spent on each layer), but this
+ // value would work for many cases. May be tuned later.
+ const int64 cost = (latest_tree + 1) * 10;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
@@ -254,7 +255,6 @@ class BoostedTreesPredictOp : public OpKernel {
int32
logits_dimension_; // Indicates the size of the output prediction vector.
int32 num_bucketized_features_; // Indicates the number of features.
- int32 max_depth_;
};
REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 16e65cf284..6dfcd63ab3 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -29,10 +29,6 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
explicit BoostedTreesCalculateBestGainsPerFeatureOp(
OpKernelConstruction* const context)
: OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("l1", &l1_));
- OP_REQUIRES_OK(context, context->GetAttr("l2", &l2_));
- OP_REQUIRES_OK(context,
- context->GetAttr("tree_complexity", &tree_complexity_));
OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
}
@@ -54,6 +50,20 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
for (const auto& tensor : stats_summary_list) {
stats_summary.emplace_back(tensor.tensor<float, 3>());
}
+ const Tensor* l1_t;
+ OP_REQUIRES_OK(context, context->input("l1", &l1_t));
+ const auto l1 = l1_t->scalar<float>()();
+ const Tensor* l2_t;
+ OP_REQUIRES_OK(context, context->input("l2", &l2_t));
+ const auto l2 = l2_t->scalar<float>()();
+ const Tensor* tree_complexity_t;
+ OP_REQUIRES_OK(context,
+ context->input("tree_complexity", &tree_complexity_t));
+ const auto tree_complexity = tree_complexity_t->scalar<float>()();
+ const Tensor* min_node_weight_t;
+ OP_REQUIRES_OK(context,
+ context->input("min_node_weight", &min_node_weight_t));
+ const auto min_node_weight = min_node_weight_t->scalar<float>()();
// Allocate output lists of tensors:
OpOutputList output_node_ids_list;
@@ -99,6 +109,11 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
cum_grad.push_back(total_grad);
cum_hess.push_back(total_hess);
}
+ // Check if node has enough of average hessian.
+ if (total_hess < min_node_weight) {
+ // Do not split the node because not enough avg hessian.
+ continue;
+ }
float best_gain = std::numeric_limits<float>::lowest();
float best_bucket = 0;
float best_contrib_for_left = 0.0;
@@ -106,7 +121,8 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
// Parent gain.
float parent_gain;
float unused;
- CalculateWeightsAndGains(total_grad, total_hess, &unused, &parent_gain);
+ CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
+ &parent_gain);
for (int bucket = 0; bucket < num_buckets; ++bucket) {
const float cum_grad_bucket = cum_grad[bucket];
@@ -114,13 +130,13 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
// Left child.
float contrib_for_left;
float gain_for_left;
- CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket,
+ CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket, l1, l2,
&contrib_for_left, &gain_for_left);
// Right child.
float contrib_for_right;
float gain_for_right;
CalculateWeightsAndGains(total_grad - cum_grad_bucket,
- total_hess - cum_hess_bucket,
+ total_hess - cum_hess_bucket, l1, l2,
&contrib_for_right, &gain_for_right);
if (gain_for_left + gain_for_right > best_gain) {
@@ -173,7 +189,7 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
for (int i = 0; i < num_nodes; ++i) {
output_node_ids_vec(i) = output_node_ids[i];
// Adjust the gains to penalize by tree complexity.
- output_gains_vec(i) = output_gains[i] - tree_complexity_;
+ output_gains_vec(i) = output_gains[i] - tree_complexity;
output_thresholds_vec(i) = output_thresholds[i];
// Logits are 1-dimensional for now.
// TODO(nponomareva): Consider multi-dimensional logits.
@@ -184,8 +200,8 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
}
private:
- void CalculateWeightsAndGains(const float g, const float h, float* weight,
- float* gain) {
+ void CalculateWeightsAndGains(const float g, const float h, const float l1,
+ const float l2, float* weight, float* gain) {
//
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
// (g+l1*sgn(w))^2/(h+l2).
@@ -196,11 +212,11 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
// 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
// 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
// For g from (-l1, l1), thus there is no solution => set to 0.
- if (l1_ > 0) {
- if (g > l1_) {
- g_with_l1 -= l1_;
- } else if (g < -l1_) {
- g_with_l1 += l1_;
+ if (l1 > 0) {
+ if (g > l1) {
+ g_with_l1 -= l1;
+ } else if (g < -l1) {
+ g_with_l1 += l1;
} else {
*weight = 0.0;
*gain = 0.0;
@@ -208,19 +224,16 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
}
}
// Apply L2 regularization.
- if (h + l2_ <= kEps) {
+ if (h + l2 <= kEps) {
// Avoid division by 0 or infinitesimal.
*weight = 0;
*gain = 0;
} else {
- *weight = -g_with_l1 / (h + l2_);
+ *weight = -g_with_l1 / (h + l2);
*gain = -g_with_l1 * (*weight);
}
}
- float l1_;
- float l2_;
- float tree_complexity_;
int max_splits_;
int num_features_;
};
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index 67cac14c52..a14fd4a133 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -43,8 +43,6 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
public:
explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context)
: OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
- OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_));
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
int32 pruning_index;
@@ -79,8 +77,15 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
const Tensor* feature_ids_t;
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
+ const auto feature_ids = feature_ids_t->vec<int32>();
- auto feature_ids = feature_ids_t->vec<int32>();
+ const Tensor* max_depth_t;
+ OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
+ const auto max_depth = max_depth_t->scalar<int32>()();
+
+ const Tensor* learning_rate_t;
+ OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
+ const auto learning_rate = learning_rate_t->scalar<float>()();
// Find best splits for each active node.
std::map<int32, SplitCandidate> best_splits;
@@ -125,10 +130,10 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// For now assume that the weights vectors are one dimensional.
// TODO(nponomareva): change here for multiclass.
const float left_contrib =
- learning_rate_ *
+ learning_rate *
left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
const float right_contrib =
- learning_rate_ *
+ learning_rate *
right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
// unused.
@@ -145,7 +150,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Update growable tree metadata.
ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
// Finalize the tree if needed.
- if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth_) {
+ if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) {
// If the tree is finalized, next growing will start from node 0;
node_id_start = 0;
node_id_end = 1;
@@ -216,8 +221,6 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
private:
int32 num_features_;
- float learning_rate_;
- int32 max_depth_;
PruningMode pruning_mode_;
};
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index e4036ddaa9..a21f13a4dd 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -78,6 +78,7 @@ using CPUDevice = Eigen::ThreadPoolDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
+using ::perftools::gputools::StreamExecutor;
template <typename Device, typename T, typename Index>
class CudnnRNNParamsSizeOp;
@@ -101,15 +102,21 @@ enum class TFRNNInputMode {
};
namespace {
-using perftools::gputools::DeviceMemory;
-using perftools::gputools::DeviceMemoryBase;
-using perftools::gputools::ScratchAllocator;
-using perftools::gputools::dnn::AlgorithmConfig;
-using perftools::gputools::dnn::RnnDirectionMode;
-using perftools::gputools::dnn::RnnInputMode;
-using perftools::gputools::dnn::RnnMode;
-using perftools::gputools::dnn::ToDataType;
-using perftools::gputools::port::StatusOr;
+using ::perftools::gputools::DeviceMemory;
+using ::perftools::gputools::DeviceMemoryBase;
+using ::perftools::gputools::ScratchAllocator;
+using ::perftools::gputools::Stream;
+using ::perftools::gputools::dnn::AlgorithmConfig;
+using ::perftools::gputools::dnn::AlgorithmDesc;
+using ::perftools::gputools::dnn::ProfileResult;
+using ::perftools::gputools::dnn::RnnDescriptor;
+using ::perftools::gputools::dnn::RnnDirectionMode;
+using ::perftools::gputools::dnn::RnnInputMode;
+using ::perftools::gputools::dnn::RnnMode;
+using ::perftools::gputools::dnn::RnnSequenceTensorDescriptor;
+using ::perftools::gputools::dnn::RnnStateTensorDescriptor;
+using ::perftools::gputools::dnn::ToDataType;
+using ::perftools::gputools::port::StatusOr;
Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
if (str == "rnn_relu") {
@@ -252,12 +259,12 @@ class CudnnRnnAllocatorInTemp : public ScratchAllocator {
explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ int64 GetMemoryLimitInBytes(Stream* stream) override {
return std::numeric_limits<int64>::max();
}
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
+ StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
+ int64 byte_size) override {
Tensor temporary_memory;
const DataType tf_data_type = ToTFDataType<T>::value;
int64 allocate_count =
@@ -298,11 +305,11 @@ class CudnnRnnAllocatorInOutput : public ScratchAllocator {
~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
: context_(context), output_index_(output_index) {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ int64 GetMemoryLimitInBytes(Stream* stream) override {
return std::numeric_limits<int64>::max();
}
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
+ StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
+ int64 byte_size) override {
CHECK(total_byte_size_ == 0)
<< "Reserve space allocator can only be called once";
int64 allocate_count =
@@ -338,12 +345,12 @@ class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
~CudnnRNNPersistentSpaceAllocator() override {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ int64 GetMemoryLimitInBytes(Stream* stream) override {
return std::numeric_limits<int64>::max();
}
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
+ StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
+ int64 byte_size) override {
if (total_byte_size_ != 0) {
return Status(error::FAILED_PRECONDITION,
"Persistent space allocator can only be called once");
@@ -374,6 +381,13 @@ struct CudnnModelTypes {
// input-h.
return rnn_mode == RnnMode::kRnnLstm;
}
+
+ string DebugString() const {
+ return strings::Printf(
+ "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
+ static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
+ static_cast<int>(rnn_direction_mode));
+ }
};
// A helper class that collects the shapes to describe a RNN model.
@@ -381,9 +395,9 @@ struct CudnnRnnModelShapes {
int num_layers;
int input_size;
int num_units;
+ int dir_count;
int seq_length;
int batch_size;
- int dir_count;
TensorShape input_shape;
TensorShape output_shape;
TensorShape hidden_state_shape;
@@ -392,10 +406,11 @@ struct CudnnRnnModelShapes {
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count;
}
- string RnnDescDebugString() {
+ string DebugString() const {
return strings::Printf(
- "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]",
- num_layers, input_size, num_units, dir_count);
+ "[num_layers, input_size, num_units, dir_count, seq_length, "
+ "batch_size]: [%d, %d, %d, %d, %d, %d] ",
+ num_layers, input_size, num_units, dir_count, seq_length, batch_size);
}
};
@@ -420,8 +435,15 @@ struct CudnnRnnModelShapesComparator {
}
};
-// Extract and checks the forward input tensors, parameters, and shapes from
-// the OpKernelContext.
+// Pointers to RNN scratch space for a specific set of shape parameters (used as
+// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
+struct RnnScratchSpace {
+ std::unique_ptr<RnnDescriptor> rnn_desc;
+ std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
+};
+
+// Extract and checks the forward input tensors, parameters, and shapes from the
+// OpKernelContext.
Status ExtractForwardInput(OpKernelContext* context,
const CudnnModelTypes& model_types,
const Tensor** input, const Tensor** input_h,
@@ -474,13 +496,171 @@ Status ExtractForwardInput(OpKernelContext* context,
return Status::OK();
}
-using perftools::gputools::dnn::RnnDescriptor;
+template <typename T>
+Status CreateForwardAndBackwardIODescriptors(
+ OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
+ std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
+ std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
+ std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc) {
+ StreamExecutor* executor = context->op_device_context()->stream()->parent();
+ ::perftools::gputools::dnn::DataType data_type = ToDataType<T>::value;
+
+ const TensorShape& input_shape = model_shapes.input_shape;
+ const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+ const TensorShape& output_shape = model_shapes.output_shape;
+
+ DCHECK_EQ(input_shape.dims(), 3);
+ auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
+ input_shape.dim_size(0), input_shape.dim_size(1), input_shape.dim_size(2),
+ data_type);
+ TF_RETURN_IF_ERROR(input_desc_s.status());
+ *input_desc = input_desc_s.ConsumeValueOrDie();
+
+ DCHECK_EQ(hidden_state_shape.dims(), 3);
+ auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
+ hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
+ hidden_state_shape.dim_size(2), data_type);
+ TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
+ *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+
+ DCHECK_EQ(output_shape.dims(), 3);
+ auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
+ output_shape.dim_size(0), output_shape.dim_size(1),
+ output_shape.dim_size(2), data_type);
+ TF_RETURN_IF_ERROR(output_desc_s.status());
+ *output_desc = output_desc_s.ConsumeValueOrDie();
+ return Status::OK();
+}
+
+template <typename T>
+Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
+ const CudnnModelTypes& model_types,
+ const CudnnRnnModelShapes& model_shapes,
+ /* forward inputs */
+ const Tensor* input, const Tensor* input_h,
+ const Tensor* input_c, const Tensor* params,
+ const bool is_training,
+ /* forward outputs, outputs of the function */
+ Tensor* output, Tensor* output_h, Tensor* output_c,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator,
+ ProfileResult* output_profile_result) {
+ std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
+ std::unique_ptr<RnnStateTensorDescriptor> state_desc;
+ std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
+
+ TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
+ context, model_shapes, &input_desc, &state_desc, &output_desc));
+
+ auto input_data = AsDeviceMemory<T>(input);
+ auto input_h_data = AsDeviceMemory<T>(input_h);
+ DeviceMemory<T> input_c_data;
+ if (model_types.HasInputC()) {
+ input_c_data = AsDeviceMemory<T>(input_c);
+ }
+ auto params_data = AsDeviceMemory<T>(params);
+ auto output_data = AsDeviceMemory<T>(output);
+ auto output_h_data = AsDeviceMemory<T>(output_h);
+ DeviceMemory<T> output_c_data;
+ if (model_types.HasInputC()) {
+ output_c_data = AsDeviceMemory<T>(output_c);
+ }
+
+ Stream* stream = context->op_device_context()->stream();
+ bool launch_success =
+ stream
+ ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
+ input_h_data, *state_desc, input_c_data, params_data,
+ *output_desc, &output_data, *state_desc,
+ &output_h_data, *state_desc, &output_c_data,
+ is_training, reserve_space_allocator,
+ workspace_allocator, output_profile_result)
+ .ok();
+ return launch_success
+ ? Status::OK()
+ : errors::Internal(
+ "Failed to call ThenRnnForward with model config: ",
+ model_types.DebugString(), ", ", model_shapes.DebugString());
+}
+
+template <typename T>
+Status DoBackward(
+ OpKernelContext* context, const RnnDescriptor& rnn_desc,
+ const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
+ /* forward inputs */
+ const Tensor* input, const Tensor* input_h, const Tensor* input_c,
+ const Tensor* params,
+ /* forward outptus */
+ const Tensor* output, const Tensor* output_h, const Tensor* output_c,
+ /* backprop inputs */
+ const Tensor* output_backprop, const Tensor* output_h_backprop,
+ const Tensor* output_c_backprop, const Tensor* reserve_space,
+ /* backprop outputs, output of the function */
+ Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
+ Tensor* params_backprop, ScratchAllocator* workspace_allocator,
+ ProfileResult* output_profile_result) {
+ std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
+ std::unique_ptr<RnnStateTensorDescriptor> state_desc;
+ std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
+
+ TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
+ context, model_shapes, &input_desc, &state_desc, &output_desc));
+
+ auto input_data = AsDeviceMemory<T>(input);
+ auto input_h_data = AsDeviceMemory<T>(input_h);
+ DeviceMemory<T> input_c_data;
+ if (model_types.HasInputC()) {
+ input_c_data = AsDeviceMemory<T>(input_c);
+ }
+ auto params_data = AsDeviceMemory<T>(params);
+ auto output_data = AsDeviceMemory<T>(output);
+ auto output_h_data = AsDeviceMemory<T>(output_h);
+ DeviceMemory<T> output_c_data;
+ if (model_types.HasInputC()) {
+ output_c_data = AsDeviceMemory<T>(output_c);
+ }
+ auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
+ auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
+ DeviceMemory<T> output_c_backprop_data;
+ if (model_types.HasInputC()) {
+ output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
+ }
+ auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
+ auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
+ DeviceMemory<T> input_c_backprop_data;
+ if (model_types.HasInputC()) {
+ input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
+ }
+ auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
+ auto reserve_space_uint8 =
+ CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
+
+ // Creates a memory callback for the workspace. The memory lives to the end
+ // of this kernel calls.
+ Stream* stream = context->op_device_context()->stream();
+ bool launch_success =
+ stream
+ ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
+ input_h_data, *state_desc, input_c_data,
+ params_data, *output_desc, output_data, *state_desc,
+ output_h_data, *state_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data,
+ output_c_backprop_data, &input_backprop_data,
+ &input_h_backprop_data, &input_c_backprop_data,
+ &params_backprop_data, &reserve_space_uint8,
+ workspace_allocator, output_profile_result)
+ .ok();
+ return launch_success
+ ? Status::OK()
+ : errors::Internal(
+ "Failed to call ThenRnnBackward with model config: ",
+ model_types.DebugString(), ", ", model_shapes.DebugString());
+}
template <typename T>
void RestoreParams(const OpInputList params_input,
const std::vector<RnnDescriptor::ParamsRegion>& params,
- DeviceMemoryBase* data_dst,
- perftools::gputools::Stream* stream) {
+ DeviceMemoryBase* data_dst, Stream* stream) {
int num_params = params.size();
CHECK(params_input.size() == num_params)
<< "Number of params mismatch. Expected " << params_input.size()
@@ -570,7 +750,7 @@ class CudnnRNNKernelCommon : public OpKernel {
TF_RETURN_IF_ERROR(
ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
- auto* stream = context->op_device_context()->stream();
+ Stream* stream = context->op_device_context()->stream();
// ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
// random number generator, therefore set state_allocator to nullptr.
const AlgorithmConfig algo_config;
@@ -585,6 +765,51 @@ class CudnnRNNKernelCommon : public OpKernel {
return Status::OK();
}
+ template <typename T>
+ Status CreateRnnDescriptor(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode,
+ const AlgorithmConfig& algo_config,
+ ScratchAllocator* dropout_state_allocator,
+ std::unique_ptr<RnnDescriptor>* rnn_desc) {
+ StreamExecutor* executor = context->op_device_context()->stream()->parent();
+ ::perftools::gputools::dnn::DataType data_type = ToDataType<T>::value;
+ auto rnn_desc_s = executor->createRnnDescriptor(
+ model_shapes.num_layers, model_shapes.num_units,
+ model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(),
+ data_type, algo_config, dropout(), seed(), dropout_state_allocator);
+ TF_RETURN_IF_ERROR(rnn_desc_s.status());
+
+ *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
+ return Status::OK();
+ }
+
+ using RnnStateCache =
+ gtl::FlatMap<CudnnRnnModelShapes, RnnScratchSpace,
+ CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>;
+ // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
+ // should outlive the returned pointer.
+ template <typename T>
+ Status GetCachedRnnDescriptor(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode,
+ const AlgorithmConfig& algo_config,
+ RnnStateCache* cache,
+ RnnDescriptor** rnn_desc) {
+ RnnScratchSpace& rnn_state = (*cache)[model_shapes];
+ if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
+ CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
+ new CudnnRNNPersistentSpaceAllocator(context);
+ rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
+ Status status =
+ CreateRnnDescriptor<T>(context, model_shapes, input_mode, algo_config,
+ dropout_state_allocator, &rnn_state.rnn_desc);
+ TF_RETURN_IF_ERROR(status);
+ }
+ *rnn_desc = rnn_state.rnn_desc.get();
+ return Status::OK();
+ }
+
private:
int seed_;
int seed2_;
@@ -648,7 +873,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(3);
auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
- auto* stream = context->op_device_context()->stream();
+ Stream* stream = context->op_device_context()->stream();
std::unique_ptr<RnnDescriptor> rnn_desc;
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
@@ -789,7 +1014,7 @@ class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context,
context->allocate_output(0, {params_size}, &output));
auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
- auto* stream = context->op_device_context()->stream();
+ Stream* stream = context->op_device_context()->stream();
OpInputList weights;
OP_REQUIRES_OK(context, context->input_list("weights", &weights));
@@ -816,13 +1041,6 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
-// Pointers to RNN scratch space for a specific set of shape parameters (used as
-// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
-struct RnnScratchSpace {
- std::unique_ptr<RnnDescriptor> rnn_desc;
- std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
-};
-
// Run the forward operation of the RNN model.
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
@@ -842,115 +1060,71 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
- const auto& input_shape = model_shapes.input_shape;
- const auto& hidden_state_shape = model_shapes.hidden_state_shape;
- const auto& output_shape = model_shapes.output_shape;
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- Tensor* output_h = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, hidden_state_shape, &output_h));
- Tensor* output_c = nullptr;
- if (HasInputC()) {
- // Only LSTM uses input_c and output_c. So for all other models, we only
- // need to create dummy outputs.
- OP_REQUIRES_OK(
- context, context->allocate_output(2, hidden_state_shape, &output_c));
- } else {
- OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c));
- }
-
- auto* stream = context->op_device_context()->stream();
- auto* executor = stream->parent();
RnnInputMode input_mode;
OP_REQUIRES_OK(context,
ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
model_shapes.input_size, &input_mode));
- auto data_type = ToDataType<T>::value;
-
- auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
- input_shape.dim_size(0), input_shape.dim_size(1),
- input_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
- auto input_desc = input_desc_s.ConsumeValueOrDie();
-
- auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
- hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
- hidden_state_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
- auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
-
- auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
- output_shape.dim_size(0), output_shape.dim_size(1),
- output_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
- auto output_desc = output_desc_s.ConsumeValueOrDie();
-
- auto input_data = AsDeviceMemory<T>(input);
- auto input_h_data = AsDeviceMemory<T>(input_h);
- DeviceMemory<T> input_c_data;
- if (HasInputC()) {
- input_c_data = AsDeviceMemory<T>(input_c);
- }
- auto params_data = AsDeviceMemory<T>(params);
- auto output_data = AsDeviceMemory<T>(output);
- auto output_h_data = AsDeviceMemory<T>(output_h);
- DeviceMemory<T> output_c_data;
- if (HasInputC()) {
- output_c_data = AsDeviceMemory<T>(output_c);
- }
+ Tensor* output = nullptr;
+ Tensor* output_h = nullptr;
+ Tensor* output_c = nullptr;
+ OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
+ &output_h, &output_c));
+
+ AlgorithmConfig algo_config;
// Creates a memory callback for the reserve_space. The memory lives in the
// output of this kernel. And it will be fed into the backward pass when
// needed.
CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
- if (!is_training_) {
- Tensor* dummy_reserve_space = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(3, {}, &dummy_reserve_space));
- }
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
- bool launch_status = false;
+ Status launch_status;
{
mutex_lock l(mu_);
- RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
- if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
- CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
- new CudnnRNNPersistentSpaceAllocator(context);
- rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
- const AlgorithmConfig algo_config;
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, algo_config, dropout(), seed(),
- dropout_state_allocator);
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- launch_status =
- stream
- ->ThenRnnForward(
- *rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data, *hidden_state_desc,
- input_c_data, params_data, *output_desc, &output_data,
- *hidden_state_desc, &output_h_data, *hidden_state_desc,
- &output_c_data, is_training_, &reserve_space_allocator,
- &workspace_allocator, /*output_result_profile=*/nullptr)
- .ok();
+ RnnDescriptor* rnn_desc_ptr = nullptr;
+ OP_REQUIRES_OK(
+ context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
+ algo_config, &rnn_state_cache_,
+ &rnn_desc_ptr));
+ launch_status = DoForward<T>(
+ context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
+ input_c, params, is_training_, output, output_h, output_c,
+ &reserve_space_allocator, &workspace_allocator,
+ /*output_profile_result=*/nullptr);
}
- OP_REQUIRES(context, launch_status,
- errors::Internal("Failed to call ThenRnnForward"));
+ OP_REQUIRES_OK(context, launch_status);
}
private:
+ Status AllocateOutputs(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ Tensor** output, Tensor** output_h,
+ Tensor** output_c) {
+ const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+ const TensorShape& output_shape = model_shapes.output_shape;
+
+ TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(1, hidden_state_shape, output_h));
+ if (HasInputC()) {
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(2, hidden_state_shape, output_c));
+ } else {
+ // Only LSTM uses input_c and output_c. So for all other models, we only
+ // need to create dummy outputs.
+ TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
+ }
+ if (!is_training_) {
+ Tensor* dummy_reserve_space = nullptr;
+ TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
+ }
+ return Status::OK();
+ }
+
mutex mu_;
bool is_training_;
- std::unordered_map<CudnnRnnModelShapes, RnnScratchSpace,
- CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>
- rnn_state_cache_ GUARDED_BY(mu_);
+ RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
};
#define REGISTER_GPU(T) \
@@ -981,184 +1155,141 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
+ RnnInputMode input_mode;
+ OP_REQUIRES_OK(context,
+ ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
+ model_shapes.input_size, &input_mode));
- const auto& input_shape = model_shapes.input_shape;
- const auto& hidden_state_shape = model_shapes.hidden_state_shape;
- const auto& output_shape = model_shapes.output_shape;
-
- auto data_type = ToDataType<T>::value;
const Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->input("output", &output));
- OP_REQUIRES(context, output_shape == output->shape(),
- errors::InvalidArgument(
- "input_h and input_c must have the same shape: ",
- input_h->shape().DebugString(), " ",
- input_c->shape().DebugString()));
const Tensor* output_h = nullptr;
- OP_REQUIRES_OK(context, context->input("output_h", &output_h));
- OP_REQUIRES(context, output_h->shape() == hidden_state_shape,
- errors::InvalidArgument(
- "Invalid output_h shape: ", output_h->shape().DebugString(),
- " ", hidden_state_shape.DebugString()));
const Tensor* output_c = nullptr;
- if (HasInputC()) {
- // Only LSTM uses input_c and output_c. So for all other models, we only
- // need to create dummy outputs.
- OP_REQUIRES_OK(context, context->input("output_c", &output_c));
- OP_REQUIRES(context, output_c->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_c shape: ",
- output_c->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
- }
-
const Tensor* output_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->input("output_backprop", &output_backprop));
- OP_REQUIRES(context, output_backprop->shape() == output_shape,
- errors::InvalidArgument("Invalid output_backprop shapes: ",
- output_backprop->shape().DebugString(),
- " ", output_shape.DebugString()));
-
const Tensor* output_h_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->input("output_h_backprop", &output_h_backprop));
- OP_REQUIRES(
- context, output_h_backprop->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_h_backprop shapes: ",
- output_h_backprop->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
const Tensor* output_c_backprop = nullptr;
- if (HasInputC()) {
- OP_REQUIRES_OK(context,
- context->input("output_c_backprop", &output_c_backprop));
- OP_REQUIRES(
- context, output_c_backprop->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_c_backprop shapes: ",
- output_c_backprop->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
- }
- const Tensor* reserve_space_const = nullptr;
- // This is the same "reserve_space" created by the forward op.
- // It can also be modified by this backward operation.
+ const Tensor* reserve_space = nullptr;
OP_REQUIRES_OK(context,
- context->input("reserve_space", &reserve_space_const));
- // Cudnn needs the reserve space to be writeable. This is fine because they
- // are opaque.
- Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const);
+ ExtractBackwardInputs(context, model_shapes, model_types(),
+ &output, &output_h, &output_c,
+ &output_backprop, &output_h_backprop,
+ &output_c_backprop, &reserve_space));
Tensor* input_backprop = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, input->shape(), &input_backprop));
Tensor* input_h_backprop = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(),
- &input_h_backprop));
Tensor* input_c_backprop = nullptr;
- if (HasInputC()) {
- OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(),
- &input_c_backprop));
- } else {
- OP_REQUIRES_OK(context,
- context->allocate_output(2, {}, &input_c_backprop));
- }
Tensor* params_backprop = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(),
- &params_backprop));
-
- auto* stream = context->op_device_context()->stream();
- auto* executor = stream->parent();
- RnnInputMode input_mode;
OP_REQUIRES_OK(context,
- ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
- model_shapes.input_size, &input_mode));
+ AllocateOutputs(context, model_shapes, params->shape(),
+ &input_backprop, &input_h_backprop,
+ &input_c_backprop, &params_backprop));
- auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
- input_shape.dim_size(0), input_shape.dim_size(1),
- input_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
- auto input_desc = input_desc_s.ConsumeValueOrDie();
-
- auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
- hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
- hidden_state_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
- auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
-
- auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
- output_shape.dim_size(0), output_shape.dim_size(1),
- output_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
- auto output_desc = output_desc_s.ConsumeValueOrDie();
-
- auto input_data = AsDeviceMemory<T>(input);
- auto input_h_data = AsDeviceMemory<T>(input_h);
- DeviceMemory<T> input_c_data;
- if (HasInputC()) {
- input_c_data = AsDeviceMemory<T>(input_c);
- }
- auto params_data = AsDeviceMemory<T>(params);
- auto output_data = AsDeviceMemory<T>(output);
- auto output_h_data = AsDeviceMemory<T>(output_h);
- DeviceMemory<T> output_c_data;
- if (HasInputC()) {
- output_c_data = AsDeviceMemory<T>(output_c);
- }
- auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
- auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
- DeviceMemory<T> output_c_backprop_data;
- if (HasInputC()) {
- output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
- }
- auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
- auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
- DeviceMemory<T> input_c_backprop_data;
- if (HasInputC()) {
- input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
- }
- auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
- auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
- bool launch_status = false;
+ const AlgorithmConfig default_algo_config;
+ Status launch_status;
{
mutex_lock l(mu_);
- RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
- if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
- CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
- new CudnnRNNPersistentSpaceAllocator(context);
- rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
- const AlgorithmConfig algo_config;
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, algo_config, dropout(), seed(),
- dropout_state_allocator);
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- launch_status =
- stream
- ->ThenRnnBackward(
- *rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data, *hidden_state_desc,
- input_c_data, params_data, *output_desc, output_data,
- *hidden_state_desc, output_h_data, *hidden_state_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, &input_backprop_data,
- &input_h_backprop_data, &input_c_backprop_data,
- &params_backprop_data, &reserve_space_uint8,
- &workspace_allocator, /*output_result_profile=*/nullptr)
- .ok();
+ RnnDescriptor* rnn_desc_ptr = nullptr;
+ OP_REQUIRES_OK(
+ context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
+ default_algo_config,
+ &rnn_state_cache_, &rnn_desc_ptr));
+ launch_status = DoBackward<T>(
+ context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
+ input_c, params, output, output_h, output_c, output_backprop,
+ output_h_backprop, output_c_backprop, reserve_space, input_backprop,
+ input_h_backprop, input_c_backprop, params_backprop,
+ &workspace_allocator, /*output_profile_result=*/nullptr);
}
- OP_REQUIRES(context, launch_status,
- errors::Internal("Failed to call ThenRnnBackward"));
+ OP_REQUIRES_OK(context, launch_status);
}
private:
mutex mu_;
- std::unordered_map<CudnnRnnModelShapes, RnnScratchSpace,
- CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>
- rnn_state_cache_ GUARDED_BY(mu_);
+ RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
+
+ Status ExtractBackwardInputs(
+ OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
+ const CudnnModelTypes& model_types, const Tensor** output,
+ const Tensor** output_h, const Tensor** output_c,
+ const Tensor** output_backprop, const Tensor** output_h_backprop,
+ const Tensor** output_c_backprop, const Tensor** reserve_space) {
+ TF_RETURN_IF_ERROR(context->input("output", output));
+ TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
+ TF_RETURN_IF_ERROR(context->input("output_h", output_h));
+ TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
+ if (model_types.HasInputC()) {
+ TF_RETURN_IF_ERROR(context->input("output_c", output_c));
+ TF_RETURN_IF_ERROR(
+ context->input("output_c_backprop", output_c_backprop));
+ }
+ TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
+ const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+ const TensorShape& output_shape = model_shapes.output_shape;
+
+ if (output_shape != (*output)->shape()) {
+ return errors::InvalidArgument(
+ "Invalid output shape: ", (*output)->shape().DebugString(), " ",
+ output_shape.DebugString());
+ }
+ if (hidden_state_shape != (*output_h)->shape()) {
+ return errors::InvalidArgument(
+ "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
+ hidden_state_shape.DebugString());
+ }
+
+ if (output_shape != (*output_backprop)->shape()) {
+ return errors::InvalidArgument("Invalid output_backprop shape: ",
+ (*output_backprop)->shape().DebugString(),
+ " ", output_shape.DebugString());
+ }
+ if (hidden_state_shape != (*output_h_backprop)->shape()) {
+ return errors::InvalidArgument(
+ "Invalid output_h_backprop shape: ",
+ (*output_h_backprop)->shape().DebugString(), " ",
+ hidden_state_shape.DebugString());
+ }
+
+ if (model_types.HasInputC()) {
+ if (hidden_state_shape != (*output_c)->shape()) {
+ return errors::InvalidArgument(
+ "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
+ hidden_state_shape.DebugString());
+ }
+ if (hidden_state_shape != (*output_c_backprop)->shape()) {
+ return errors::InvalidArgument(
+ "Invalid output_c_backprop shape: ",
+ (*output_c_backprop)->shape().DebugString(), " ",
+ hidden_state_shape.DebugString());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status AllocateOutputs(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const TensorShape& params_shape,
+ Tensor** input_backprop, Tensor** input_h_backprop,
+ Tensor** input_c_backprop, Tensor** params_backprop) {
+ const TensorShape& input_shape = model_shapes.input_shape;
+ const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, input_shape, input_backprop));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(1, hidden_state_shape, input_h_backprop));
+ if (HasInputC()) {
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(2, hidden_state_shape, input_c_backprop));
+ } else {
+ // Only LSTM uses input_c and output_c. So for all other models, we only
+ // need to create dummy outputs.
+ TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
+ }
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(3, params_shape, params_backprop));
+ return Status::OK();
+ }
};
#define REGISTER_GPU(T) \
diff --git a/tensorflow/core/kernels/cwise_op_clip.cc b/tensorflow/core/kernels/cwise_op_clip.cc
index bd22f5777c..14d889e8e3 100644
--- a/tensorflow/core/kernels/cwise_op_clip.cc
+++ b/tensorflow/core/kernels/cwise_op_clip.cc
@@ -70,8 +70,9 @@ class ClipOp : public OpKernel {
functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
} else {
- OP_REQUIRES(ctx, (in0.shape() == in2.shape() &&
- TensorShapeUtils::IsScalar(in1.shape())),
+ OP_REQUIRES(ctx,
+ (in0.shape() == in2.shape() &&
+ TensorShapeUtils::IsScalar(in1.shape())),
errors::InvalidArgument(
"clip_value_min and clip_value_max must be either of "
"the same shape as input, or a scalar. ",
@@ -90,12 +91,12 @@ namespace functor {
template <typename T>
struct UnaryClipFunc {
UnaryClipFunc(const T& value_min, const T& value_max)
- : value_min_(value_min), value_max_(value_max) {}
+ : value_min(value_min), value_max(value_max) {}
const T operator()(const T& value) const {
- return std::max(std::min(value, value_max_), value_min_);
+ return std::max(std::min(value, value_max), value_min);
}
- T value_min_;
- T value_max_;
+ T value_min;
+ T value_max;
};
template <typename T>
struct UnaryClipOp<CPUDevice, T> {
@@ -110,11 +111,11 @@ struct UnaryClipOp<CPUDevice, T> {
// Binary functor for clip [Tensor, Scalar, Tensor]
template <typename T>
struct BinaryRightClipFunc {
- BinaryRightClipFunc(const T& value_min) : value_min_(value_min) {}
+ explicit BinaryRightClipFunc(const T& value_min) : value_min(value_min) {}
const T operator()(const T& value, const T& value_max) const {
- return std::max(std::min(value, value_max), value_min_);
+ return std::max(std::min(value, value_max), value_min);
}
- T value_min_;
+ T value_min;
};
template <typename T>
struct BinaryRightClipOp<CPUDevice, T> {
@@ -130,11 +131,11 @@ struct BinaryRightClipOp<CPUDevice, T> {
// Binary functor for clip [Tensor, Tensor, Scalar]
template <typename T>
struct BinaryLeftClipFunc {
- BinaryLeftClipFunc(const T& value_max) : value_max_(value_max) {}
+ explicit BinaryLeftClipFunc(const T& value_max) : value_max(value_max) {}
const T operator()(const T& value, const T& value_min) const {
- return std::max(std::min(value, value_max_), value_min);
+ return std::max(std::min(value, value_max), value_min);
}
- T value_max_;
+ T value_max;
};
template <typename T>
struct BinaryLeftClipOp<CPUDevice, T> {
diff --git a/tensorflow/core/kernels/cwise_op_clip.h b/tensorflow/core/kernels/cwise_op_clip.h
index 1a4bf8cf1d..171b6932c2 100644
--- a/tensorflow/core/kernels/cwise_op_clip.h
+++ b/tensorflow/core/kernels/cwise_op_clip.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OP_CLIP_H_
-#define TENSORFLOW_KERNELS_CWISE_OP_CLIP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
#include "tensorflow/core/kernels/cwise_ops_common.h"
@@ -55,7 +55,7 @@ struct TernaryClipOp {
typename TTypes<T>::ConstFlat &in2_flat,
typename TTypes<T>::Flat &out_flat) const;
};
-}
+} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OP_CLIP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_
diff --git a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
index 5c07847548..44dea7dee9 100644
--- a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
@@ -62,10 +62,10 @@ struct UnaryClipOp<GPUDevice, T> {
typename TTypes<T>::Flat &out_flat) const {
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
- UnaryClipCustomKernel<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
- out_flat.data());
+ UnaryClipCustomKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
+ out_flat.data());
}
};
@@ -78,10 +78,10 @@ struct BinaryRightClipOp<GPUDevice, T> {
typename TTypes<T>::Flat &out_flat) const {
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
- BinaryRightClipCustomKernel<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
- out_flat.data());
+ BinaryRightClipCustomKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
+ out_flat.data());
}
};
@@ -94,10 +94,10 @@ struct BinaryLeftClipOp<GPUDevice, T> {
typename TTypes<T>::Flat &out_flat) const {
CudaLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
- BinaryLeftClipCustomKernel<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
- out_flat.data());
+ BinaryLeftClipCustomKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ in0_flat.size(), in0_flat.data(), in1_flat.data(), in2_flat.data(),
+ out_flat.data());
}
};
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 8c4f0218ee..e856ede44b 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -281,6 +281,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 3f88d6dee8..fa33867ec1 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
@@ -35,7 +36,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
@@ -80,24 +81,28 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(
+ interleave_func_, std::move(other_arguments), &captured_func));
*output =
- new Dataset(input, std::move(captured_func), cycle_length, block_length,
- sloppy, buffer_output_elements, prefetch_input_elements,
- output_types_, output_shapes_);
+ new Dataset(ctx, input, interleave_func_, std::move(captured_func),
+ cycle_length, block_length, sloppy, buffer_output_elements,
+ prefetch_input_elements, output_types_, output_shapes_);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- Dataset(const DatasetBase* input,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : input_(input),
+ : GraphDatasetBase(ctx),
+ input_(input),
+ interleave_func_(func),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
@@ -128,6 +133,52 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return "ParallelInterleaveDatasetOp::Dataset";
}
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* sloppy_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
+ Node* buffer_output_elements_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
+ Node* prefetch_input_elements_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
+ &prefetch_input_elements_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, sloppy_node},
+ {5, buffer_output_elements_node},
+ {6, prefetch_input_elements_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
private:
int64 num_threads() const {
return cycle_length_ + prefetch_input_elements_;
@@ -156,17 +207,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// that a caller will block waiting for an element to be produced.
//
// Pointers to these worker states are kept in 2 disjoint data structures:
- // 1. `interleave_` is a vector containing pointers to `WorkerState`s that
- // we
- // are interleaving. Worker threads backing these WorkerStates should
- // be regularly producing values.
- // 2. `staging_` is a deque containing pointers to WorkerStates that we
- // will move to `interleave_` when an iterator in `interleave_` is
- // exhausted.
+ // 1. `interleave_indices_` is a vector containing indices of WorkerStates
+ // in `workers_` that we are interleaving. Worker threads backing these
+ // WorkerStates should be regularly producing values.
+ // 2. `staging_indices_` is a deque containing indices of WorkerStates in
+ // `workers_` that we will move to `interleave_indices_` when an
+ // iterator in `interleave_indices_` is exhausted.
//
// The client calls `GetNext[Internal]()` to retrieve an output element. The
- // internal implementation updates the state of `interleave_` and `staging_`
- // as output iterators (run by the worker threads) are exhausted.
+ // internal implementation updates the state of `interleave_indices_` and
+ // `staging_indices_` as output iterators (run by the worker threads) are
+ // exhausted.
//
// `input_impl_` is the input iterator that generates arguments for the
// flat-map function (`captured_func_`). It is set to an iterator at
@@ -175,18 +226,19 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// memory.
//
// A few invariants are maintained:
- // 1. No element in interleave_ should be a nullptr unless `staging_` is
- // empty and `input_impl_` is empty.
+ // 1. No element in interleave_indices_ should be a -1 unless
+ // `staging_indices_` is empty and `input_impl_` is empty.
// 2. Every `worker_` element is pointed to by at most one element of the
- // union of `interleave_` and `staging_`.
+ // union of `interleave_indices_` and `staging_indices_`.
// 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
- // an element in `interleave_` or `staging_`.
+ // an element in `interleave_indices_` or `staging_indices_`.
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
- workers_(dataset()->num_threads()) {}
+ workers_(dataset()->num_threads()),
+ worker_thread_states_(dataset()->num_threads()) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -211,10 +263,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// not have an item readily available.
bool can_produce_elements = false;
bool must_wait_for_input = true;
- for (int64 i = 0; i < interleave_.size(); ++i) {
- int64 index = (next_index_ + i) % interleave_.size();
- WorkerState* current_worker = interleave_[index];
- if (!current_worker) continue; // Empty interleave elements.
+ for (int64 i = 0; i < interleave_indices_.size(); ++i) {
+ int64 index = (next_index_ + i) % interleave_indices_.size();
+ int64 current_worker_index = interleave_indices_[index];
+ if (current_worker_index < 0) {
+ continue; // Empty interleave elements.
+ }
+ WorkerState* current_worker = &workers_[current_worker_index];
can_produce_elements |= current_worker->MayHaveElements();
if (!current_worker->outputs.empty()) {
// We have an element!
@@ -222,7 +277,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (i == 0) {
block_count_++;
if (block_count_ == dataset()->block_length_) {
- next_index_ = (index + 1) % interleave_.size();
+ next_index_ = (index + 1) % interleave_indices_.size();
block_count_ = 0;
}
} else {
@@ -245,7 +300,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
break;
} else if (!current_worker->is_producing) {
// This iterator has reached end of input.
- interleave_[index] = nullptr;
+ interleave_indices_[index] = -1;
if (input_impl_) {
// Start prefetching a new iterator.
std::vector<Tensor> args;
@@ -255,16 +310,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
input_impl_.reset();
} else {
current_worker->SetInputs(s, std::move(args));
- staging_.emplace_back(current_worker);
+ staging_indices_.emplace_back(current_worker_index);
}
}
- if (!staging_.empty()) {
- // Move a worker from `staging_` to `interleave_`.
- interleave_[index] = staging_.front();
- staging_.pop_front();
+ if (!staging_indices_.empty()) {
+ // Move a worker from `staging_indices_` to
+ // `interleave_indices_`.
+ interleave_indices_[index] = staging_indices_.front();
+ staging_indices_.pop_front();
- next_index_ = (index + 1) % interleave_.size();
+ next_index_ = (index + 1) % interleave_indices_.size();
block_count_ = 0;
// Restart the inner [for] loop
can_produce_elements = true;
@@ -285,7 +341,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
- interleave_[next_index_]->cond_var.wait(l);
+ workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
}
}
@@ -293,6 +349,137 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ // The order of locking is important here to avoid deadlock.
+ mutex_lock l(mu_);
+ mutex_lock ckpt_l(ckpt_mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_exhausted"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("next_index"), next_index_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("block_count"), block_count_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("workers_size"), workers_.size()));
+ for (int i = 0; i < workers_.size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
+ }
+ for (int i = 0; i < worker_thread_states_.size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"),
+ interleave_indices_.size()));
+ for (int i = 0; i < interleave_indices_.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("interleave_indices_", i)),
+ interleave_indices_[i]));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"),
+ staging_indices_.size()));
+ for (int i = 0; i < staging_indices_.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("staging_indices_", i)),
+ staging_indices_[i]));
+ }
+ if (!worker_threads_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("worker_threads_running"), ""));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ // The order of locking is important here to avoid deadlock.
+ mutex_lock l(mu_);
+ mutex_lock ckpt_l(ckpt_mu_);
+ if (!reader->Contains(full_name("input_exhausted"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp));
+ next_index_ = size_t(temp);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp));
+ block_count_ = size_t(temp);
+
+ // Restore WorkerStates.
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("workers_size"), &temp));
+ if (temp != dataset()->num_threads()) {
+ return errors::Internal("Expected ", dataset()->num_threads(),
+ " worker states but found ", temp, ".");
+ }
+ for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
+ }
+ for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
+ }
+
+ // Restore `interleave_indices_`.
+ std::set<int64> all_indices;
+ {
+ int64 interleave_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"),
+ &interleave_size));
+ interleave_indices_.reserve(interleave_size);
+ for (int64 i = 0; i < interleave_size; ++i) {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("interleave_indices_", i)), &temp));
+ if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
+ return errors::Internal(
+ "Duplicate entry for ", temp,
+ " found when reading interleave and staging indices.");
+ }
+ if (temp >= 0) {
+ all_indices.insert(temp);
+ }
+ interleave_indices_.emplace_back(temp);
+ }
+ }
+
+ // Restore `staging_indices_`.
+ {
+ int64 staging_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("staging_size"), &staging_size));
+ for (int i = 0; i < staging_size; ++i) {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("staging_indices_", i)), &temp));
+ if (all_indices.find(temp) != all_indices.end()) {
+ return errors::Internal(
+ "Duplicate entry for ", temp,
+ " found when reading interleave and staging indices.");
+ }
+ if (temp >= 0) {
+ all_indices.insert(temp);
+ }
+ staging_indices_.emplace_back(temp);
+ }
+ }
+
+ // Start Worker threads.
+ if (reader->Contains(full_name("worker_threads_running"))) {
+ worker_threads_.reserve(dataset()->num_threads());
+ for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ worker_threads_.emplace_back(ctx->env()->StartThread(
+ {}, "worker_thread",
+ std::bind(&Iterator::WorkerThread, this,
+ new IteratorContext(*ctx), i)));
+ }
+ }
+ return Status::OK();
+ }
+
private:
// OutputElem contains the information from a call to GetNext by an output
// iterator.
@@ -345,6 +532,31 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
};
+ // The internal state of a worker thread that is not already captured
+ // in its `WorkerState`.
+ //
+ // This is needed only for checkpointing purposes. We keep this
+ // separate from `WorkerState` and guard its fields using a separate
+ // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
+ struct WorkerThreadState {
+ // The output element that has been produced from the input iterator
+ // and is waiting to be added to `WorkerState.outputs`.
+ OutputElem output_elem;
+
+ // Whether the input iterator returned an `end_of_sequence`.
+ bool end_of_sequence = false;
+
+ // Status returned from `MakeIteratorFromInputElement`.
+ Status iterator_creation_status;
+
+ // The arguments to be used to construct `iterator`.
+ std::vector<Tensor> input;
+
+ std::unique_ptr<IteratorBase> iterator;
+
+ WorkerThreadState() : output_elem(Status::OK()) {}
+ };
+
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty()) {
@@ -363,19 +575,38 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::bind(&Iterator::WorkerThread, this,
new IteratorContext(*ctx), i)));
if (i < dataset()->cycle_length_) {
- interleave_.push_back(&workers_[i]);
+ interleave_indices_.push_back(i);
} else {
- staging_.push_back(&workers_[i]);
+ staging_indices_.push_back(i);
}
}
- DCHECK(interleave_.size() == dataset()->cycle_length_);
- DCHECK(staging_.size() == dataset()->prefetch_input_elements_);
+ DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
+ DCHECK(staging_indices_.size() ==
+ dataset()->prefetch_input_elements_);
}
return Status::OK();
}
// Produces elements into the worker's output buffers.
void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
+ //
+ // 1. Any local state that may need to be checkpointed should be kept
+ // in `worker_thread_states_[thread_index]`.
+ // 2. `WorkerThreadState` should contain state that is needed only for
+ // checkpointing, i.e., if we were to remove checkpointing support,
+ // we could keep that state as local variables in this thread.
+ // 3. This thread should only read/write state at `thread_index`
+ // and should not access other thread states.
+ // 4. When restoring from checkpoint, threads are started only after
+ // the restore is complete.
+ // 5. Once restored from a checkpoint, the local state is edited only
+ // by this thread. 3 & 4 allow making assumptions like temporarily
+ // caching local state in this thread and using it outside a lock
+ // e.g. `make_new_iterator`.
+ // 6. `ckpt_mu_` should be wisely used to create *consistent*
+ // checkpoint markers.
+
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
std::unique_ptr<IteratorContext> ctx(ctx_ptr);
@@ -383,38 +614,135 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
});
-
+ bool make_new_iterator;
+ {
+ tf_shared_lock l(ckpt_mu_);
+ // Decide whether a new iterator should be built.
+ // 1. If there is an existing iterator, we use it.
+ // 2. If there was an error in iterator creation that could not be
+ // notified to the client we attempt to send that to the client
+ // first.
+ make_new_iterator =
+ worker_thread_states_[thread_index].iterator == nullptr &&
+ worker_thread_states_[thread_index].iterator_creation_status.ok();
+ }
+ // Even though `make_new_iterator` has cached values from
+ // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
+ // it is safe to *read* `make_new_iterator`outside of a lock without
+ // worrying about concurrent changes to values in
+ // `worker_thread_states_[thread_index]`. See comment at the start of
+ // this function for details.
while (true) {
- // 1. Wait for input.
- std::vector<Tensor> input;
- {
- mutex_lock l(mu_);
- while (!cancelled_ && !workers_[thread_index].is_producing) {
- workers_[thread_index].cond_var.wait(l);
+ // Whether creation of the iterator succeeded.
+ Status iterator_creation_status;
+ // 1. Build a new iterator or use the existing one.
+ if (make_new_iterator) {
+ // 1a. Get new input tensors or use the exiting ones.
+
+ bool read_new_input;
+
+ {
+ tf_shared_lock l(ckpt_mu_);
+ // worker_thread_states_[thread_index].input will be non-empty
+ // if checkpointing happened at CHECKPOINT_MARKER_A.
+ read_new_input =
+ worker_thread_states_[thread_index].input.empty();
}
- if (cancelled_) return;
- input.swap(workers_[thread_index].input);
- }
- // 2. Run the user defined function to produce a new iterator.
- std::unique_ptr<IteratorBase> iterator;
- Status s = dataset::MakeIteratorFromInputElement(
- ctx.get(), input, thread_index, dataset()->captured_func_.get(),
- prefix(), &iterator);
- input.clear(); // Release memory as early as possible.
+ if (read_new_input) {
+ mutex_lock l(mu_);
+ while (!cancelled_ && !workers_[thread_index].is_producing) {
+ workers_[thread_index].cond_var.wait(l);
+ }
+ if (cancelled_) return;
+ // Copy the input tensors so that we do not need to block on `mu_`
+ // when building the iterator.
+ // We keep a copy of the input tensors in
+ // `WorkerThreadState.input` till the iterator is in use. This is
+ // used in `RestoreInternal` to re-build the iterator.
+ // TODO(b/78046638): Explore ways to avoid tracking the input
+ // tensors.
+ tf_shared_lock ckpt_l(ckpt_mu_);
+ worker_thread_states_[thread_index].input.swap(
+ workers_[thread_index].input);
+ // CHECKPOINT_MARKER_A
+ // We have the input tensors but have not built the iterator yet.
+ }
- if (!s.ok()) {
+ // 1b. Run the user defined function to produce a new iterator.
+ {
+ tf_shared_lock l(ckpt_mu_);
+ worker_thread_states_[thread_index].iterator_creation_status =
+ dataset::MakeIteratorFromInputElement(
+ ctx.get(), worker_thread_states_[thread_index].input,
+ thread_index, dataset()->captured_func_.get(), prefix(),
+ &worker_thread_states_[thread_index].iterator);
+ iterator_creation_status =
+ worker_thread_states_[thread_index].iterator_creation_status;
+ if (!iterator_creation_status.ok()) {
+ worker_thread_states_[thread_index].input.clear();
+ }
+ // CHECKPOINT_MARKER_B
+ // Either an iterator has been successfully built and placed in
+ // `worker_thread_states_[thread_index].iterator` or it failed and
+ // a non-OK status has been put in
+ // `worker_thread_states_[thread_index].iterator_creation_status`.
+ }
+ } else {
+ tf_shared_lock l(ckpt_mu_);
+ iterator_creation_status =
+ worker_thread_states_[thread_index].iterator_creation_status;
+ // Mark that we have used up the restored iterator.
+ make_new_iterator = true;
+ }
+ // 2. Start producing elements or send error state to client if
+ // iterator creation failed.
+ if (!iterator_creation_status.ok()) {
mutex_lock l(mu_);
- workers_[thread_index].outputs.emplace_back(s);
+ // Wait for space in the prefetch queue.
+ while (!cancelled_ && workers_[thread_index].outputs.size() ==
+ dataset()->buffer_output_elements_) {
+ workers_[thread_index].cond_var.wait(l);
+ }
+ if (cancelled_) return;
+ tf_shared_lock ckpt_l(ckpt_mu_);
+ workers_[thread_index].outputs.emplace_back(
+ iterator_creation_status);
workers_[thread_index].is_producing = false;
+ worker_thread_states_[thread_index].iterator_creation_status =
+ Status::OK();
+ // CHECKPOINT_MARKER_C
+ // Non-OK iterator creation status has been notified to the
+ // client.
workers_[thread_index].cond_var.notify_one();
} else {
- // 3. Produce elements
bool end_of_sequence = false;
while (!end_of_sequence) {
// 3.a Produce an element!
- std::vector<Tensor> output_elem;
- s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence);
+ {
+ tf_shared_lock ckpt_l(ckpt_mu_);
+ if (worker_thread_states_[thread_index]
+ .output_elem.status.ok() &&
+ worker_thread_states_[thread_index]
+ .output_elem.output.empty() &&
+ !worker_thread_states_[thread_index].end_of_sequence) {
+ worker_thread_states_[thread_index].output_elem.status =
+ worker_thread_states_[thread_index].iterator->GetNext(
+ ctx.get(),
+ &worker_thread_states_[thread_index]
+ .output_elem.output,
+ &worker_thread_states_[thread_index].end_of_sequence);
+ end_of_sequence =
+ worker_thread_states_[thread_index].end_of_sequence;
+ } else {
+ end_of_sequence =
+ worker_thread_states_[thread_index].end_of_sequence;
+ }
+ // CHECKPOINT_MARKER_D
+ // An element has been read or an error or end_of_sequence has
+ // been received from the input iterator and is waiting to be
+ // sent to client.
+ }
// 3.b Make it available to the client.
{
@@ -427,30 +755,255 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
if (cancelled_) return;
- // Output the element.
+ tf_shared_lock ckpt_l(ckpt_mu_);
workers_[thread_index].is_producing = !end_of_sequence;
- if (!end_of_sequence) {
- workers_[thread_index].outputs.emplace_back(s);
+
+ // Output the element.
+
+ // Move the temporary state in WorkerThreadState to WorkerState
+ // and mark it as used.
+ if (end_of_sequence) {
+ worker_thread_states_[thread_index].iterator.reset();
+ worker_thread_states_[thread_index].input.clear();
+ worker_thread_states_[thread_index].end_of_sequence = false;
+ } else {
+ workers_[thread_index].outputs.emplace_back(
+ worker_thread_states_[thread_index].output_elem.status);
workers_[thread_index].outputs.back().output.swap(
- output_elem);
+ worker_thread_states_[thread_index].output_elem.output);
}
+ worker_thread_states_[thread_index].output_elem.status =
+ Status::OK();
if (dataset()->sloppy_) {
sloppy_cond_var_.notify_one();
} else {
workers_[thread_index].cond_var.notify_one();
}
+ // CHECKPOINT_MARKER_E
+ // Output element or iterator status has been sent to the
+ // client.
}
}
}
}
}
+ Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ string prefix = strings::StrCat("worker_", index);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_input_size")),
+ workers_[index].input.size()));
+ for (int i = 0; i < workers_[index].input.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat(prefix, "_input_", i)),
+ workers_[index].input[i]));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_outputs_size")),
+ workers_[index].outputs.size()));
+ for (int i = 0; i < workers_[index].outputs.size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteOutputElemLocked(
+ writer, workers_[index].outputs[i],
+ full_name(strings::StrCat(prefix, "_outputs_", i))));
+ }
+ if (workers_[index].is_producing) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_is_producing")), ""));
+ }
+ return Status::OK();
+ }
+
+ Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
+ IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ string worker_prefix = strings::StrCat("worker_", index);
+ // Restore inputs.
+ int64 input_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(worker_prefix, "_input_size")),
+ &input_size));
+ workers_[index].input.reserve(input_size);
+ for (int i = 0; i < input_size; ++i) {
+ workers_[index].input.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat(worker_prefix, "_input_", i)),
+ &workers_[index].input.back()));
+ }
+ int64 outputs_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(worker_prefix, "_outputs_size")),
+ &outputs_size));
+ for (int i = 0; i < outputs_size; ++i) {
+ workers_[index].outputs.emplace_back(Status::OK());
+ TF_RETURN_IF_ERROR(ReadOutputElemLocked(
+ reader, &workers_[index].outputs.back(),
+ full_name(strings::StrCat(worker_prefix, "_outputs_", i))));
+ }
+ if (reader->Contains(
+ full_name(strings::StrCat(worker_prefix, "_is_producing")))) {
+ workers_[index].is_producing = true;
+ } else {
+ workers_[index].is_producing = false;
+ }
+ return Status::OK();
+ }
+
+ Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer,
+ int index)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ string prefix = strings::StrCat("worker_thread_", index);
+ if (worker_thread_states_[index].iterator != nullptr) {
+ TF_RETURN_IF_ERROR(
+ SaveParent(writer, worker_thread_states_[index].iterator));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_iterator_exhausted")), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_input_size")),
+ worker_thread_states_[index].input.size()));
+ for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat(prefix, "_input_", i)),
+ worker_thread_states_[index].input[i]));
+ }
+ TF_RETURN_IF_ERROR(WriteStatusLocked(
+ writer, strings::StrCat(prefix, "_iterator_creation_status"),
+ worker_thread_states_[index].iterator_creation_status));
+ TF_RETURN_IF_ERROR(WriteOutputElemLocked(
+ writer, worker_thread_states_[index].output_elem,
+ full_name(strings::StrCat(prefix, "_output"))));
+ if (worker_thread_states_[index].end_of_sequence) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_end_of_sequence")), ""));
+ }
+ return Status::OK();
+ }
+
+ Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
+ IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ string worker_prefix = strings::StrCat("worker_thread_", index);
+ // Restore inputs.
+ int64 input_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(worker_prefix, "_input_size")),
+ &input_size));
+ worker_thread_states_[index].input.reserve(input_size);
+ for (int i = 0; i < input_size; ++i) {
+ worker_thread_states_[index].input.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat(worker_prefix, "_input_", i)),
+ &worker_thread_states_[index].input.back()));
+ }
+ // Restore iterator.
+ if (reader->Contains(full_name(
+ strings::StrCat(worker_prefix, "_iterator_exhausted")))) {
+ worker_thread_states_[index].iterator.reset();
+ } else {
+ std::unique_ptr<IteratorBase> iterator;
+ Status s = dataset::MakeIteratorFromInputElement(
+ ctx, worker_thread_states_[index].input, index,
+ dataset()->captured_func_.get(), prefix(), &iterator);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, iterator));
+ worker_thread_states_[index].iterator.swap(iterator);
+ }
+ TF_RETURN_IF_ERROR(ReadStatusLocked(
+ reader, strings::StrCat(worker_prefix, "_iterator_creation_status"),
+ &worker_thread_states_[index].iterator_creation_status));
+ TF_RETURN_IF_ERROR(ReadOutputElemLocked(
+ reader, &worker_thread_states_[index].output_elem,
+ full_name(strings::StrCat(worker_prefix, "_output"))));
+ if (reader->Contains(full_name(
+ strings::StrCat(worker_prefix, "_end_of_sequence")))) {
+ worker_thread_states_[index].end_of_sequence = true;
+ } else {
+ worker_thread_states_[index].end_of_sequence = false;
+ }
+ return Status::OK();
+ }
+
+ Status WriteOutputElemLocked(IteratorStateWriter* writer,
+ const OutputElem& output_elem,
+ const string& prefix)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(
+ writer, strings::StrCat(prefix, "_status"), output_elem.status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat(prefix, "_output_size"),
+ output_elem.output.size()));
+ for (int i = 0; i < output_elem.output.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ strings::StrCat(prefix, "_output_", i), output_elem.output[i]));
+ }
+ return Status::OK();
+ }
+
+ Status ReadOutputElemLocked(IteratorStateReader* reader,
+ OutputElem* output_elem, const string& prefix)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ TF_RETURN_IF_ERROR(ReadStatusLocked(
+ reader, strings::StrCat(prefix, "_status"), &output_elem->status));
+ int64 output_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat(prefix, "_output_size"), &output_size));
+ output_elem->output.reserve(output_size);
+ for (int i = 0; i < output_size; ++i) {
+ output_elem->output.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat(prefix, "_output_", i),
+ &output_elem->output.back()));
+ }
+ return Status::OK();
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer,
+ const string& prefix, const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
+ static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
+ Status* status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(prefix, "_code")), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(prefix, "_msg")), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
// Mutex & condition variable to guard mutable iterator internals and
// coordinate among worker threads and client thread[s].
- mutex mu_;
+ mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
// The main thread waits on this condition variable if running in sloppy
// mode and no values are available.
condition_variable sloppy_cond_var_;
+ // Mutex used to wait for a consistent state while checkpointing.
+ // Only Save and Restore require an exclusive lock on this mutex. In
+ // other scenarios we just acquire a shared lock so the pipeline's
+ // performance should not be affected in the absence of checkpointing.
+ // A thread must not wait on any condition variable while holding
+ // `ckpt_mu_` in either shared or exclusive modes.
+ mutex ckpt_mu_;
// The iterator producing elements which are converted to datasets by
// the dataset()->captured_func_ then interleaved together.
@@ -461,10 +1014,14 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// workers_ elements are in at most one of interleave_ and staging_.
std::vector<WorkerState> workers_ GUARDED_BY(mu_);
- // The iterators to interleave
- std::vector<WorkerState*> interleave_ GUARDED_BY(mu_);
- // Prefetched iterators
- std::deque<WorkerState*> staging_ GUARDED_BY(mu_);
+ // Stores the temporary state of WorkerThreads which is not stored in
+ // WorkerState. This is used for checkpointing purposes only.
+ std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
+
+ // Indices in `workers_` of iterators to interleave.
+ std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
+ // Indices in `workers_` of prefetched iterators.
+ std::deque<int64> staging_indices_ GUARDED_BY(mu_);
// The index into output_elements_ for next element to produce.
size_t next_index_ GUARDED_BY(mu_) = 0;
@@ -479,6 +1036,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
@@ -492,7 +1050,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
- NameAttrList func_;
+ NameAttrList interleave_func_;
};
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index 08adf4badb..ef332ebee3 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -143,6 +143,8 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_quint16(REGISTER_GATHER_CPU);
TF_CALL_qint16(REGISTER_GATHER_CPU);
+TF_CALL_uint32(REGISTER_GATHER_CPU);
+TF_CALL_uint64(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU
diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD
index 4870d9ae20..66aeec5105 100644
--- a/tensorflow/core/kernels/hexagon/BUILD
+++ b/tensorflow/core/kernels/hexagon/BUILD
@@ -70,6 +70,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
index 4040bf52bf..40bf5a4dc7 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
index 352d548bd3..ada96ae4ea 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
@@ -20,14 +20,14 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+class RemoteFusedGraphExecuteInfo;
+
class GraphTransferUtils {
public:
static std::priority_queue<std::tuple<float, int, string>>
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index 0963dff5fa..7960cb4b05 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <algorithm>
#include <cinttypes>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -73,6 +75,12 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
return nullptr;
}
+GraphTransferer::GraphTransferer() {
+ graph_transfer_info_ = new GraphTransferInfo();
+}
+
+GraphTransferer::~GraphTransferer() { delete graph_transfer_info_; }
+
/**
* graph loading functions
* - LoadGraphFromProto
@@ -142,8 +150,8 @@ Status GraphTransferer::LoadGraphFromProto(
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
- GraphTransferInfo::GraphInputNodeInfo& graph_input_node_info =
- *graph_transfer_info_.add_graph_input_node_info();
+ GraphTransferGraphInputNodeInfo& graph_input_node_info =
+ *graph_transfer_info_->add_graph_input_node_info();
graph_input_node_info.set_name(input_node_info.first);
graph_input_node_info.set_dtype(input_node_info.second.dtype());
for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
@@ -159,8 +167,8 @@ Status GraphTransferer::LoadGraphFromProto(
const Node* node = node_name_cache_list_.at(node_id);
CHECK_NOTNULL(node);
- GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info =
- *graph_transfer_info_.add_graph_output_node_info();
+ GraphTransferGraphOutputNodeInfo& graph_output_node_info =
+ *graph_transfer_info_->add_graph_output_node_info();
graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
// Get output tensor shape type
@@ -231,17 +239,17 @@ Status GraphTransferer::LoadGraphFromProtoFile(
void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
// TODO(satok): optimize complexity
- std::unordered_map<int, GraphTransferInfo::NodeInputInfo*> input_map;
- for (GraphTransferInfo::NodeInputInfo& input :
- *graph_transfer_info_.mutable_node_input_info()) {
+ std::unordered_map<int, GraphTransferNodeInputInfo*> input_map;
+ for (GraphTransferNodeInputInfo& input :
+ *graph_transfer_info_->mutable_node_input_info()) {
input_map.emplace(input.node_id(), &input);
}
// Setup dependency map placeholder
std::vector<int> output_node_ids;
std::unordered_map<int, std::unordered_set<int>> dependency_map;
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
const int node_id = params.node_id();
for (const string& output_node_name : output_node_names) {
if (params.name() == output_node_name) {
@@ -255,7 +263,7 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
continue;
}
CHECK_EQ(input_map.count(node_id), 1);
- for (const GraphTransferInfo::NodeInput& node_input :
+ for (const GraphTransferNodeInput& node_input :
input_map.at(node_id)->node_input()) {
dependency_map.at(node_id).emplace(node_input.node_id());
}
@@ -267,8 +275,8 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
FillDependencyRec(output_node_id, dependency_map, completed);
}
- std::sort(graph_transfer_info_.mutable_node_info()->begin(),
- graph_transfer_info_.mutable_node_info()->end(),
+ std::sort(graph_transfer_info_->mutable_node_info()->begin(),
+ graph_transfer_info_->mutable_node_info()->end(),
TransferParamsComparator(dependency_map));
}
@@ -278,15 +286,15 @@ void GraphTransferer::EnableStrictCheckMode(const bool enable) {
void GraphTransferer::SetSerializedGraphTransferInfo(
const string& serialized_proto) {
- graph_transfer_info_.ParseFromString(serialized_proto);
+ graph_transfer_info_->ParseFromString(serialized_proto);
}
const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
- return graph_transfer_info_;
+ return *graph_transfer_info_;
}
GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
- return graph_transfer_info_;
+ return *graph_transfer_info_;
}
void GraphTransferer::CacheNode(const Node& node) {
@@ -473,8 +481,8 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
data_size = max_bytes_per_data * num_output_elements;
shape_array = BuildShapeArray(shape_handle, context);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(node.name());
const_node_info.set_node_id(id);
// TODO(satok): Make this generic. Never assume rank is 4.
@@ -505,8 +513,8 @@ int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(shape_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(shape_name);
const_node_info.set_node_id(id);
// TODO(satok): Make this generic. Never assume rank is 5.
@@ -528,8 +536,8 @@ int GraphTransferer::RegisterConstTensor(const Tensor& tensor,
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(node_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(node_name);
const_node_info.set_node_id(id);
CHECK_EQ(4, SHAPE_ARRAY_SIZE);
@@ -558,8 +566,8 @@ int GraphTransferer::RegisterConstScalar(const DataType dt, const int val,
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(val_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(val_name);
const_node_info.set_node_id(id);
// TODO(satok): Do not assume rank is 4 here.
@@ -715,8 +723,8 @@ void GraphTransferer::RegisterPadNode(
CHECK_EQ(2, node.num_inputs());
- GraphTransferInfo::NodeInputInfo& node_input_info =
- *graph_transfer_info_.add_node_input_info();
+ GraphTransferNodeInputInfo& node_input_info =
+ *graph_transfer_info_->add_node_input_info();
node_input_info.set_node_id(id);
AddNodeInputByInputIndex(node, 0, &node_input_info);
@@ -761,8 +769,7 @@ void GraphTransferer::RegisterPadNode(
new_const_tensor,
strings::StrCat(input_node->name(), "_", node.name(), "_1"));
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(id);
node_input.set_output_port(0);
} else {
@@ -849,8 +856,7 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id,
const int padding, const int inputs_size,
const std::vector<int>& extra_inputs,
const int outputs_size) {
- GraphTransferInfo::NodeInfo& node_info =
- *graph_transfer_info_.add_node_info();
+ GraphTransferNodeInfo& node_info = *graph_transfer_info_->add_node_info();
node_info.set_name(name);
node_info.set_node_id(id);
node_info.set_type_name(type);
@@ -863,7 +869,7 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id,
void GraphTransferer::AddNodeInputByInputIndex(
const Node& node, const int idx,
- GraphTransferInfo::NodeInputInfo* node_input_info) {
+ GraphTransferNodeInputInfo* node_input_info) {
const Edge* edge = nullptr;
TF_CHECK_OK(node.input_edge(idx, &edge));
const Node* input_node = edge->src();
@@ -873,7 +879,7 @@ void GraphTransferer::AddNodeInputByInputIndex(
const std::string& op_name = input_node->name();
CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
const int src_id = node_name_to_id_cache_map_[op_name];
- GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info->add_node_input();
node_input.set_node_id(src_id);
node_input.set_output_port(port);
}
@@ -882,15 +888,14 @@ void GraphTransferer::AppendNodeInputParams(
const int id, const Node& node, const std::vector<int>& extra_inputs) {
VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
<< ", " << extra_inputs.size();
- GraphTransferInfo::NodeInputInfo& node_input_info =
- *graph_transfer_info_.add_node_input_info();
+ GraphTransferNodeInputInfo& node_input_info =
+ *graph_transfer_info_->add_node_input_info();
node_input_info.set_node_id(id);
for (int i = 0; i < node.num_inputs(); ++i) {
AddNodeInputByInputIndex(node, i, &node_input_info);
}
for (const int extra_input : extra_inputs) {
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(extra_input);
node_input.set_output_port(0);
}
@@ -900,8 +905,8 @@ void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
const int id, const Node& node) {
VLOG(1) << "Append output params: " << node.name() << ", "
<< node.num_outputs();
- GraphTransferInfo::NodeOutputInfo& node_output_info =
- *graph_transfer_info_.add_node_output_info();
+ GraphTransferNodeOutputInfo& node_output_info =
+ *graph_transfer_info_->add_node_output_info();
node_output_info.set_node_id(id);
std::vector<DataType> data_types;
@@ -1030,8 +1035,7 @@ GraphTransferer::TransferParamsComparator::TransferParamsComparator(
: dependency_map_(dep_map) {}
bool GraphTransferer::TransferParamsComparator::operator()(
- const GraphTransferInfo::NodeInfo& obj0,
- const GraphTransferInfo::NodeInfo& obj1) {
+ const GraphTransferNodeInfo& obj0, const GraphTransferNodeInfo& obj1) {
const int node_id0 = obj0.node_id();
const int node_id1 = obj1.node_id();
bool obj0_uses_obj1 = false;
@@ -1114,8 +1118,8 @@ void GraphTransferer::ClearCache() {
void GraphTransferer::DumpNodeTransferParams() const {
LOG(INFO) << "*** Const Nodes ***";
- for (const GraphTransferInfo::ConstNodeInfo& params :
- graph_transfer_info_.const_node_info()) {
+ for (const GraphTransferConstNodeInfo& params :
+ graph_transfer_info_->const_node_info()) {
// TODO(satok): Stop assuming shape size is 4.
CHECK_EQ(params.shape_size(), 4);
LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
@@ -1131,8 +1135,8 @@ void GraphTransferer::DumpNodeTransferParams() const {
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Op Nodes ***";
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
LOG(INFO) << " type: " << params.type_name();
LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id());
@@ -1146,18 +1150,18 @@ void GraphTransferer::DumpNodeTransferParams() const {
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Node input params ***";
- for (const GraphTransferInfo::NodeInputInfo& params :
- graph_transfer_info_.node_input_info()) {
+ for (const GraphTransferNodeInputInfo& params :
+ graph_transfer_info_->node_input_info()) {
LOG(INFO) << "[ " << params.node_id() << " ]";
- for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+ for (const GraphTransferNodeInput& node_input : params.node_input()) {
LOG(INFO) << " src node id = " << node_input.node_id()
<< ", output port = " << node_input.output_port();
}
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Node output params ***";
- for (const GraphTransferInfo::NodeOutputInfo& params :
- graph_transfer_info_.node_output_info()) {
+ for (const GraphTransferNodeOutputInfo& params :
+ graph_transfer_info_->node_output_info()) {
LOG(INFO) << "[ " << params.node_id() << " ]";
for (const int max_size : params.max_byte_size()) {
LOG(INFO) << " max_size = " << max_size;
@@ -1167,8 +1171,8 @@ void GraphTransferer::DumpNodeTransferParams() const {
}
void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
- for (const GraphTransferInfo::ConstNodeInfo& params :
- graph_transfer_info_.const_node_info()) {
+ for (const GraphTransferConstNodeInfo& params :
+ graph_transfer_info_->const_node_info()) {
std::stringstream sstream;
// TODO(satok): Stop assuming shape size is 4.
CHECK_EQ(params.shape_size(), 4);
@@ -1182,9 +1186,9 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Const node count = "
- << graph_transfer_info_.const_node_info_size();
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ << graph_transfer_info_->const_node_info_size();
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
std::stringstream sstream;
sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
<< params.node_id() << std::dec << "," << params.soc_op_id() << ","
@@ -1197,12 +1201,12 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
<< "," << params.output_count() << "," << params.type_name() << "]";
LOG(INFO) << sstream.str();
}
- LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size();
- for (const GraphTransferInfo::NodeInputInfo& params :
- graph_transfer_info_.node_input_info()) {
+ LOG(INFO) << "Op node count = " << graph_transfer_info_->node_info_size();
+ for (const GraphTransferNodeInputInfo& params :
+ graph_transfer_info_->node_input_info()) {
std::stringstream sstream;
sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
- for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+ for (const GraphTransferNodeInput& node_input : params.node_input()) {
sstream << "," << std::hex << node_input.node_id() << std::dec << ","
<< node_input.output_port();
}
@@ -1210,9 +1214,9 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Input params count = "
- << graph_transfer_info_.node_input_info_size();
- for (const GraphTransferInfo::NodeOutputInfo& params :
- graph_transfer_info_.node_output_info()) {
+ << graph_transfer_info_->node_input_info_size();
+ for (const GraphTransferNodeOutputInfo& params :
+ graph_transfer_info_->node_output_info()) {
std::stringstream sstream;
sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
for (const int max_size : params.max_byte_size()) {
@@ -1222,7 +1226,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Output params count = "
- << graph_transfer_info_.node_output_info_size();
+ << graph_transfer_info_->node_output_info_size();
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 0d43d028cd..86c1c5625f 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -22,8 +22,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/shape_refiner.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
@@ -34,6 +32,10 @@ limitations under the License.
namespace tensorflow {
+class GraphTransferInfo;
+class GraphTransferNodeInfo;
+class GraphTransferNodeInputInfo;
+
// GraphTransferer transfers graph definitions into SoC memory.
// This functionality is effective if SoC is capable to run
// the graph on that chip.
@@ -47,7 +49,9 @@ class GraphTransferer {
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
- GraphTransferer() = default;
+ GraphTransferer();
+
+ ~GraphTransferer();
// Load graph structure into GraphTransferer
// TODO(satok): Pass a pair of TensorShape and DataType instead of
@@ -96,8 +100,8 @@ class GraphTransferer {
public:
TransferParamsComparator(
const std::unordered_map<int, std::unordered_set<int>>& dep_map);
- bool operator()(const GraphTransferInfo::NodeInfo& obj0,
- const GraphTransferInfo::NodeInfo& obj1);
+ bool operator()(const GraphTransferNodeInfo& obj0,
+ const GraphTransferNodeInfo& obj1);
const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
};
@@ -174,9 +178,8 @@ class GraphTransferer {
const std::vector<int>& extra_inputs,
const int outputs_size);
- void AddNodeInputByInputIndex(
- const Node& node, const int idx,
- GraphTransferInfo::NodeInputInfo* node_input_info);
+ void AddNodeInputByInputIndex(const Node& node, const int idx,
+ GraphTransferNodeInputInfo* node_input_info);
void AppendNodeInputParams(const int id, const Node& node,
const std::vector<int>& extra_inputs);
@@ -211,7 +214,7 @@ class GraphTransferer {
// Dump pretty print of parameters
void DumpNodeTransferParams() const;
- GraphTransferInfo graph_transfer_info_{};
+ GraphTransferInfo* graph_transfer_info_;
std::vector<const Node*> node_name_cache_list_{};
std::unordered_map<string, int> node_name_to_id_cache_map_{};
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
index 20b09f144b..765795b1f4 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
@@ -191,9 +191,9 @@ static GraphDef CreatePoolGraphDef() {
return def;
}
-static const GraphTransferInfo::ConstNodeInfo* FindConstNodeInfo(
+static const GraphTransferConstNodeInfo* FindConstNodeInfo(
const GraphTransferer& gt, const string& name) {
- for (const GraphTransferInfo::ConstNodeInfo& params :
+ for (const GraphTransferConstNodeInfo& params :
gt.GetGraphTransferInfo().const_node_info()) {
if (params.name() == name) {
return &params;
@@ -202,9 +202,9 @@ static const GraphTransferInfo::ConstNodeInfo* FindConstNodeInfo(
return nullptr;
}
-static const GraphTransferInfo::NodeInfo* FindNodeInfo(
- const GraphTransferer& gt, const string& name) {
- for (const GraphTransferInfo::NodeInfo& params :
+static const GraphTransferNodeInfo* FindNodeInfo(const GraphTransferer& gt,
+ const string& name) {
+ for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.name() == name) {
return &params;
@@ -213,9 +213,9 @@ static const GraphTransferInfo::NodeInfo* FindNodeInfo(
return nullptr;
}
-static const GraphTransferInfo::NodeInputInfo* FindNodeInputInfo(
+static const GraphTransferNodeInputInfo* FindNodeInputInfo(
const GraphTransferer& gt, const int node_id) {
- for (const GraphTransferInfo::NodeInputInfo& params :
+ for (const GraphTransferNodeInputInfo& params :
gt.GetGraphTransferInfo().node_input_info()) {
if (params.node_id() == node_id) {
return &params;
@@ -224,9 +224,9 @@ static const GraphTransferInfo::NodeInputInfo* FindNodeInputInfo(
return nullptr;
}
-static const GraphTransferInfo::NodeOutputInfo* FindNodeOutputInfo(
+static const GraphTransferNodeOutputInfo* FindNodeOutputInfo(
const GraphTransferer& gt, const int node_id) {
- for (const GraphTransferInfo::NodeOutputInfo& params :
+ for (const GraphTransferNodeOutputInfo& params :
gt.GetGraphTransferInfo().node_output_info()) {
if (params.node_id() == node_id) {
return &params;
@@ -236,21 +236,21 @@ static const GraphTransferInfo::NodeOutputInfo* FindNodeOutputInfo(
}
static void SanityCheckNodes(const GraphTransferer& gt) {
- for (const GraphTransferInfo::NodeInfo& params :
+ for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.input_count() > 0) {
- const GraphTransferInfo::NodeInputInfo* input_params =
+ const GraphTransferNodeInputInfo* input_params =
FindNodeInputInfo(gt, params.node_id());
ASSERT_NE(nullptr, input_params);
EXPECT_EQ(params.input_count(), input_params->node_input_size());
EXPECT_EQ(params.node_id(), input_params->node_id());
- for (const GraphTransferInfo::NodeInput& node_input :
+ for (const GraphTransferNodeInput& node_input :
input_params->node_input()) {
EXPECT_GE(node_input.output_port(), 0);
}
}
if (params.output_count() > 0) {
- const GraphTransferInfo::NodeOutputInfo* output_params =
+ const GraphTransferNodeOutputInfo* output_params =
FindNodeOutputInfo(gt, params.node_id());
ASSERT_NE(nullptr, output_params);
EXPECT_EQ(params.output_count(), output_params->max_byte_size_size());
@@ -273,8 +273,7 @@ TEST_F(GraphTransfererTest, LoadAddGraph) {
const int const_node_count =
gt_.GetGraphTransferInfo().const_node_info_size();
ASSERT_EQ(2, const_node_count);
- const GraphTransferInfo::ConstNodeInfo* params_a =
- FindConstNodeInfo(gt_, NAME_A);
+ const GraphTransferConstNodeInfo* params_a = FindConstNodeInfo(gt_, NAME_A);
ASSERT_TRUE(params_a != nullptr);
EXPECT_EQ(NAME_A, params_a->name());
ASSERT_EQ(4, params_a->shape_size());
@@ -284,8 +283,7 @@ TEST_F(GraphTransfererTest, LoadAddGraph) {
EXPECT_EQ(1, params_a->shape(3));
EXPECT_EQ(4, params_a->data().length());
- const GraphTransferInfo::ConstNodeInfo* params_b =
- FindConstNodeInfo(gt_, NAME_B);
+ const GraphTransferConstNodeInfo* params_b = FindConstNodeInfo(gt_, NAME_B);
ASSERT_TRUE(params_b != nullptr);
ASSERT_EQ(4, params_b->shape_size());
EXPECT_EQ(1, params_b->shape(0));
@@ -328,7 +326,7 @@ TEST_F(GraphTransfererTest, LoadConvGraph) {
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
- const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv");
+ const GraphTransferNodeInfo* params_conv = FindNodeInfo(gt_, "conv");
ASSERT_TRUE(params_conv != nullptr);
const int id = params_conv->node_id();
EXPECT_GE(id, 0);
@@ -354,8 +352,7 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
- const GraphTransferInfo::NodeInfo* params_max_pool =
- FindNodeInfo(gt_, "maxpool");
+ const GraphTransferNodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool");
ASSERT_TRUE(params_max_pool != nullptr);
const int id = params_max_pool->node_id();
EXPECT_GE(id, 0);
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index 9c2e1e123c..66d24d171d 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
@@ -54,9 +55,9 @@ static uint8* FindAlignedPointer(uint8* ptr) {
return data_ptr;
}
-/* static */ GraphTransferInfo::NodeInfo* HexagonControlWrapper::FindNodeInfo(
+/* static */ GraphTransferNodeInfo* HexagonControlWrapper::FindNodeInfo(
const string& name, GraphTransferInfo* graph_transfer_info) {
- for (GraphTransferInfo::NodeInfo& node_info :
+ for (GraphTransferNodeInfo& node_info :
*graph_transfer_info->mutable_node_info()) {
if (node_info.name() == name) {
return &node_info;
@@ -138,9 +139,9 @@ bool HexagonControlWrapper::SetupGraph() {
graph_transferer_.GetMutableGraphTransferInfo();
// Overwrite op type of input nodes for hexagon
- for (const GraphTransferInfo::GraphInputNodeInfo& graph_input :
+ for (const GraphTransferGraphInputNodeInfo& graph_input :
graph_transfer_info.graph_input_node_info()) {
- GraphTransferInfo::NodeInfo* node_info =
+ GraphTransferNodeInfo* node_info =
FindNodeInfo(graph_input.name(), &graph_transfer_info);
CHECK_NE(node_info, nullptr);
}
@@ -148,13 +149,13 @@ bool HexagonControlWrapper::SetupGraph() {
// Generate a new output node which is connected to graph output node
// TODO(satok): Support multiple output nodes
CHECK_EQ(graph_transfer_info.graph_output_node_info_size(), 1);
- for (const GraphTransferInfo::GraphOutputNodeInfo& graph_output :
+ for (const GraphTransferGraphOutputNodeInfo& graph_output :
graph_transfer_info.graph_output_node_info()) {
const int new_output_node_id = graph_transfer_info.node_info_size() +
graph_transfer_info.const_node_info_size() +
2 /* offset for ids */;
// Register a new output node
- GraphTransferInfo::NodeInfo& new_output_node_info =
+ GraphTransferNodeInfo& new_output_node_info =
*graph_transfer_info.add_node_info();
new_output_node_info.set_name(OUTPUT_OP_NAME);
new_output_node_info.set_node_id(new_output_node_id);
@@ -169,14 +170,13 @@ bool HexagonControlWrapper::SetupGraph() {
const string node_name = tid.first.ToString();
const int port = tid.second;
// Register node input for the new output node
- const GraphTransferInfo::NodeInfo* node_info =
+ const GraphTransferNodeInfo* node_info =
FindNodeInfo(node_name, &graph_transfer_info);
CHECK_NE(node_info, nullptr);
- GraphTransferInfo::NodeInputInfo& node_input_info =
+ GraphTransferNodeInputInfo& node_input_info =
*graph_transfer_info.add_node_input_info();
node_input_info.set_node_id(new_output_node_id);
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(node_info->node_id());
node_input.set_output_port(port);
}
@@ -189,12 +189,12 @@ bool HexagonControlWrapper::SetupGraph() {
int inputs_count = 0;
int outputs_count = 0;
- for (const GraphTransferInfo::NodeInputInfo& input_params :
+ for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
inputs_count += input_params.node_input_size();
}
- for (const GraphTransferInfo::NodeOutputInfo& output_params :
+ for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
outputs_count += output_params.max_byte_size_size();
}
@@ -204,15 +204,14 @@ bool HexagonControlWrapper::SetupGraph() {
// Construct node input parameters
std::unordered_map<int, std::tuple<void*, int>> inputs_map;
- for (const GraphTransferInfo::NodeInputInfo& input_params :
+ for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
const int count = input_params.node_input_size();
CHECK(count <= MAX_IN_OUT_COUNT);
int node_ids[MAX_IN_OUT_COUNT];
int ports[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
- const GraphTransferInfo::NodeInput& node_input =
- input_params.node_input(i);
+ const GraphTransferNodeInput& node_input = input_params.node_input(i);
node_ids[i] = node_input.node_id() + NODE_ID_OFFSET;
ports[i] = node_input.output_port();
}
@@ -224,7 +223,7 @@ bool HexagonControlWrapper::SetupGraph() {
// Construct node output parameters
std::unordered_map<int, std::tuple<void*, int>> outputs_map;
- for (const GraphTransferInfo::NodeOutputInfo& output_params :
+ for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
const int count = output_params.max_byte_size_size();
CHECK(count <= MAX_IN_OUT_COUNT);
@@ -244,7 +243,7 @@ bool HexagonControlWrapper::SetupGraph() {
// Initialize graph
// 1. Setup const nodes
- for (const GraphTransferInfo::ConstNodeInfo& params :
+ for (const GraphTransferConstNodeInfo& params :
graph_transfer_info.const_node_info()) {
const int node_id = params.node_id();
// TODO(satok): Stop assuming shape size is 4.
@@ -267,8 +266,7 @@ bool HexagonControlWrapper::SetupGraph() {
}
// 2. Setup op nodes
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info.node_info()) {
+ for (const GraphTransferNodeInfo& params : graph_transfer_info.node_info()) {
const int node_id = params.node_id();
const int op_id = params.soc_op_id();
CHECK(inputs_map.count(node_id) == 1);
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
index dca1f94a9b..132cfde2db 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
@@ -67,8 +67,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
// CAVEAT: Need offset as HVX library reserves some ids
static constexpr int NODE_ID_OFFSET = 0x10000;
- static GraphTransferInfo::NodeInfo* FindNodeInfo(
- const string& node_name, GraphTransferInfo* graph_transfer_info);
+ static GraphTransferNodeInfo* FindNodeInfo(
+ const string& name, GraphTransferInfo* graph_transfer_info);
const RemoteFusedGraphExecuteInfo* execute_info_{};
GraphTransferer graph_transferer_{};
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index 3f794dfb1a..5fb6b9247f 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -29,6 +29,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
#include <memory>
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
@@ -209,7 +210,7 @@ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
const GraphTransferInfo& graph_transfer_info) {
RemoteFusedGraphExecuteInfo execute_info;
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
- for (const GraphTransferInfo::GraphInputNodeInfo& input :
+ for (const GraphTransferGraphInputNodeInfo& input :
graph_transfer_info.graph_input_node_info()) {
execute_info.add_graph_input_node_name(input.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
@@ -221,7 +222,7 @@ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
}
}
- for (const GraphTransferInfo::GraphOutputNodeInfo& output :
+ for (const GraphTransferGraphOutputNodeInfo& output :
graph_transfer_info.graph_output_node_info()) {
execute_info.add_graph_output_node_name(output.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
@@ -325,8 +326,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
// 1. check node_info
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
for (int i = 0; i < gfi0.node_info_size(); ++i) {
- const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i);
- const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i);
+ const GraphTransferNodeInfo& ni0 = gfi0.node_info(i);
+ const GraphTransferNodeInfo& ni1 = gfi1.node_info(i);
EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
EXPECT_EQ(ni0.ByteSizeLong(), ni1.ByteSizeLong());
}
@@ -334,8 +335,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
// 2. check const_node_info
ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
- const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i);
- const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i);
+ const GraphTransferConstNodeInfo& cni0 = gfi0.const_node_info(i);
+ const GraphTransferConstNodeInfo& cni1 = gfi1.const_node_info(i);
ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
for (int j = 0; j < cni0.shape_size(); ++j) {
EXPECT_EQ(cni0.shape(j), cni1.shape(j));
@@ -347,8 +348,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
// 3. check node_input_info
ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
- const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i);
- const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i);
+ const GraphTransferNodeInputInfo& nii0 = gfi0.node_input_info(i);
+ const GraphTransferNodeInputInfo& nii1 = gfi1.node_input_info(i);
EXPECT_EQ(nii0.ByteSizeLong(), nii1.ByteSizeLong());
EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
}
@@ -356,8 +357,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
// 4. check node_output_info
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
- const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i);
- const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i);
+ const GraphTransferNodeOutputInfo& noi0 = gfi0.node_output_info(i);
+ const GraphTransferNodeOutputInfo& noi1 = gfi1.node_output_info(i);
ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
@@ -370,9 +371,9 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
ASSERT_EQ(gfi0.graph_input_node_info_size(),
gfi1.graph_input_node_info_size());
for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
- const GraphTransferInfo::GraphInputNodeInfo& gini0 =
+ const GraphTransferGraphInputNodeInfo& gini0 =
gfi0.graph_input_node_info(i);
- const GraphTransferInfo::GraphInputNodeInfo& gini1 =
+ const GraphTransferGraphInputNodeInfo& gini1 =
gfi0.graph_input_node_info(i);
EXPECT_EQ(gini0.ByteSizeLong(), gini1.ByteSizeLong());
EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
@@ -382,9 +383,9 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
ASSERT_EQ(gfi0.graph_output_node_info_size(),
gfi1.graph_output_node_info_size());
for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
- const GraphTransferInfo::GraphOutputNodeInfo& goni0 =
+ const GraphTransferGraphOutputNodeInfo& goni0 =
gfi0.graph_output_node_info(i);
- const GraphTransferInfo::GraphOutputNodeInfo& goni1 =
+ const GraphTransferGraphOutputNodeInfo& goni1 =
gfi0.graph_output_node_info(i);
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc
index cdff7bad5f..7912ca1563 100644
--- a/tensorflow/core/kernels/matching_files_op.cc
+++ b/tensorflow/core/kernels/matching_files_op.cc
@@ -60,6 +60,7 @@ class MatchingFilesOp : public OpKernel {
output(index++) = all_fnames[i][j];
}
}
+ std::sort(&output(0), &output(0) + num_files);
}
};
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 23176b8577..aaaf45d3e7 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/kernels/ops_util.h"
@@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
-#include "tensorflow/core/kernels/bounds_check.h"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
@@ -174,7 +174,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
// Although this check is in the inner loop, it is worth its value
// so we don't end up with memory corruptions. Our benchmark shows that
// the performance impact is quite small
- //CHECK(input_backprop_index >= in_start && input_backprop_index < in_end)
+ // CHECK(input_backprop_index >= in_start && input_backprop_index <
+ // in_end)
FastBoundsCheck(input_backprop_index - in_start, in_end - in_start);
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
}
@@ -873,7 +874,7 @@ template <typename T>
struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output, Tensor* argmax,
- bool propogate_nans) {
+ bool propagate_nans) {
Tensor unused;
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
context, output, argmax, nullptr, input, unused, params);
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 623de2a482..3e16ba8d04 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <random>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
@@ -368,9 +369,9 @@ Status Examples::Initialize(OpKernelContext* const context,
TF_RETURN_IF_ERROR(CreateDenseFeatureRepresentation(
worker_threads, num_examples, num_dense_features, weights,
dense_features_inputs, &examples_));
- ComputeSquaredNormPerExample(worker_threads, num_examples,
- num_sparse_features, num_dense_features,
- &examples_);
+ TF_RETURN_IF_ERROR(ComputeSquaredNormPerExample(
+ worker_threads, num_examples, num_sparse_features, num_dense_features,
+ &examples_));
return Status::OK();
}
@@ -382,7 +383,7 @@ Status Examples::CreateSparseFeatureRepresentation(
const OpInputList& sparse_feature_values_inputs,
std::vector<Example>* const examples) {
mutex mu;
- Status result GUARDED_BY(mu);
+ Status result; // Guarded by mu
auto parse_partition = [&](const int64 begin, const int64 end) {
// The static_cast here is safe since begin and end can be at most
// num_examples which is an int.
@@ -460,7 +461,7 @@ Status Examples::CreateDenseFeatureRepresentation(
const OpInputList& dense_features_inputs,
std::vector<Example>* const examples) {
mutex mu;
- Status result GUARDED_BY(mu);
+ Status result; // Guarded by mu
auto parse_partition = [&](const int64 begin, const int64 end) {
// The static_cast here is safe since begin and end can be at most
// num_examples which is an int.
@@ -486,14 +487,17 @@ Status Examples::CreateDenseFeatureRepresentation(
return result;
}
-void Examples::ComputeSquaredNormPerExample(
+Status Examples::ComputeSquaredNormPerExample(
const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
const int num_sparse_features, const int num_dense_features,
std::vector<Example>* const examples) {
+ mutex mu;
+ Status result; // Guarded by mu
// Compute norm of examples.
auto compute_example_norm = [&](const int64 begin, const int64 end) {
// The static_cast here is safe since begin and end can be at most
// num_examples which is an int.
+ gtl::FlatSet<int64> previous_indices;
for (int example_id = static_cast<int>(begin); example_id < end;
++example_id) {
double squared_norm = 0;
@@ -501,12 +505,19 @@ void Examples::ComputeSquaredNormPerExample(
for (int j = 0; j < num_sparse_features; ++j) {
const Example::SparseFeatures& sparse_features =
example->sparse_features_[j];
- if (sparse_features.values) {
- const Eigen::Tensor<float, 0, Eigen::RowMajor> sn =
- sparse_features.values->square().sum();
- squared_norm += sn();
- } else {
- squared_norm += sparse_features.indices->size();
+ previous_indices.clear();
+ for (int64 k = 0; k < sparse_features.indices->size(); ++k) {
+ const int64 feature_index = (*sparse_features.indices)(k);
+ if (previous_indices.insert(feature_index).second == false) {
+ mutex_lock l(mu);
+ result =
+ errors::InvalidArgument("Duplicate index in sparse vector.");
+ return;
+ }
+ const double feature_value = sparse_features.values == nullptr
+ ? 1.0
+ : (*sparse_features.values)(k);
+ squared_norm += feature_value * feature_value;
}
}
for (int j = 0; j < num_dense_features; ++j) {
@@ -521,6 +532,7 @@ void Examples::ComputeSquaredNormPerExample(
const int64 kCostPerUnit = num_dense_features + num_sparse_features;
Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
kCostPerUnit, compute_example_norm);
+ return result;
}
} // namespace sdca
diff --git a/tensorflow/core/kernels/sdca_internal.h b/tensorflow/core/kernels/sdca_internal.h
index bfdb3febdc..897c488702 100644
--- a/tensorflow/core/kernels/sdca_internal.h
+++ b/tensorflow/core/kernels/sdca_internal.h
@@ -369,7 +369,7 @@ class Examples {
// Computes squared example norm per example i.e |x|^2. This function modifies
// the |examples| passed in and adds the squared norm per example.
- static void ComputeSquaredNormPerExample(
+ static Status ComputeSquaredNormPerExample(
const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples,
int num_sparse_features, int num_dense_features,
std::vector<Example>* const examples);
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 6d35ff2de6..0a0f8d4dcf 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,6 +24,12 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h
index 2fd22c3f4e..62ef35bbba 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.h
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h
@@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
-template <uint64 hash(const string&)>
+template <uint64 hash(StringPiece)>
class StringToHashBucketOp : public OpKernel {
public:
explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 8af4903418..88d6eaf819 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -37,9 +37,10 @@ REGISTER_OP("IsBoostedTreesEnsembleInitialized")
REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
.Input("node_id_range: int32")
.Input("stats_summary_list: num_features * float32")
- .Attr("l1: float")
- .Attr("l2: float")
- .Attr("tree_complexity: float")
+ .Input("l1: float")
+ .Input("l2: float")
+ .Input("tree_complexity: float")
+ .Input("min_node_weight: float")
.Attr("max_splits: int >= 1")
.Attr("num_features: int >= 1") // not passed but populated automatically.
.Output("node_ids_list: num_features * int32")
@@ -51,19 +52,6 @@ REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
// Confirms the rank of the inputs and sets the shape of the outputs.
int max_splits;
int num_features;
- float l1, l2, tree_complexity;
- TF_RETURN_IF_ERROR(c->GetAttr("l1", &l1));
- if (l1 < 0) {
- return errors::InvalidArgument("l1 must be non-negative.");
- }
- TF_RETURN_IF_ERROR(c->GetAttr("l2", &l2));
- if (l2 < 0) {
- return errors::InvalidArgument("l2 must be non-negative.");
- }
- TF_RETURN_IF_ERROR(c->GetAttr("tree_complexity", &tree_complexity));
- if (tree_complexity < 0) {
- return errors::InvalidArgument("Tree complexity must be non-negative.");
- }
TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
shape_inference::ShapeHandle node_id_range_shape;
@@ -83,6 +71,12 @@ REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
TF_RETURN_IF_ERROR(
c->Merge(summary_shape_base, summary_shape, &unused_shape));
}
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_shape));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 2), 0, &unused_shape));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 3), 0, &unused_shape));
// Sets the output lists.
std::vector<shape_inference::ShapeHandle> output_shapes_vec(
num_features, c->MakeShape({-1}));
@@ -185,9 +179,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary")
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
- .Attr("num_bucketized_features: int >= 1")
+ .Attr("num_bucketized_features: int >= 1") // Inferred.
.Attr("logits_dimension: int")
- .Attr("max_depth: int >= 1")
.Output("logits: float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle feature_shape;
@@ -229,7 +222,6 @@ REGISTER_OP("BoostedTreesTrainingPredict")
.Input("bucketized_features: num_bucketized_features * int32")
.Attr("num_bucketized_features: int >= 1")
.Attr("logits_dimension: int")
- .Attr("max_depth: int >= 1")
.Output("partial_logits: float")
.Output("tree_ids: int32")
.Output("node_ids: int32")
@@ -239,9 +231,6 @@ REGISTER_OP("BoostedTreesTrainingPredict")
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
- int max_depth;
- TF_RETURN_IF_ERROR(c->GetAttr("max_depth", &max_depth));
-
shape_inference::ShapeHandle unused_input;
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape));
@@ -273,8 +262,8 @@ REGISTER_OP("BoostedTreesUpdateEnsemble")
.Input("thresholds: num_features * int32")
.Input("left_node_contribs: num_features * float")
.Input("right_node_contribs: num_features * float")
- .Attr("max_depth: int >= 1")
- .Attr("learning_rate: float")
+ .Input("max_depth: int32")
+ .Input("learning_rate: float")
.Attr("pruning_mode: int >=0")
.Attr("num_features: int >= 0") // Inferred.
.SetShapeFn([](shape_inference::InferenceContext* c) {
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index a45a95ae09..5bd37efac8 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -10855,6 +10855,22 @@ op {
type: DT_FLOAT
number_attr: "num_features"
}
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "tree_complexity"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "min_node_weight"
+ type: DT_FLOAT
+ }
output_arg {
name: "node_ids_list"
type: DT_INT32
@@ -10881,18 +10897,6 @@ op {
number_attr: "num_features"
}
attr {
- name: "l1"
- type: "float"
- }
- attr {
- name: "l2"
- type: "float"
- }
- attr {
- name: "tree_complexity"
- type: "float"
- }
- attr {
name: "max_splits"
type: "int"
has_minimum: true
@@ -11054,12 +11058,6 @@ op {
name: "logits_dimension"
type: "int"
}
- attr {
- name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
- }
is_stateful: true
}
op {
@@ -11119,12 +11117,6 @@ op {
name: "logits_dimension"
type: "int"
}
- attr {
- name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
- }
is_stateful: true
}
op {
@@ -11162,15 +11154,13 @@ op {
type: DT_FLOAT
number_attr: "num_features"
}
- attr {
+ input_arg {
name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
+ type: DT_INT32
}
- attr {
+ input_arg {
name: "learning_rate"
- type: "float"
+ type: DT_FLOAT
}
attr {
name: "pruning_mode"
@@ -11762,6 +11752,50 @@ op {
}
}
op {
+ name: "ClipByValue"
+ input_arg {
+ name: "t"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "clip_value_min"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "clip_value_max"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "CloseSummaryWriter"
input_arg {
name: "writer"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index afb3dab3fe..1659adc9fe 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4009,6 +4009,18 @@ op {
type: DT_FLOAT
number_attr: "num_features"
}
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "tree_complexity"
+ type: DT_FLOAT
+ }
output_arg {
name: "node_ids_list"
type: DT_INT32
@@ -4035,18 +4047,6 @@ op {
number_attr: "num_features"
}
attr {
- name: "l1"
- type: "float"
- }
- attr {
- name: "l2"
- type: "float"
- }
- attr {
- name: "tree_complexity"
- type: "float"
- }
- attr {
name: "max_splits"
type: "int"
has_minimum: true
@@ -4208,12 +4208,6 @@ op {
name: "logits_dimension"
type: "int"
}
- attr {
- name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
- }
is_stateful: true
}
op {
@@ -4273,12 +4267,6 @@ op {
name: "logits_dimension"
type: "int"
}
- attr {
- name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
- }
is_stateful: true
}
op {
@@ -4316,15 +4304,13 @@ op {
type: DT_FLOAT
number_attr: "num_features"
}
- attr {
+ input_arg {
name: "max_depth"
- type: "int"
- has_minimum: true
- minimum: 1
+ type: DT_INT32
}
- attr {
+ input_arg {
name: "learning_rate"
- type: "float"
+ type: DT_FLOAT
}
attr {
name: "pruning_mode"
@@ -4728,6 +4714,50 @@ op {
}
}
op {
+ name: "ClipByValue"
+ input_arg {
+ name: "t"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "clip_value_min"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "clip_value_max"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "CloseSummaryWriter"
input_arg {
name: "writer"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 4cfa25bf66..44356e3438 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -1,7 +1,6 @@
# Platform-specific build configurations.
load("@protobuf_archive//:protobuf.bzl", "proto_gen")
-load("@protobuf_archive//:protobuf.bzl", "py_proto_library")
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
load("//tensorflow:tensorflow.bzl", "if_windows")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
@@ -110,6 +109,12 @@ def _proto_cc_srcs(srcs, use_grpc_plugin=False):
ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
return ret
+def _proto_py_outs(srcs, use_grpc_plugin=False):
+ ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
+ return ret
+
# Re-defined protocol buffer rule to allow building "header only" protocol
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
# containing select() statements.
@@ -217,6 +222,80 @@ def cc_proto_library(
hdrs=gen_hdrs,
**kargs)
+# Re-defined protocol buffer rule to bring in the change introduced in commit
+# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
+# which was not part of a stable protobuf release in 04/2018.
+# TODO(jsimsa): Remove this once the protobuf dependency version is updated
+# to include the above commit.
+def py_proto_library(
+ name,
+ srcs=[],
+ deps=[],
+ py_libs=[],
+ py_extra_srcs=[],
+ include=None,
+ default_runtime="@protobuf_archive//:protobuf_python",
+ protoc="@protobuf_archive//:protoc",
+ use_grpc_plugin=False,
+ **kargs):
+ """Bazel rule to create a Python protobuf library from proto source files
+
+ NOTE: the rule is only an internal workaround to generate protos. The
+ interface may change and the rule may be removed when bazel has introduced
+ the native rule.
+
+ Args:
+ name: the name of the py_proto_library.
+ srcs: the .proto files of the py_proto_library.
+ deps: a list of dependency labels; must be py_proto_library.
+ py_libs: a list of other py_library targets depended by the generated
+ py_library.
+ py_extra_srcs: extra source files that will be added to the output
+ py_library. This attribute is used for internal bootstrapping.
+ include: a string indicating the include path of the .proto files.
+ default_runtime: the implicitly default runtime which will be depended on by
+ the generated py_library target.
+ protoc: the label of the protocol compiler to generate the sources.
+ use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
+ when processing the proto files.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+ outs = _proto_py_outs(srcs, use_grpc_plugin)
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ grpc_python_plugin = None
+ if use_grpc_plugin:
+ grpc_python_plugin = "//external:grpc_python_plugin"
+ # Note: Generated grpc code depends on Python grpc module. This dependency
+ # is not explicitly listed in py_libs. Instead, host system is assumed to
+ # have grpc installed.
+
+ proto_gen(
+ name=name + "_genproto",
+ srcs=srcs,
+ deps=[s + "_genproto" for s in deps],
+ includes=includes,
+ protoc=protoc,
+ gen_py=1,
+ outs=outs,
+ visibility=["//visibility:public"],
+ plugin=grpc_python_plugin,
+ plugin_language="grpc"
+ )
+
+ if default_runtime and not default_runtime in py_libs + deps:
+ py_libs = py_libs + [default_runtime]
+
+ native.py_library(
+ name=name,
+ srcs=outs+py_extra_srcs,
+ deps=py_libs+deps,
+ imports=includes,
+ **kargs)
+
def tf_proto_library_cc(name, srcs = [], has_services = None,
protodeps = [],
visibility = [], testonly = 0,
@@ -261,8 +340,7 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
)
def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
- testonly=0,
- srcs_version="PY2AND3"):
+ testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
py_proto_library(
name = name + "_py",
srcs = srcs,
@@ -272,6 +350,7 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
default_runtime = "@protobuf_archive//:protobuf_python",
visibility = visibility,
testonly = testonly,
+ use_grpc_plugin = use_grpc_plugin,
)
def tf_jspb_proto_library(**kwargs):
@@ -310,6 +389,7 @@ def tf_proto_library(name, srcs = [], has_services = None,
srcs_version = "PY2AND3",
testonly = testonly,
visibility = visibility,
+ use_grpc_plugin = has_services,
)
def tf_additional_lib_hdrs(exclude = []):
diff --git a/tensorflow/core/platform/default/fingerprint.h b/tensorflow/core/platform/default/fingerprint.h
index 71f9951e53..f901befc16 100644
--- a/tensorflow/core/platform/default/fingerprint.h
+++ b/tensorflow/core/platform/default/fingerprint.h
@@ -18,14 +18,16 @@ limitations under the License.
#include <farmhash.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+
namespace tensorflow {
-inline uint64 Fingerprint64(const string& s) {
- return ::util::Fingerprint64(s);
+inline uint64 Fingerprint64(StringPiece s) {
+ return ::util::Fingerprint64(s.data(), s.size());
}
-inline Fprint128 Fingerprint128(const string& s) {
- const auto fingerprint = ::util::Fingerprint128(s);
+inline Fprint128 Fingerprint128(StringPiece s) {
+ const auto fingerprint = ::util::Fingerprint128(s.data(), s.size());
return {::util::Uint128Low64(fingerprint),
::util::Uint128High64(fingerprint)};
}
diff --git a/tensorflow/core/platform/fingerprint.h b/tensorflow/core/platform/fingerprint.h
index fd0347a10b..b47dcdedd7 100644
--- a/tensorflow/core/platform/fingerprint.h
+++ b/tensorflow/core/platform/fingerprint.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_
#define TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -36,15 +37,12 @@ struct Fprint128Hasher {
}
};
-// TODO(sibyl-Mooth6ku): Change these to accept StringPiece (or make them templated
-// on any kind of byte array?).
-
// This is a portable fingerprint interface for strings that will never change.
// However, it is not suitable for cryptography.
-uint64 Fingerprint64(const string& s);
+uint64 Fingerprint64(StringPiece s);
// 128-bit variant of Fingerprint64 above (same properties and caveats apply).
-Fprint128 Fingerprint128(const string& s);
+Fprint128 Fingerprint128(StringPiece s);
namespace internal {
// Mixes some of the bits that got propagated to the high bits back into the
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index 3d3203cdaa..af034bdd7d 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -1,6 +1,4 @@
-package(
- default_visibility = ["//visibility:public"],
-)
+package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 3963d5faa7..8373a1219d 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1417,12 +1417,12 @@ Applies a reduction function to an array.
| `dimensions` | `int64` array | unordered array of dimensions to |
: : : reduce :
-Conceptually, this operation reduces one or more dimensions in the input array
-into scalars. The rank of the result array is `rank(operand) - len(dimensions)`.
-`init_value` is the initial value used for every reduction and may also be
-inserted anywhere during computation if the back-end chooses to do so. So in
-most cases `init_value` should be an identity of the reduction function (for
-example, 0 for addition).
+This operation reduces one or more dimensions of the input array into scalars.
+The rank of the returned array is `rank(operand) - len(dimensions)`.
+`init_value` is the initial value used for every reduction and may be inserted
+anywhere during computation by the back-end. In most cases, `init_value` is an
+identity of the reduction function (for example, 0 for addition). The applied
+`computation` is always passed the `init_value` on the left-hand side.
The evaluation order of the reduction function is arbitrary and may be
non-deterministic. Therefore, the reduction function should not be overly
@@ -1442,8 +1442,7 @@ could be computed as
but there are also many other possibilities, e.g.
-`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(13,
-init_value))))`
+`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
The following is a rough pseudo-code example of how reduction could be
implemented, using summation as the reduction computation with an initial value
@@ -1561,7 +1560,9 @@ See also
Applies a reduction function to all elements in each window of the input
multi-dimensional array, producing an output multi-dimensional array with the
same number of elements as the number of valid positions of the window. A
-pooling layer can be expressed as a `ReduceWindow`.
+pooling layer can be expressed as a `ReduceWindow`. Similar to
+[`Reduce`](#reduce), the applied `computation` is always passed the `init_value`
+on the left-hand side.
<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
window_strides, padding)` </b>
diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md
index 9ccdbde627..67be41b1a6 100644
--- a/tensorflow/docs_src/programmers_guide/datasets.md
+++ b/tensorflow/docs_src/programmers_guide/datasets.md
@@ -540,7 +540,7 @@ batched into a fixed size.
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
- image_decoded = tf.image.decode_image(image_string)
+ image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index c99d04869a..9c1601753b 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 4561c2c8ad..3d013e12b0 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 82a2b8e769..40e44af1f5 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 4c1ec0cc80..82bfd0c73a 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index fcd8236bad..0a2775a500 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 241581713a..61961432a7 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0</version>
+ <version>1.8.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a683c8cfa6..569d3eb2ce 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4,16 +4,14 @@
# Public targets:
# ":platform" - Low-level and platform-specific Python code.
-package(
- default_visibility = [
- "//engedu/ml/tf_from_scratch:__pkg__",
- "//tensorflow:internal",
- "//tensorflow/contrib/lite/toco/python:__pkg__",
- "//tensorflow_models:__subpackages__",
- # TODO(aselle): to pass open source test.
- "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__",
- ],
-)
+package(default_visibility = [
+ "//engedu/ml/tf_from_scratch:__pkg__",
+ "//tensorflow:internal",
+ "//tensorflow/contrib/lite/toco/python:__pkg__",
+ "//tensorflow_models:__subpackages__",
+ # TODO(aselle): to pass open source test.
+ "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__",
+])
licenses(["notice"]) # Apache 2.0
@@ -1795,6 +1793,16 @@ py_library(
)
py_library(
+ name = "cudnn_rnn_grad",
+ srcs = ["ops/cudnn_rnn_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ "//tensorflow/python:cudnn_rnn_ops_gen",
+ ],
+)
+
+py_library(
name = "data_flow_grad",
srcs = ["ops/data_flow_grad.py"],
srcs_version = "PY2AND3",
@@ -2467,6 +2475,7 @@ py_library(
":clip_ops",
":confusion_matrix",
":control_flow_ops",
+ ":cudnn_rnn_grad",
":data_flow_grad",
":data_flow_ops",
":framework_for_generated_wrappers",
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index 6442eb9ff5..f7d7d085c9 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -69,6 +69,54 @@ class ListFilesDatasetOpTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
+ def testSimpleDirectoryNotShuffled(self):
+ filenames = ['b', 'c', 'a']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=False)
+ with self.test_session() as sess:
+ itr = dataset.make_one_shot_iterator()
+ next_element = itr.get_next()
+
+ for filename in sorted(filenames):
+ self.assertEqual(compat.as_bytes(path.join(self.tmp_dir, filename)),
+ sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(itr.get_next())
+
+ def testFixedSeedResultsInRepeatableOrder(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
+ with self.test_session() as sess:
+ itr = dataset.make_initializable_iterator()
+ next_element = itr.get_next()
+
+ full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames]
+
+ all_produced_filenames = []
+ for _ in range(3):
+ produced_filenames = []
+ sess.run(itr.initializer)
+ try:
+ while True:
+ produced_filenames.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ all_produced_filenames.append(produced_filenames)
+
+ # Each run should produce the same set of filenames, which may be
+ # different from the order of `full_filenames`.
+ self.assertItemsEqual(full_filenames, all_produced_filenames[0])
+ # However, the different runs should produce filenames in the same order
+ # as each other.
+ self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
+ self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])
+
def testEmptyDirectoryInitializer(self):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 406f172e59..bd9686f692 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -571,9 +571,13 @@ class Dataset(object):
return PrefetchDataset(self, buffer_size)
@staticmethod
- def list_files(file_pattern, shuffle=None):
+ def list_files(file_pattern, shuffle=None, seed=None):
"""A dataset of all files matching a pattern.
+ NOTE: The default behavior of this method is to return filenames in
+ a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
+ to get results in a deterministic order.
+
Example:
If we had the following files on our filesystem:
- /path/to/dir/a.txt
@@ -584,20 +588,18 @@ class Dataset(object):
- /path/to/dir/b.py
- /path/to/dir/c.py
- NOTE: The order of the file names returned can be non-deterministic even
- when `shuffle` is `False`.
-
Args:
file_pattern: A string or scalar string `tf.Tensor`, representing
the filename pattern that will be matched.
shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
Defaults to `True`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ @{tf.set_random_seed} for behavior.
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
- # TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
- # behavior deterministic (e.g. by sorting the filenames).
if shuffle is None:
shuffle = True
matching_files = gen_io_ops.matching_files(file_pattern)
@@ -607,7 +609,7 @@ class Dataset(object):
# list of files might be empty.
buffer_size = math_ops.maximum(
array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
- dataset = dataset.shuffle(buffer_size)
+ dataset = dataset.shuffle(buffer_size, seed=seed)
return dataset
def repeat(self, count=None):
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index a34405c702..7bf4447491 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -7,7 +7,6 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "estimator_py",
@@ -25,7 +24,6 @@ py_library(
":linear",
":model_fn",
":parsing_utils",
- ":replicate_model_fn",
":run_config",
":training",
"//tensorflow/python:util",
@@ -909,68 +907,3 @@ py_test(
"//tensorflow/python:training",
],
)
-
-py_library(
- name = "replicate_model_fn",
- srcs = [
- "replicate_model_fn.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":export_output",
- ":model_fn",
- ":util",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:device",
- "//tensorflow/python:device_lib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/ops/losses",
- "@six_archive//:six",
- ],
-)
-
-cuda_py_test(
- name = "replicate_model_fn_test",
- size = "medium",
- srcs = ["replicate_model_fn_test.py"],
- additional_deps = [
- "//tensorflow/python/estimator",
- ":dnn",
- ":export_export",
- ":export_output",
- ":model_fn",
- ":numpy_io",
- ":optimizers",
- ":prediction_keys",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ":replicate_model_fn",
- ],
- tags = [
- "multi_gpu",
- "noasan", # flaky time outs
- "notsan", # flaky
- ],
-)
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 0ecc8c7089..085dace1b3 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
@@ -40,14 +41,42 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import tf_export
-_TreeHParams = collections.namedtuple(
- 'TreeHParams',
- ['n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity'])
+# TODO(nponomareva): Reveal pruning params here.
+_TreeHParams = collections.namedtuple('TreeHParams', [
+ 'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
+ 'min_node_weight'
+])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
_HOLD_FOR_MULTI_DIM_SUPPORT = object()
+def _get_max_buckets(feature_columns):
+ """Gets the maximum number of buckets from feature_columns.
+
+ Args:
+ feature_columns: a list/set of tf.feature_column.
+
+ Returns:
+ max_buckets: the maximum number of buckets among bucketized_columns.
+
+ Raises:
+ ValueError: when unsupported feature_columns are given.
+ """
+ if not feature_columns:
+ raise ValueError('feature_columns must be a non-empty list/set of '
+ 'tf.feature_column.')
+ max_buckets = 1
+ for fc in feature_columns:
+ if isinstance(fc, feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
+ # N boundaries creates (N+1) buckets.
+ max_buckets = max(max_buckets, len(fc.boundaries) + 1)
+ else:
+ raise ValueError('For now, only bucketized_column is supported but '
+ 'got: {}'.format(fc))
+ return max_buckets
+
+
def _get_transformed_features(features, feature_columns):
"""Gets the transformed features from features/feature_columns pair.
@@ -57,36 +86,31 @@ def _get_transformed_features(features, feature_columns):
Returns:
result_features: a list of the transformed features, sorted by the name.
- num_buckets: the maximum number of buckets across bucketized_columns.
Raises:
ValueError: when unsupported features/columns are tried.
"""
- num_buckets = 1
# pylint:disable=protected-access
for fc in feature_columns:
- if isinstance(fc, feature_column_lib._BucketizedColumn):
- # N boundaries creates (N+1) buckets.
- num_buckets = max(num_buckets, len(fc.boundaries) + 1)
- else:
+ if not isinstance(fc, feature_column_lib._BucketizedColumn):
raise ValueError('For now, only bucketized_column is supported but '
'got: {}'.format(fc))
- transformed = feature_column_lib._transform_features(features,
- feature_columns)
+ transformed_features = feature_column_lib._transform_features(
+ features, feature_columns)
# pylint:enable=protected-access
result_features = []
- for column in sorted(transformed, key=lambda tc: tc.name):
+ for column in sorted(transformed_features, key=lambda tc: tc.name):
source_name = column.source_column.name
- squeezed_tensor = array_ops.squeeze(transformed[column], axis=1)
+ squeezed_tensor = array_ops.squeeze(transformed_features[column], axis=1)
if len(squeezed_tensor.shape) > 1:
raise ValueError('For now, only supports features equivalent to rank 1 '
'but column `{}` got: {}'.format(
source_name, features[source_name].shape))
result_features.append(squeezed_tensor)
- return result_features, num_buckets
+ return result_features
-def _keep_as_local_variable(tensor, name=None):
+def _local_variable(tensor, name=None):
"""Stores a tensor as a local Variable for faster read."""
return variable_scope.variable(
initial_value=tensor,
@@ -96,6 +120,48 @@ def _keep_as_local_variable(tensor, name=None):
name=name)
+def _cache_transformed_features(features, feature_columns, batch_size):
+ """Transform features and cache, then returns (cached_features, cache_op)."""
+ num_features = len(feature_columns)
+ cached_features = [
+ _local_variable(
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ name='cached_feature_{}'.format(i))
+ for i in range(num_features)
+ ]
+ are_features_cached = _local_variable(False, name='are_features_cached')
+
+ def cache_features_and_return():
+ """Caches transoformed features.
+
+ The intention is to hide get_transformed_features() from the graph by
+ caching the result except the first step, since bucketize operation
+ (inside get_transformed_features) is expensive.
+
+ Returns:
+ input_feature_list: a list of input features.
+ cache_flip_op: op to add to graph to make sure cache update is included to
+ the graph.
+ """
+
+ transformed_features = _get_transformed_features(features, feature_columns)
+ cached = [
+ state_ops.assign(cached_features[i], transformed_features[i])
+ for i in range(num_features)
+ ]
+ # TODO(youngheek): Try other combination of dependencies so that the
+ # function returns a single result, not a tuple.
+ with ops.control_dependencies(cached):
+ cache_flip_op = are_features_cached.assign(True)
+ return cached, cache_flip_op
+
+ input_feature_list, cache_flip_op = control_flow_ops.cond(
+ are_features_cached,
+ lambda: (cached_features, control_flow_ops.no_op()),
+ cache_features_and_return)
+ return input_feature_list, cache_flip_op
+
+
class _CacheTrainingStatesUsingHashTable(object):
"""Caching logits, etc. using MutableHashTable."""
@@ -184,13 +250,13 @@ class _CacheTrainingStatesUsingVariables(object):
logits_dimension: a constant (int) for the dimension of logits.
"""
self._logits_dimension = logits_dimension
- self._tree_ids = _keep_as_local_variable(
+ self._tree_ids = _local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
- self._node_ids = _keep_as_local_variable(
+ self._node_ids = _local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
- self._logits = _keep_as_local_variable(
+ self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
name='logits_cache')
@@ -288,33 +354,38 @@ def _bt_model_fn(
'When train_in_memory is enabled, input_fn should return the entire '
'dataset as a single batch, and n_batches_per_layer should be set as '
'1.')
+ if (not config.is_chief or config.num_worker_replicas > 1 or
+ config.num_ps_replicas > 0):
+ raise ValueError('train_in_memory is supported only for '
+ 'non-distributed training.')
worker_device = control_flow_ops.no_op().device
# maximum number of splits possible in the whole tree =2^(D-1)-1
# TODO(youngheek): perhaps storage could be optimized by storing stats with
# the dimension max_splits_per_layer, instead of max_splits (for the entire
# tree).
max_splits = (1 << tree_hparams.max_depth) - 1
+ max_buckets = _get_max_buckets(feature_columns)
+ train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
- input_feature_list, num_buckets = _get_transformed_features(
- features, feature_columns)
- if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
- input_feature_list = [
- _keep_as_local_variable(feature) for feature in input_feature_list
- ]
- num_features = len(input_feature_list)
-
- cache = None
- if mode == model_fn.ModeKeys.TRAIN:
- if train_in_memory and is_single_machine: # maybe just train_in_memory?
- batch_size = array_ops.shape(input_feature_list[0])[0]
- cache = _CacheTrainingStatesUsingVariables(batch_size,
- head.logits_dimension)
- elif example_id_column_name:
+ num_features = len(feature_columns)
+ # Extract input features and set up cache for training.
+ training_state_cache = None
+ if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
+ # cache transformed features as well for in-memory training.
+ batch_size = array_ops.shape(labels)[0]
+ input_feature_list, input_cache_op = _cache_transformed_features(
+ features, feature_columns, batch_size)
+ train_op.append(input_cache_op)
+ training_state_cache = _CacheTrainingStatesUsingVariables(
+ batch_size, head.logits_dimension)
+ else:
+ input_feature_list = _get_transformed_features(features, feature_columns)
+ if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
example_ids = features[example_id_column_name]
- cache = _CacheTrainingStatesUsingHashTable(example_ids,
- head.logits_dimension)
+ training_state_cache = _CacheTrainingStatesUsingHashTable(
+ example_ids, head.logits_dimension)
# Create Ensemble resources.
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
@@ -325,8 +396,7 @@ def _bt_model_fn(
# so no local copy is needed; using tree_ensemble directly.
tree_ensemble_handle=tree_ensemble.resource_handle,
bucketized_features=input_feature_list,
- logits_dimension=head.logits_dimension,
- max_depth=tree_hparams.max_depth)
+ logits_dimension=head.logits_dimension)
else:
if is_single_machine:
local_tree_ensemble = tree_ensemble
@@ -339,11 +409,12 @@ def _bt_model_fn(
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
- if cache:
- cached_tree_ids, cached_node_ids, cached_logits = cache.lookup()
+ if training_state_cache:
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ training_state_cache.lookup())
else:
# Always start from the beginning when no cache is set up.
- batch_size = array_ops.shape(input_feature_list[0])[0]
+ batch_size = array_ops.shape(labels)[0]
cached_tree_ids, cached_node_ids, cached_logits = (
array_ops.zeros([batch_size], dtype=dtypes.int32),
array_ops.zeros([batch_size], dtype=dtypes.int32),
@@ -361,16 +432,14 @@ def _bt_model_fn(
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
- logits_dimension=head.logits_dimension,
- max_depth=tree_hparams.max_depth)
+ logits_dimension=head.logits_dimension)
logits = cached_logits + partial_logits
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
- train_op = []
- if cache:
- train_op.append(cache.insert(tree_ids, node_ids, logits))
+ if training_state_cache:
+ train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
if closed_form_grad_and_hess_fn:
gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
else:
@@ -385,7 +454,7 @@ def _bt_model_fn(
hessians=hessians,
bucketized_features_list=[input_feature_list[f]],
max_splits=max_splits,
- num_buckets=num_buckets),
+ num_buckets=max_buckets),
axis=0) for f in range(num_features)
]
@@ -399,6 +468,7 @@ def _bt_model_fn(
l1=tree_hparams.l1,
l2=tree_hparams.l2,
tree_complexity=tree_hparams.tree_complexity,
+ min_node_weight=tree_hparams.min_node_weight,
max_splits=max_splits))
grow_op = boosted_trees_ops.update_ensemble(
# Confirm if local_tree_ensemble or tree_ensemble should be used.
@@ -421,7 +491,7 @@ def _bt_model_fn(
summary_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of gradients and hessians (the last dimension).
- shape=[num_features, max_splits, num_buckets, 2],
+ shape=[num_features, max_splits, max_buckets, 2],
shared_name='stats_summary_accumulator')
apply_grad = summary_accumulator.apply_grad(
array_ops.stack(stats_summary_list, axis=0), stamp_token)
@@ -517,21 +587,21 @@ def _create_regression_head(label_dimension, weight_column=None):
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
- def __init__(
- self,
- feature_columns,
- n_batches_per_layer,
- model_dir=None,
- n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
- weight_column=None,
- label_vocabulary=None,
- n_trees=100,
- max_depth=6,
- learning_rate=0.1,
- l1_regularization=0.,
- l2_regularization=0.,
- tree_complexity=0.,
- config=None):
+ def __init__(self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
+ weight_column=None,
+ label_vocabulary=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ config=None):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -595,6 +665,9 @@ class BoostedTreesClassifier(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: min_node_weight: minimum hessian a node must have for a
+ split to be considered. The value will be compared with
+ sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
Raises:
@@ -608,9 +681,9 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes, weight_column, label_vocabulary=label_vocabulary)
# HParams for the model.
- tree_hparams = _TreeHParams(
- n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
+ l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -632,20 +705,20 @@ class BoostedTreesClassifier(estimator.Estimator):
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
- def __init__(
- self,
- feature_columns,
- n_batches_per_layer,
- model_dir=None,
- label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
- weight_column=None,
- n_trees=100,
- max_depth=6,
- learning_rate=0.1,
- l1_regularization=0.,
- l2_regularization=0.,
- tree_complexity=0.,
- config=None):
+ def __init__(self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
+ weight_column=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ config=None):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -702,6 +775,9 @@ class BoostedTreesRegressor(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: min_node_weight: minimum hessian a node must have for a
+ split to be considered. The value will be compared with
+ sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
Raises:
@@ -714,9 +790,9 @@ class BoostedTreesRegressor(estimator.Estimator):
head = _create_regression_head(label_dimension, weight_column)
# HParams for the model.
- tree_hparams = _TreeHParams(
- n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
+ l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 7823ef8410..c8c52d3bc6 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.canned import boosted_trees
@@ -58,13 +59,32 @@ def _make_train_input_fn(is_classification):
"""Makes train input_fn for classification/regression."""
def _input_fn():
- features = dict(FEATURES_DICT)
- features[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
- if is_classification:
- labels = CLASSIFICATION_LABELS
+ features_dict = dict(FEATURES_DICT)
+ features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+ labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+ return features_dict, labels
+
+ return _input_fn
+
+
+def _make_train_input_fn_dataset(is_classification, batch=None, repeat=None):
+ """Makes input_fn using Dataset."""
+
+ def _input_fn():
+ features_dict = dict(FEATURES_DICT)
+ features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+ labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+ if batch:
+ ds = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(features_dict),
+ dataset_ops.Dataset.from_tensor_slices(labels))).batch(batch)
else:
- labels = REGRESSION_LABELS
- return features, labels
+ ds = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(features_dict),
+ dataset_ops.Dataset.from_tensors(labels)))
+ # repeat indefinitely by default, or stop at the given step.
+ ds = ds.repeat(repeat)
+ return ds
return _input_fn
@@ -125,9 +145,28 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(train_input_fn, steps=num_steps)
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+ def testTrainClassifierWithDataset(self):
+ train_input_fn = _make_train_input_fn_dataset(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+ est.train(train_input_fn, steps=100) # will stop after 5 steps anyway.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
predictions = list(est.predict(input_fn=predict_input_fn))
- # All labels are correct.
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
@@ -166,12 +205,126 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
est.train(train_input_fn, steps=num_steps)
self._assert_checkpoint(
est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
+
+ def testTrainRegressorWithDataset(self):
+ train_input_fn = _make_train_input_fn_dataset(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+ est.train(train_input_fn, steps=100) # will stop after 5 steps anyway.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
+
+ def testTrainRegressorWithDatasetBatch(self):
+ # The batch_size as the entire data size should yield the same result as
+ # dataset without batching.
+ train_input_fn = _make_train_input_fn_dataset(
+ is_classification=False, batch=5)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+ est.train(train_input_fn, steps=100) # will stop after 5 steps anyway.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
+
+ def testTrainRegressorWithDatasetLargerBatch(self):
+ # The batch_size as the multiple of the entire data size should still yield
+ # the same result.
+ train_input_fn = _make_train_input_fn_dataset(
+ is_classification=False, batch=15)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+ est.train(train_input_fn, steps=100) # will stop after 5 steps anyway.
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
+
+ def testTrainRegressorWithDatasetSmallerBatch(self):
+ # Even when using small batches, if (n_batches_per_layer * batch_size) makes
+ # the same entire data size, the result should be the same.
+ train_input_fn = _make_train_input_fn_dataset(
+ is_classification=False, batch=1)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=5,
+ n_trees=1,
+ max_depth=5)
+ # Train stops after (n_batches_per_layer * n_trees * max_depth) steps.
+ est.train(train_input_fn, steps=100)
+ self._assert_checkpoint(
+ est.model_dir, global_step=25, finalized_trees=1, attempted_layers=5)
+ # 5 batches = one epoch.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=5)
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
predictions = list(est.predict(input_fn=predict_input_fn))
self.assertAllClose(
[[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
[pred['predictions'] for pred in predictions])
+ def testTrainRegressorWithDatasetWhenInputIsOverEarlier(self):
+ train_input_fn = _make_train_input_fn_dataset(
+ is_classification=False, repeat=3) # to stop input after 3 steps.
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+ # Note that training will stop when input exhausts.
+ # This might not be a typical pattern, but dataset.repeat(3) causes
+ # the input stream to cease after 3 steps.
+ est.train(train_input_fn, steps=100)
+ self._assert_checkpoint(
+ est.model_dir, global_step=3, finalized_trees=0, attempted_layers=3)
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 3.777295)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[0.353850], [0.254100], [0.106850], [0.712100], [1.012100]],
+ [pred['predictions'] for pred in predictions])
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
@@ -188,7 +341,8 @@ class ModelFnTests(test_util.TensorFlowTestCase):
learning_rate=0.1,
l1=0.,
l2=0.01,
- tree_complexity=0.)
+ tree_complexity=0.,
+ min_node_weight=0.)
def _get_expected_ensembles_for_classification(self):
first_round = """
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 5e61c30ea2..efa4bdf598 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -1042,7 +1042,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
- labels = _assert_range(labels, 2)
+ labels = _assert_range(labels, n_classes=2)
if self._loss_fn:
unweighted_loss = _call_loss_fn(
loss_fn=self._loss_fn, labels=labels, logits=logits,
@@ -1450,12 +1450,12 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
def _assert_range(labels, n_classes, message=None):
with ops.name_scope(None, 'assert_range', (labels,)):
- assert_less = check_ops.assert_less(
+ assert_less = check_ops.assert_less_equal(
labels,
- ops.convert_to_tensor(n_classes, dtype=labels.dtype),
- message=message or 'Label IDs must < n_classes')
+ ops.convert_to_tensor(n_classes - 1, dtype=labels.dtype),
+ message=message or 'Labels must <= n_classes - 1')
assert_greater = check_ops.assert_non_negative(
- labels, message=message or 'Label IDs must >= 0')
+ labels, message=message or 'Labels must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels)
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index fe6ee07529..7da3df01dc 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -255,14 +255,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits_placeholder,
labels=labels_placeholder)[0]
with self.test_session():
- with self.assertRaisesOpError('Label IDs must < n_classes'):
+ with self.assertRaisesOpError('Labels must <= n_classes - 1'):
training_loss.eval({
labels_placeholder: labels_2x1_with_large_id,
logits_placeholder: logits_2x3
})
with self.test_session():
- with self.assertRaisesOpError('Label IDs must >= 0'):
+ with self.assertRaisesOpError('Labels must >= 0'):
training_loss.eval({
labels_placeholder: labels_2x1_with_negative_id,
logits_placeholder: logits_2x3
@@ -2090,6 +2090,24 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
expected_regularization_loss),
}, summary_str)
+ def test_float_labels_invalid_values(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
+
+ logits = np.array([[0.5], [-0.3]], dtype=np.float32)
+ labels = np.array([[1.2], [0.4]], dtype=np.float32)
+ features = {'x': np.array([[42]], dtype=np.float32)}
+ training_loss = head.create_loss(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)[0]
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'Labels must <= n_classes - 1'):
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ training_loss.eval()
+
def test_float_labels_train_create_loss(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
diff --git a/tensorflow/python/estimator/replicate_model_fn.py b/tensorflow/python/estimator/replicate_model_fn.py
deleted file mode 100644
index 144d89abf3..0000000000
--- a/tensorflow/python/estimator/replicate_model_fn.py
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities to replicate model_fn's over local GPUs.
-
-This file contains util that allow to replicate `Estimator.model_fn` over
-GPUs. Replicated version of a `model_fn` is returned that can subsequently
-be used with `Estimator`.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from contextlib import contextmanager
-import copy
-
-import six
-
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.client import device_lib
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
-from tensorflow.python.estimator.export import export_output as export_output_lib
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.losses import losses
-from tensorflow.python.platform import tf_logging
-from tensorflow.python.training import device_setter as device_setter_lib
-from tensorflow.python.training import optimizer as optimizer_lib
-
-
-def _replicate_model_fn(model_fn,
- devices=None):
- """Replicate `Estimator.model_fn` over GPUs.
-
- The given `model_fn` specifies a single forward pass of a model. To replicate
- such a model over GPUs, each GPU gets its own instance of the forward pass
- (a.k.a. a tower). The input features and labels get sharded into the chunks
- that correspond to the number of GPUs. Each tower computes a loss based
- on its input. For each such loss, gradients are computed. After that, the
- available losses are aggregated to form aggregated loss. Available
- gradients are summed. Then, they update weights using the specified
- optimizer.
-
- If `devices` are `None`, then all available GPUs are going to be used for
- replication. If no GPUs are available, then the model is going to be
- placed on the CPU.
-
- Two modes of local replication over available GPUs are supported:
- 1) If exactly 1 GPU is detected, then variables and operations are placed
- onto the GPU.
- 2) If more than 1 GPU is detected, then variables are going to be placed on
- the CPU. Replicas of operations are placed on each individual GPU.
-
- Here is an example of how one might use their `model_fn` to run over GPUs:
- ```python
- ...
- def model_fn(...): # See `model_fn` in `Estimator`.
- loss = ...
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
- optimizer = tf.contrib.estimator._TowerOptimizer(optimizer)
- if mode == tf.estimator.ModeKeys.TRAIN:
- # See the section below on `EstimatorSpec.train_op`.
- return EstimatorSpec(mode=mode, loss=loss,
- train_op=optimizer.minimize(loss))
-
- # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
- return EstimatorSpec(...)
- ...
- classifier = tf.estimator.Estimator(
- model_fn=tf.contrib.estimator.replicate_model_fn(model_fn))
- ```
-
- Please see `DNNClassifierIntegrationTest` for an example with a canned
- Estimator.
-
- On `EstimatorSpec.train_op`:
- `model_fn` returns `EstimatorSpec.train_op` for
- `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
- Towers are expected to populate it in the same way. Gradients from all towers
- are reduced and applied in the last tower. To achieve that in the case of
- multiple towers, `_TowerOptimizer` needs to be used. See `_TowerOptimizer`.
-
- On sharding input features and labels:
- Input features and labels are split for consumption by each tower. They are
- split across the dimension 0. Features and labels need to be batch major.
-
- On reduction algorithms:
- Certain algorithms were chosen for aggregating results of computations on
- multiple towers:
- - Losses from all towers are reduced according to `loss_reduction` argument
- to TowerOptimizer..
- - Gradients from all towers are reduced according to the `loss_reduction`
- for each trainable variable.
- - `eval_metrics_ops` are reduced per metric using `reduce_mean`.
- - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
- reduced using concatenation.
- - For all other fields of `EstimatorSpec` the values of the first tower
- are taken.
-
- On distribution of variables:
- Variables are not duplicated between towers. Instead, they are placed on a
- single device as defined above and shared across towers.
-
- On overhead:
- If only one device is specified, then aggregation of loss and gradients
- doesn't happen. Replication consists of placing `model_fn` onto the
- specified device.
-
- On current limitations:
- - `predictions` are not supported for `ModeKeys.EVAL`. They are required
- for `tf.contrib.estimator.add_metrics`.
-
- Args:
- model_fn: `model_fn` as defined in `Estimator`. See the section above about
- the train_op argument of `EstimatorSpec`.
- devices: Optional list of devices to replicate the model across. This
- argument can be used to replice only on the subset of available GPUs.
- If `None`, then all available GPUs are going to be used for replication.
- If no GPUs are available, then the model is going to be placed on the CPU.
-
- Returns:
- A replicated version of the supplied `model_fn`. Returned function that
- conforms to the requirements of `Estimator`'s `model_fn` and can be used
- instead of the supplied `model_fn`.
- """
- return _replicate_model_fn_with_mode(
- model_fn,
- devices,
- # TODO(isaprykin): Query the system configuration to choose modes other
- # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often
- # appropriate.
- mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER)
-
-
-class _VariableDistributionMode(object):
- """Modes for variable distribution used for forcing a particular one.
-
- Forcing a mode is meant for performance experimentation purposes rather than
- for general use cases.
- """
-
- SHARED_LOCAL_PARAMETER_SERVER = 1
- """Variables are placed on a single device and shared across all devices.
-
- Two ways to achieve this distribution over available GPUs are supported:
- 1) If exactly 1 GPU is detected, then variables and operations are placed
- onto GPU.
- 2) If more than 1 GPU is detected, then variables are going to be placed on
- the CPU. Replicas of operations are placed on each individual GPU.
- """
-
- SHARED_ROUND_ROBIN = 2
- """Variables are placed on all devices in a round-robin fashion.
-
- Every subsequent variable is placed on the next device. There is only one
- copy of each variable that is shared across all devices.
- """
-
-
-def _replicate_model_fn_with_mode(
- model_fn,
- devices=None,
- mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
- """A version of `replicate_model_fn` that allows to specify a `mode`."""
- if not devices:
- devices = _get_local_devices('GPU') or _get_local_devices('CPU')
-
- is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper()
- consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0'
-
- ps_devices = [consolidation_device]
- if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN:
- ps_devices = devices
-
- tf_logging.info('Replicating the `model_fn` across {}. Variables are going '
- 'to be placed on {}. Consolidation device is going to be {}.'
- .format(devices, ps_devices, consolidation_device))
-
- def single_device_model_fn(features, labels, mode, params=None, config=None):
- """`model_fn` on a single device without reduction overhead."""
- return _get_loss_towers(
- model_fn=model_fn,
- mode=mode,
- features=[features],
- labels=[labels],
- params=params,
- config=config,
- devices=devices,
- local_ps_devices=ps_devices)[0] # One device, so one spec is out.
-
- def replicated_model_fn(features, labels, mode, params=None, config=None):
- """Replicated version of `model_fn` to be used instead."""
- feature_shards, label_shards = _split_batch(
- features, labels, len(devices), device=consolidation_device)
- tower_specs = _get_loss_towers(
- model_fn=model_fn,
- mode=mode,
- features=feature_shards,
- labels=label_shards,
- params=params,
- config=config,
- devices=devices,
- local_ps_devices=ps_devices)
-
- if mode == model_fn_lib.ModeKeys.TRAIN:
- train_op = _minimize_towers(tower_specs)
- return _train_spec(
- tower_specs, train_op, aggregation_device=consolidation_device)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- return _eval_spec(tower_specs, aggregation_device=consolidation_device)
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- return _predict_spec(tower_specs, aggregation_device=consolidation_device)
-
- if len(devices) == 1:
- return single_device_model_fn
- else:
- return replicated_model_fn
-
-
-class _TowerOptimizer(optimizer_lib.Optimizer):
- """Gathers gradients from all towers and reduces them in the last one."""
-
- COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
-
- def __init__(self, optimizer_or_optimizer_fn,
- loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE):
- """Wrap an existing optimizer for gathering gradients across towers.
-
- Each invocation of model_fn has to call the same optimizers in the same
- order.
-
- Multiple optimizers that use the same or different losses are supported.
-
- If _TowerOptimizer is used but `replicate_model_fn` isn't, then no
- aggregation will happen. All calls will simply be forwarded to the
- underlying optimizer. The behavior is similar if there is only one tower.
-
- If _TowerOptimizer is used together with SyncReplicasOptimizer that wraps
- the user's optimizer, then it's the SyncReplicasOptimizer that needs to be
- wrapped with _TowerOptimizer.
-
- Args:
- optimizer_or_optimizer_fn: an instance of optimizer to wrap. That
- instance is going to be used for optimizer-specific logic. This can
- also be a no-argument function that returns such an optimizer instance.
- loss_reduction: controls whether losses are summed or averaged.
- """
- self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
- self._loss_reduction = loss_reduction
-
- @staticmethod
- def has_been_used():
- return _TowerOptimizer._graph_state().has_tower_optimizer_been_used
-
- def get_slot(self, *args, **kwargs):
- return self._get_optimizer().get_slot(*args, **kwargs)
-
- def get_slot_names(self, *args, **kwargs):
- return self._get_optimizer().get_slot_names(*args, **kwargs)
-
- def get_name(self, *args, **kwargs):
- return self._get_optimizer().get_name(*args, **kwargs)
-
- def variables(self, *args, **kwargs):
- return self._get_optimizer().variables(*args, **kwargs)
-
- def compute_gradients(self, loss, *args, **kwargs):
- """Compute gradients, but first, if needed, scale the loss."""
- _TowerOptimizer._graph_state().set_loss_reduction(self._loss_reduction)
- loss = _scale_loss(loss,
- self._loss_reduction,
- self._graph_state().number_of_towers)
- return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
-
- def apply_gradients(self, grads_and_vars, global_step=None, **kwargs):
- """Collect gradients updates to apply them with the last tower."""
- if self._graph_state().number_of_towers == 1:
- # Avoid the overhead of reduction if there's only one tower.
- #
- # There assumed to be only one tower if aggregation-related methods were
- # not called by `_get_loss_towers`, for example if the model_fn uses
- # TowerEstimator, but `replicate_model_fn` isn't used.
- return self._get_optimizer().apply_gradients(grads_and_vars, global_step,
- **kwargs)
-
- self._graph_state().collect_gradients(grads_and_vars)
-
- if not self._graph_state().is_the_last_tower:
- with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)):
- return self._construct_no_op_train_op()
- else:
- # Gradients need to be gathered and applied in the scope of the first
- # tower, so that the tensors are accessible via names without prefixes.
- var_scope, name_scope = self._graph_state().scopes_of_the_first_tower
- with variable_scope.variable_scope(var_scope):
- with ops_lib.name_scope(name_scope):
- return self._apply_gathered_gradients(global_step, **kwargs)
-
- def _apply_gathered_gradients(self, global_step, **kwargs):
- graph_state = self._graph_state()
- optimizer = self._get_optimizer()
-
- grad_lists = {}
- for grad, var in graph_state.get_latest_gradients_from_all_towers():
- if grad is not None:
- grad_lists.setdefault(var, []).append(grad)
-
- aggregated_grads = []
- with ops_lib.name_scope('gradient_aggregating'):
- for var, grads in six.iteritems(grad_lists):
- grad = _compute_sum_on_device(grads, var.device)
- aggregated_grads.append((grad, var))
- return optimizer.apply_gradients(
- aggregated_grads, global_step=global_step, **kwargs)
-
- def _get_optimizer(self):
- if callable(self._optimizer_or_optimizer_fn):
- # If optimizer is given as a function then we need to wait till we are
- # under the right graph context before constructing it. That's why the
- # optimizer is constructed in _get_optimizer() rather than __init__().
- self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn()
- self._graph_state().has_tower_optimizer_been_used = True
- return self._optimizer_or_optimizer_fn
-
- def _construct_no_op_train_op(self):
- return control_flow_ops.no_op(name='train_op_placeholder')
-
- @staticmethod
- def _graph_state():
- graph_states = ops_lib.get_default_graph().get_collection_ref(
- _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
- if not graph_states:
- graph_states.append(_TowerOptimizer._PerGraphState())
- return graph_states[-1]
-
- @staticmethod
- def _did_towers_have_same_optimizer_calls():
- graph_state = _TowerOptimizer._graph_state()
- return graph_state.did_towers_have_same_optimizer_calls()
-
- @staticmethod
- def _clear_graph_state():
- # Clearing the Graph collection will prevent _PerGraphState from being
- # serialized.
- ops_lib.get_default_graph().clear_collection(
- _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
-
- class _PerGraphState(object):
- """Gradient reduction related state of a Tensorflow graph."""
-
- def __init__(self):
- self._collected_grads_and_vars = defaultdict(list)
- self._current_tower_index = 0
- self._number_of_towers = 1
- self._loss_reduction = None
- # Scopes of the first tower that don't have a prefix:
- self._variable_scope = None
- self._name_scope = None
- # If needed, alert that _TowerOptimizer needs to be used with model_fn.
- self._has_tower_optimizer_been_used = False
-
- def collect_gradients(self, grads_and_vars):
- self._collected_grads_and_vars[self._current_tower_index].append(
- grads_and_vars)
-
- def get_latest_gradients_from_all_towers(self):
- """Get gradients across towers for the last called optimizer."""
- grads_and_vars = []
- index_of_last_gradients = len(
- self._collected_grads_and_vars[self._current_tower_index]) - 1
- for tower_id in range(self._current_tower_index + 1):
- grads_and_vars.extend(
- self._collected_grads_and_vars[tower_id][index_of_last_gradients])
- return grads_and_vars
-
- def set_number_of_towers(self, number_of_towers):
- self._number_of_towers = number_of_towers
-
- def set_loss_reduction(self, loss_reduction):
- self._loss_reduction = loss_reduction
-
- @contextmanager
- def tower(self, tower_id, var_scope, name_scope):
- if tower_id == 0:
- self._variable_scope = var_scope
- self._name_scope = name_scope
- self._current_tower_index = tower_id
- yield
-
- @property
- def scopes_of_the_first_tower(self):
- return self._variable_scope, self._name_scope
-
- @property
- def is_the_last_tower(self):
- return self._current_tower_index == (self._number_of_towers - 1)
-
- @property
- def number_of_towers(self):
- return self._number_of_towers
-
- @property
- def loss_reduction(self):
- return self._loss_reduction
-
- @property
- def has_tower_optimizer_been_used(self):
- return self._has_tower_optimizer_been_used
-
- @has_tower_optimizer_been_used.setter
- def has_tower_optimizer_been_used(self, value):
- self._has_tower_optimizer_been_used = value
-
- def did_towers_have_same_optimizer_calls(self):
- total_number_of_grads = sum([
- len(grads)
- for _, grads in six.iteritems(self._collected_grads_and_vars)
- ])
- return total_number_of_grads % self._number_of_towers == 0
-
-
-def _get_local_devices(device_type):
- local_device_protos = device_lib.list_local_devices()
- return [
- device.name
- for device in local_device_protos
- if device.device_type == device_type
- ]
-
-
-def _split_batch(features, labels, number_of_shards, device):
- """Split input features and labes into batches."""
-
- def ensure_divisible_by_shards(sequence):
- batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
- if batch_size % number_of_shards != 0:
- raise ValueError(
- 'Batch size {} needs to be divisible by the number of GPUs, which '
- 'is {}.'.format(batch_size, number_of_shards))
-
- def split_dictionary(dictionary):
- """Split a dictionary into shards."""
- shards = [{} for _ in range(number_of_shards)]
- for name, tensor in six.iteritems(dictionary):
- if isinstance(tensor, sparse_tensor.SparseTensor):
- for i, shard in enumerate(
- sparse_ops.sparse_split(
- sp_input=tensor, num_split=number_of_shards, axis=0)):
- shards[i][name] = shard
- else:
- ensure_divisible_by_shards(tensor)
- for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
- shards[i][name] = shard
- return shards
-
- with ops_lib.name_scope('split_inputs'):
- with ops_lib.device(device):
- if isinstance(features, dict):
- feature_shards = split_dictionary(features)
- else:
- ensure_divisible_by_shards(features)
- feature_shards = array_ops.split(features, number_of_shards)
-
- if labels is None:
- label_shards = None
- elif isinstance(labels, dict):
- label_shards = split_dictionary(labels)
- else:
- ensure_divisible_by_shards(labels)
- label_shards = array_ops.split(labels, number_of_shards)
- return feature_shards, label_shards
-
-
-_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}'
-
-
-def _get_loss_towers(model_fn,
- mode,
- features,
- labels,
- params,
- config,
- devices,
- local_ps_devices,
- name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
- """Replicate the loss computation across devices."""
- tower_specs = []
-
- model_fn_args = util.fn_args(model_fn)
- optional_params = {}
- if 'params' in model_fn_args:
- optional_params['params'] = copy.deepcopy(params)
- if 'config' in model_fn_args:
- optional_params['config'] = copy.deepcopy(config)
-
- # pylint: disable=protected-access
- round_robin_strategy = device_setter_lib._RoundRobinStrategy(
- num_tasks=len(local_ps_devices))
- _TowerOptimizer._graph_state().set_number_of_towers(len(devices))
-
- for i, device in enumerate(devices):
- is_the_first_tower = (i == 0)
-
- device_setter = _local_device_setter(
- worker_device=device,
- ps_devices=local_ps_devices,
- ps_strategy=round_robin_strategy)
-
- # We would like to preserve the names of the variables and ops that the user
- # might be relying on. Names without a prefix are going to resolve to
- # variables and ops of the first tower.
- name_scope = name_scope_pattern
- if is_the_first_tower:
- name_scope = ''
-
- with variable_scope.variable_scope(
- '', reuse=not is_the_first_tower) as var_scope:
- with ops_lib.name_scope(name_scope.format(i)) as name_scope:
- with _TowerOptimizer._graph_state().tower(
- tower_id=i, var_scope=var_scope, name_scope=name_scope):
- with ops_lib.device(device_setter):
- labels_shard = None
- if labels:
- labels_shard = labels[i]
-
- tower_spec = model_fn(
- mode=mode,
- features=features[i],
- labels=labels_shard,
- **optional_params)
-
- if (tower_spec.train_op is not None and len(devices) > 1 and
- not _TowerOptimizer.has_been_used()):
- raise ValueError('Please wrap optimizers with _TowerOptimizer'
- ' in order to use replicate_model_fn with'
- ' multiple `devices`.')
-
- # Scaling the loss here doesn't actually affect gradients. Another
- # instance of scaling happens inside the _TowerOptimizer.
- tower_spec = _scale_tower_loss(
- tower_spec,
- _TowerOptimizer._graph_state().loss_reduction,
- number_of_towers=len(devices))
- tower_specs.append(tower_spec)
-
- if not _TowerOptimizer._did_towers_have_same_optimizer_calls():
- raise ValueError('Each invocation of model_fn was supposed to make the same'
- ' optimizer calls.')
- _TowerOptimizer._clear_graph_state()
- # pylint: enable=protected-access
- return tower_specs
-
-
-def _local_device_setter(worker_device, ps_devices, ps_strategy):
- """A device setter that puts distributes Var/Ops to PS/workers."""
- ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
-
- def local_device_chooser(op):
- current_device = framework_device.DeviceSpec.from_string(op.device or '')
-
- node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
- if node_def.op in ps_ops:
- ps_device_spec = framework_device.DeviceSpec.from_string(
- '{}'.format(ps_devices[ps_strategy(op)]))
-
- ps_device_spec.merge_from(current_device)
- return ps_device_spec.to_string()
- else:
- worker_device_spec = framework_device.DeviceSpec.from_string(
- worker_device or '')
- worker_device_spec.merge_from(current_device)
- return worker_device_spec.to_string()
-
- return local_device_chooser
-
-
-def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
- """Produce an EstimatorSpec with approproriately scaled loss."""
- if tower_spec.loss is None:
- return tower_spec
-
- estimator_spec = _asdict(tower_spec)
- estimator_spec['loss'] = _scale_loss(
- tower_spec.loss,
- loss_reduction,
- number_of_towers,
- reduced_loss_name='averaged_loss')
- return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _scale_loss(loss, loss_reduction, number_of_towers, reduced_loss_name=None):
- """If needed, scale down the loss for averaging loss by summing."""
- if loss is None:
- return None
- if number_of_towers == 1:
- return loss
-
- if loss_reduction == losses.Reduction.NONE:
- raise ValueError('Tower losses need to be reduced in some way, yet {} '
- 'reduction is specified.'.format(loss_reduction))
-
- if loss_reduction != losses.Reduction.SUM:
- return math_ops.div(loss, 1.0 * number_of_towers, name=reduced_loss_name)
- else:
- return loss
-
-
-def _minimize_towers(tower_specs):
- """`train_op` of the last tower applies aggregated gradients."""
- return tower_specs[-1].train_op
-
-
-def _compute_sum_on_device(values, device, name=None):
- with ops_lib.device(device):
- if isinstance(values[0], ops_lib.IndexedSlices):
- if name:
- raise ValueError('The name {} is not expected to be given to '
- 'IndexedSlices {}'.format(name, values))
-
- values_concat = array_ops.concat([v.values for v in values], axis=0)
- indices_concat = array_ops.concat([v.indices for v in values], axis=0)
- return ops_lib.IndexedSlices(values_concat, indices_concat,
- values[0].dense_shape)
- else:
- return math_ops.add_n(values, name=name)
-
-
-def _train_spec(tower_specs,
- train_op,
- aggregation_device,
- aggregated_loss_name='loss'):
- """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`."""
- # Spec of the last tower is used as the template for the final spec, because
- # some `EstimatorSpec.training_hooks` rely on calls made in model_fn. For
- # example, `SyncReplicasOptimizerHook` validates the
- # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that
- # call only in the last tower.
- estimator_spec = _asdict(tower_specs[-1])
- estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN
- estimator_spec['train_op'] = train_op
- estimator_spec['loss'] = _compute_sum_on_device(
- [spec.loss for spec in tower_specs], aggregation_device,
- aggregated_loss_name)
- return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
- """Populate replicated EstimatorSpec for `GraphKeys.EVAL`."""
- estimator_spec = _asdict(tower_specs[0])
- estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL
- estimator_spec['loss'] = _compute_sum_on_device(
- [spec.loss for spec in tower_specs], aggregation_device,
- aggregated_loss_name)
-
- update_ops = []
- for tower_spec in tower_specs:
- for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
- update_ops.append(update_op)
-
- with ops_lib.control_dependencies(update_ops):
- reduced_update_op = _reduce_metric_variables(len(tower_specs))
-
- eval_metric_ops = {}
- for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
- eval_metric_ops[name] = (metric_tensor, reduced_update_op)
- estimator_spec['eval_metric_ops'] = eval_metric_ops
- return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _reduce_metric_variables(number_of_towers):
- """Aggregate local variables used in metrics into the first tower."""
- if number_of_towers == 1:
- return control_flow_ops.no_op(name='no_eval_metric_reduction')
-
- metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
- variables_per_tower = len(metric_variables) // number_of_towers
-
- if len(metric_variables) % number_of_towers != 0:
- raise ValueError(
- 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.'
- ' Expected {} local variables, but got {} instead.'.format(
- variables_per_tower * number_of_towers, len(metric_variables)))
-
- # `metric_variables` has the size of `variables_per_tower` x
- # number_of_towers. Each tower is produced by calling the same model_fn.
- # First `variables_per_tower` correspond to the first tower. Each such
- # variable has an replica at the `(variables_per_tower * i)` position, where
- # `i` is `[1.. number_of_towers]`. We are going to add values from replicas
- # to each variable of the first tower. We then zero out replica values, so
- # that `_reduce_metric_variables` operation is idempotent. If a metric
- # is then computed based on local variables from the first tower, then the
- # resulting metric is an estimate for all `number_of_towers` towers.
- ops = []
- for i in range(0, variables_per_tower):
- next_replica_id = i + variables_per_tower
- replicas = [
- metric_variables[replica_id]
- for replica_id in range(next_replica_id, len(metric_variables),
- variables_per_tower)
- ] # `replicas` doesn't contain the first-tower variable.
-
- reduce_op = state_ops.assign_add(metric_variables[i],
- math_ops.add_n(replicas))
-
- with ops_lib.control_dependencies([reduce_op]):
- for replica in replicas:
- zeros_for_replica = array_ops.zeros(
- array_ops.shape(replica), dtype=replica.dtype)
- zero_out_replica_op = state_ops.assign(replica, zeros_for_replica)
- ops.append(zero_out_replica_op)
-
- return control_flow_ops.group(*ops)
-
-
-def _predict_spec(tower_specs, aggregation_device):
- """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
- estimator_spec = _asdict(tower_specs[0])
- estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT
-
- with ops_lib.device(aggregation_device):
- estimator_spec['predictions'] = _concat_tensor_dicts(
- *[tower_spec.predictions for tower_spec in tower_specs])
-
- export_outputs_dict = _dict_concat(
- *[tower_spec.export_outputs for tower_spec in tower_specs])
-
- export_outputs = {}
- for name, export_output_list in six.iteritems(export_outputs_dict):
- if isinstance(export_output_list[0], export_output_lib.PredictOutput):
- export_outputs[name] = export_output_lib.PredictOutput(
- outputs=_concat_tensor_dicts(*[
- export_output.outputs for export_output in export_output_list
- ]))
- elif isinstance(export_output_list[0],
- export_output_lib.RegressionOutput):
- export_outputs[name] = export_output_lib.RegressionOutput(
- value=array_ops.concat(
- [export_output.value for export_output in export_output_list],
- axis=0))
- elif isinstance(export_output_list[0],
- export_output_lib.ClassificationOutput):
- scores = None
- if export_output_list[0].scores is not None:
- scores = array_ops.concat(
- [export_output.scores for export_output in export_output_list],
- axis=0)
-
- classes = None
- if export_output_list[0].classes is not None:
- classes = array_ops.stack(
- [export_output.classes for export_output in export_output_list],
- axis=0)
-
- export_outputs[name] = export_output_lib.ClassificationOutput(
- scores=scores, classes=classes)
-
- estimator_spec['export_outputs'] = export_outputs
- return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _concat_tensor_dicts(*tensor_dicts):
- return {
- name: array_ops.concat(tensors, axis=0, name=name)
- for name, tensors in six.iteritems(_dict_concat(*tensor_dicts))
- }
-
-
-def _extract_tensors(tensors_and_vars):
- tensors = []
- for tensor_and_var in tensors_and_vars:
- tensor, _ = tensor_and_var
- if isinstance(tensor, ops_lib.IndexedSlices):
- tensors.append(tensor.values)
- elif tensor is not None:
- tensors.append(tensor)
- return tensors
-
-
-def _dict_concat(*dicts):
- list_dict = {}
- for d in dicts:
- if d is None:
- continue
-
- for k, v in six.iteritems(d):
- list_dict.setdefault(k, []).append(v)
- return list_dict
-
-
-def _asdict(namedtuple):
- """Returns a namedtuple as a dictionary.
-
- This is required because `_asdict()` in Python 3.x.x is broken in classes
- that inherit from `collections.namedtuple`. See
- https://bugs.python.org/issue24931 for more details.
-
- Args:
- namedtuple: An object that inherits from `collections.namedtuple`.
-
- Returns:
- A dictionary version of the tuple.
- """
- return {k: getattr(namedtuple, k) for k in namedtuple._fields}
diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py
deleted file mode 100644
index ad1f9c02b9..0000000000
--- a/tensorflow/python/estimator/replicate_model_fn_test.py
+++ /dev/null
@@ -1,1739 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-import shutil
-import tempfile
-import numpy as np
-import six
-
-from tensorflow.python.estimator import estimator as estimator_lib
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import replicate_model_fn
-from tensorflow.python.estimator.canned import dnn
-from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.estimator.canned import prediction_keys
-from tensorflow.python.estimator.export import export
-from tensorflow.python.estimator.export import export_output
-from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import losses
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import metrics as metrics_lib
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.ops.losses import losses
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import adam
-from tensorflow.python.training import device_setter
-from tensorflow.python.training import gradient_descent
-from tensorflow.python.training import training
-
-
-# TODO(isaprykin): Parametrize all the tests on
-# replicate_model_fn._VariableDistributionMode when it's supported.
-class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self._model_dir = tempfile.mkdtemp()
-
- def test_complete_flow_with_public_version(self):
- return self._complete_flow_with_mode(mode=None)
-
- def test_complete_flow_with_mode_local_ps_server(self):
- return self._complete_flow_with_mode(
- replicate_model_fn._VariableDistributionMode.
- SHARED_LOCAL_PARAMETER_SERVER)
-
- def test_complete_flow_with_mode_round_robin(self):
- return self._complete_flow_with_mode(
- replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)
-
- def _complete_flow_with_mode(self, mode):
- n_classes = 3
- input_dimension = 2
- batch_size = 12
-
- data = np.linspace(
- 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
- x_data = data.reshape(batch_size, input_dimension)
- categorical_data = np.random.random_integers(
- 0, len(x_data), size=len(x_data))
- y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
- train_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data,
- 'categories': categorical_data},
- y=y_data,
- batch_size=batch_size,
- num_epochs=None,
- shuffle=True)
- eval_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data,
- 'categories': categorical_data},
- y=y_data,
- batch_size=batch_size,
- shuffle=False)
- predict_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data,
- 'categories': categorical_data},
- batch_size=batch_size,
- shuffle=False)
-
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,)),
- feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
- 'categories',
- vocabulary_list=np.linspace(
- 0., len(x_data), len(x_data), dtype=np.int64)), 1)
- ]
-
- def optimizer_fn():
- return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
-
- estimator = dnn.DNNClassifier(
- hidden_units=(2, 2),
- # Adagrad is configured with `get_optimizer_instance`, so the function
- # form of `TowerOptimizer.__init__` is used.
- optimizer=replicate_model_fn._TowerOptimizer(
- optimizer_fn, loss_reduction=losses.Reduction.SUM),
- feature_columns=feature_columns,
- n_classes=n_classes,
- model_dir=self._model_dir)
-
- if not mode: # Use the public `replicate_model_fn`.
- model_fn = replicate_model_fn._replicate_model_fn(
- estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'])
- else:
- model_fn = replicate_model_fn._replicate_model_fn_with_mode(
- estimator.model_fn,
- devices=['/gpu:0', '/gpu:1', '/gpu:2'],
- mode=mode)
-
- estimator = estimator_lib.Estimator(
- model_fn=model_fn,
- model_dir=estimator.model_dir,
- config=estimator.config,
- params=estimator.params)
-
- num_steps = 10
- estimator.train(train_input_fn, steps=num_steps)
-
- scores = estimator.evaluate(eval_input_fn)
- self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
- self.assertIn('loss', six.iterkeys(scores))
-
- predicted_proba = np.array([
- x[prediction_keys.PredictionKeys.PROBABILITIES]
- for x in estimator.predict(predict_input_fn)
- ])
- self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
-
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
- serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
- feature_spec)
- export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
- serving_input_receiver_fn)
- self.assertTrue(gfile.Exists(export_dir))
-
- # Nothing should be left in the graph so that it doesn't get serialized.
- self.assertFalse(ops_lib.get_default_graph().get_collection_ref(
- replicate_model_fn._TowerOptimizer.COLLECTION_FOR_GRAPH_STATES))
-
- def _as_label(self, data_in_float):
- return np.rint(data_in_float).astype(np.int64)
-
- def tearDown(self):
- if self._model_dir:
- writer_cache.FileWriterCache.clear()
- shutil.rmtree(self._model_dir)
-
-
-class ReplicateModelTest(test_util.TensorFlowTestCase):
-
- def create_model_fn_with_loss_reduction(self, loss_reduction):
-
- def model_fn(mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels,
- predictions=predictions,
- reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(params['learning_rate']),
- loss_reduction=loss_reduction)
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
-
- return model_fn
-
- @property
- def params(self):
- params = {}
- params['learning_rate'] = 1.0
- return params
-
- def test_train(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
- session.run(variables.global_variables_initializer())
-
- # loss = feature * c - label
- total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- # derivative of loss = (1*c - 1) + (2*c - 2) is 3.
- # new value of c = 10 - learning rate * 3 = 7.0.
- session.run(estimator_spec.train_op)
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(7.0, session.run(c))
-
- def test_train_with_mean_reduction(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session() as session:
- # Add another trainable variable that doesn't produce a gradient to
- # verify that None gradients are supported.
- _ = variable_scope.get_variable(
- 'another_variable',
- initializer=constant_op.constant(1, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
- session.run(variables.global_variables_initializer())
-
- # loss = feature * c - label
- total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5.
- # It's the same computation as without mean reduction, but the
- # loss from every tower is scaled by 1/<number of towers>.
- # new value of c = 10 - learning rate * 1.5 = 8.5
- session.run(estimator_spec.train_op)
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(8.5, session.run(c))
-
- def test_train_two_steps_collected_gradients_are_reset_between_steps(self):
- with ops_lib.Graph().as_default():
- features = array_ops.placeholder(dtypes.float64)
- labels = array_ops.placeholder(dtypes.float64)
-
- feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
- label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
-
- # loss = feature * c - label
- expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0),
- (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5))
- # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5
- # for the second.
- expected_c = 10.0 - 3.0, 7.0 - 4.0
-
- with self.test_session() as session, variable_scope.variable_scope(
- '', reuse=variable_scope.AUTO_REUSE):
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
- session.run(variables.global_variables_initializer())
-
- for feature_input, label_input, loss, weight in zip(
- feature_inputs, label_inputs, expected_losses, expected_c):
- feeds = {features: feature_input, labels: label_input}
-
- self.assertEqual(loss, session.run(estimator_spec.loss, feeds))
-
- session.run(estimator_spec.train_op, feeds)
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(weight, session.run(c, feeds))
-
- def test_eval(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
- session.run(variables.local_variables_initializer())
- session.run(variables.global_variables_initializer())
-
- accuracy, a = estimator_spec.eval_metric_ops['accuracy']
- auc, b = estimator_spec.eval_metric_ops['auc']
-
- session.run([a, b])
- accuracy = session.run(accuracy)
- auc = session.run(auc)
-
- # loss[i] = features[i] * 10 - labels[i].
- # Accuracy is 0.0 (no match) in the first tower.
- # Accuracy is 1.0 (match) in the second tower, since the feature
- # times weight "c" happened to be equal to the label.
- total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-
- self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
- self.assertEqual(0, auc)
- self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
- def test_eval_with_mean_reduction(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
- session.run(variables.local_variables_initializer())
- session.run(variables.global_variables_initializer())
-
- accuracy, a = estimator_spec.eval_metric_ops['accuracy']
- auc, b = estimator_spec.eval_metric_ops['auc']
-
- session.run([a, b])
- accuracy = session.run(accuracy)
- auc = session.run(auc)
-
- # loss[i] = features[i] * 10 - labels[i].
- # Accuracy is 0.0 (no match) in the first tower.
- # Accuracy is 1.0 (match) in the second tower, since the feature
- # times weight "c" happened to be equal to the label.
- total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0
-
- self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
- self.assertEqual(0, auc)
- self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
- def test_predict(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
- session.run(variables.global_variables_initializer())
-
- self.assertAllClose({
- 'probabilities': np.array([[0.1], [0.02]])
- }, session.run(estimator_spec.predictions))
-
- def test_train_single_tower(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
- session.run(variables.global_variables_initializer())
-
- # loss = feature * c - label
- total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- # loss' of c is 3.
- # new value of c = 10 - learning rate * 3 = 7.0.
- session.run(estimator_spec.train_op)
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(7.0, session.run(c))
-
- def test_eval_single_tower(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
- session.run(variables.local_variables_initializer())
- session.run(variables.global_variables_initializer())
-
- accuracy, a = estimator_spec.eval_metric_ops['accuracy']
- auc, b = estimator_spec.eval_metric_ops['auc']
-
- session.run([a, b])
- accuracy = session.run(accuracy)
- auc = session.run(auc)
-
- # Accuracy is 0.0 (no match) in the first tower.
- # Accuracy is 1.0 (match) in the second tower, since the feature
- # times weight "c" happened to be equal to the label.
- total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-
- self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
- self.assertEqual(0, auc)
- self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
- def test_predict_single_tower(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
- session.run(variables.global_variables_initializer())
-
- self.assertAllClose({
- 'probabilities': np.array([[0.1], [0.02]])
- }, session.run(estimator_spec.predictions))
-
- def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
- features = np.array([[1.0], [2.0], [3.0]])
- labels = np.array([[1.0], [2.0], [3.0]])
-
- with self.assertRaisesRegexp(
- ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0', '/gpu:1'])
- _ = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
- def test_unsupported_loss_reduction(self):
- features = np.array([[1.0], [2.0], [3.0]])
- labels = np.array([[1.0], [2.0], [3.0]])
-
- with self.assertRaisesRegexp(ValueError,
- '.+none.+reduction.+is.+specified.+'):
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.NONE),
- devices=['/gpu:0', '/gpu:1', '/gpu:2'])
- _ = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
- def test_places_on_gpu_with_upper_case_spelling(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session():
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/GPU:0'])
- _ = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:0', c.device)
-
- def test_places_on_gpu_with_lower_case_spelling(self):
- features = np.array([[0.01], [0.002]])
- labels = np.array([[0.01], [0.02]])
-
- with self.test_session():
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- devices=['/gpu:0'])
- _ = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:0', c.device)
-
-
-class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
- test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = gradient_descent.GradientDescentOptimizer(
- params['learning_rate'])
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
-
- @property
- def params(self):
- params = {}
- params['learning_rate'] = 1.0
- return params
-
- def test_train_single_tower(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(
- features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
- session.run(variables.global_variables_initializer())
-
- # loss = feature * c - label
- total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- # loss' of c is 3.
- # new value of c = 10 - learning rate * 3 = 7.0.
- session.run(estimator_spec.train_op)
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(7.0, session.run(c))
-
-
-class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- features = features['features']
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(params['learning_rate']))
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
-
- @property
- def params(self):
- params = {}
- params['learning_rate'] = 1.0
- return params
-
- def test_train_single_tower(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'features': features}, y=labels, batch_size=2, shuffle=False)
-
- with self.test_session():
- estimator = estimator_lib.Estimator(
- model_fn=self.model_fn,
- model_dir=tempfile.mkdtemp(),
- params=self.params)
- estimator.train(train_input_fn, steps=1)
-
- self.assertEqual(7.0, estimator.get_variable_value('c'))
-
-
-class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- features = features['features']
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = gradient_descent.GradientDescentOptimizer(
- params['learning_rate'])
- optimizer = training.SyncReplicasOptimizer(
- optimizer, replicas_to_aggregate=1)
- sync_hook = optimizer.make_session_run_hook(True)
- optimizer = replicate_model_fn._TowerOptimizer(
- optimizer, loss_reduction=losses.Reduction.SUM)
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- training_hooks=[sync_hook],
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(
- loss, global_step=training.get_global_step()))
-
- @property
- def params(self):
- params = {}
- params['learning_rate'] = 1.0
- return params
-
- def test_train_multiple_towers(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'features': features}, y=labels, batch_size=2, shuffle=False)
-
- model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- devices=['/gpu:0', '/gpu:1'])
-
- estimator = estimator_lib.Estimator(
- model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params)
- estimator.train(train_input_fn, steps=1)
-
- self.assertEqual(7.0, estimator.get_variable_value('c'))
-
-
-class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- side_effects = variable_scope.get_variable(
- 'side_effects',
- initializer=constant_op.constant(0, dtype=dtypes.float64),
- dtype=dtypes.float64,
- use_resource=True,
- trainable=False)
-
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- first_optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0),
- loss_reduction=losses.Reduction.SUM)
- second_optimizer = replicate_model_fn._TowerOptimizer(
- adam.AdamOptimizer(1.0), loss_reduction=losses.Reduction.SUM)
-
- with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
- first_grads_and_vars = first_optimizer.compute_gradients(loss)
-
- train_op = control_flow_ops.group(
- [first_optimizer.apply_gradients(first_grads_and_vars),
- second_optimizer.minimize(loss)])
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=train_op)
-
- def test_train(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(features, labels,
- model_fn_lib.ModeKeys.TRAIN, {})
- session.run(variables.global_variables_initializer())
-
- # loss = feature * c - label
- total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- # loss' of c is 3.
- # new value of c = 10 - learning rate * 3 = 7.0.
- # Adam subtracts another ~1.
- session.run(estimator_spec.train_op)
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertNear(6.0, session.run(c), 0.000001)
-
- side_effects = variable_scope.get_variable(
- 'side_effects', dtype=dtypes.float64)
- self.assertNear(2.0, session.run(side_effects), 0.000001)
-
-
-class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self._should_skip_optimizer = False
- self._towers_left_before_skipping_optimizer = -1
-
- def incorrectly_skip_optimizer_for_tower(self, tower_number):
- self._should_skip_optimizer = True
- self._towers_left_before_skipping_optimizer = tower_number
-
- def should_skip_optimizer(self):
- if not self._should_skip_optimizer:
- return False
- if self._towers_left_before_skipping_optimizer == 0:
- return True
- else:
- self._towers_left_before_skipping_optimizer -= 1
- return False
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
- d = variable_scope.get_variable(
- 'd',
- initializer=constant_op.constant(2, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- another_predictions = math_ops.multiply(features, d)
- another_loss = losses.absolute_difference(
- labels=labels,
- predictions=another_predictions,
- reduction=losses.Reduction.SUM)
- another_loss = math_ops.reduce_sum(another_loss)
-
- total_loss = math_ops.add(loss, another_loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- train_ops = []
-
- optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0),
- loss_reduction=losses.Reduction.SUM)
- train_ops.append(optimizer.minimize(loss, var_list=[c]))
- if not self.should_skip_optimizer():
- another_optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0),
- loss_reduction=losses.Reduction.SUM)
- train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
-
- train_op = control_flow_ops.group(train_ops)
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=total_loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=train_op)
-
- def test_train(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with ops_lib.Graph().as_default(), self.test_session() as session:
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(features, labels,
- model_fn_lib.ModeKeys.TRAIN, {})
- session.run(variables.global_variables_initializer())
-
- # For each tower, loss = (feature * c - label) + (feature * d - label).
- total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + (
- 2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0)
- self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
- session.run(estimator_spec.train_op)
-
- # loss' of c or loss' of d is 3.
- # new value of c = 10 - learning rate * 3 = 7.0.
- # new value of d = 2 - learning rate * 3 = -1.0.
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertNear(7.0, session.run(c), 0.000001)
- d = variable_scope.get_variable('d', dtype=dtypes.float64)
- self.assertNear(-1.0, session.run(d), 0.000001)
-
- def test_different_optimizer_calls_within_towers(self):
- self.incorrectly_skip_optimizer_for_tower(1)
-
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session(), ops_lib.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0', '/gpu:1'])
- _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
- {})
-
-
-class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = gradient_descent.GradientDescentOptimizer(1.0)
- train_op = optimizer.minimize(loss)
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=train_op)
-
- def test_train(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- with self.test_session():
- with self.assertRaisesRegexp(ValueError,
- 'Please.+wrap.+with.+TowerOptimizer'):
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0', '/gpu:1'])
- _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
- {})
-
-
-class GetLossTowersTest(test_util.TensorFlowTestCase):
-
- def create_model_fn_with_loss_reduction(self, loss_reduction):
-
- def model_fn(mode, features, labels, params):
- del params
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(0.25, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
- labels = np.array([0.1, 0.2, 0.3, labels[0]])
-
- loss = losses.absolute_difference(
- labels=labels,
- predictions=predictions,
- reduction=losses.Reduction.SUM)
-
- optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0),
- loss_reduction)
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=math_ops.reduce_sum(loss),
- train_op=optimizer.minimize(loss))
-
- return model_fn
-
- def test_gradients_are_computed(self):
- with self.test_session() as session:
- tower_specs = replicate_model_fn._get_loss_towers(
- self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
- mode=None,
- features=[[0.6], [1.6]],
- labels=[[0.6], [0.6]],
- params=None,
- config=None,
- devices=['/gpu:0', '/gpu:1'],
- local_ps_devices=['/gpu:0'],
- name_scope_pattern='test_tower_{}')
- session.run(variables.global_variables_initializer())
-
- self.assertEqual(len(tower_specs), 2)
-
- self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
- self.assertEqual('Sum:0', tower_specs[0].loss.name)
- self.assertEqual(1.0, session.run(tower_specs[0].loss))
-
- self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
- self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
- # The input batch for the second tower had a loss that is 1.0
- # bigger: 0.6 vs 1.6.
- self.assertEqual(2.0, session.run(tower_specs[1].loss))
-
- self.assertEqual(1, len(variables.global_variables()))
- self.assertEqual(1, len(variables.trainable_variables()))
-
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(0.25, session.run(c))
-
- def test_gradients_are_computed_with_mean_reduction(self):
- with self.test_session() as session:
- tower_specs = replicate_model_fn._get_loss_towers(
- self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
- mode=model_fn_lib.ModeKeys.EVAL,
- features=[[0.6], [1.6]],
- labels=[[0.6], [0.6]],
- params=None,
- config=None,
- devices=['/gpu:0', '/gpu:1'],
- local_ps_devices=['/gpu:0'],
- name_scope_pattern='test_tower_{}')
- session.run(variables.global_variables_initializer())
-
- self.assertEqual(len(tower_specs), 2)
-
- self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
- self.assertEqual('averaged_loss:0', tower_specs[0].loss.name)
- self.assertEqual(0.5, session.run(tower_specs[0].loss))
-
- self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
- self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name)
- # The input batch for the second tower had a loss that is 1.0
- # bigger: 0.6 vs 1.6.
- self.assertEqual(1.0, session.run(tower_specs[1].loss))
-
- self.assertEqual(1, len(variables.global_variables()))
- self.assertEqual(1, len(variables.trainable_variables()))
-
- with variable_scope.variable_scope('', reuse=True):
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual(0.25, session.run(c))
-
- def test_variables_are_round_robined_correctly(self):
- """Test that creates multiple variables and tests round-robin placement."""
-
- def model_fn(mode, features, labels, params):
- del params
- for variable_name in ['a', 'b', 'c', 'd']:
- c = variable_scope.get_variable(
- variable_name,
- initializer=constant_op.constant(0.25, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
- labels = np.array([0.1, 0.2, 0.3, labels[0]])
- loss = losses.absolute_difference(
- labels=labels,
- predictions=predictions,
- reduction=losses.Reduction.SUM)
- return model_fn_lib.EstimatorSpec(
- mode=mode, loss=math_ops.reduce_sum(loss))
-
- with self.test_session() as session:
- tower_specs = replicate_model_fn._get_loss_towers(
- model_fn,
- mode=None,
- features=[[0.6], [1.6], [2.6]],
- labels=[[0.6], [0.6], [2.6]],
- params=None,
- config=None,
- devices=['/gpu:0', '/gpu:1', '/gpu:3'],
- local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
- name_scope_pattern='test_tower_{}')
- session.run(variables.global_variables_initializer())
-
- self.assertEqual(len(tower_specs), 3)
- self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
- self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
- self.assertEqual('/device:GPU:3', tower_specs[2].loss.device)
-
- with variable_scope.variable_scope('', reuse=True):
- a = variable_scope.get_variable('a', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:0', a.device)
- b = variable_scope.get_variable('b', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:1', b.device)
- c = variable_scope.get_variable('c', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:3', c.device)
- d = variable_scope.get_variable('d', dtype=dtypes.float64)
- self.assertEqual('/device:GPU:0', d.device)
-
-
-class SplitBatchTest(test_util.TensorFlowTestCase):
-
- def evaluate_shards(self, first_list, second_list):
- evaluate_items = lambda x: x.eval()
- return list(map(evaluate_items, first_list)), list(
- map(evaluate_items, second_list))
-
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
- def test_simple_half_split(self):
- with self.test_session():
- features = [0.0, 1.0, 2.0, 3.0]
- labels = [10.0, 11.0, 12.0, 13.0]
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 2, device='/gpu:0')
-
- feature_shards, label_shards = self.evaluate_shards(
- feature_shards, label_shards)
-
- self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
- self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
-
- def test_to_each_their_own(self):
- with self.test_session():
- features = [0.0, 1.0, 2.0, 3.0]
- labels = [10.0, 11.0, 12.0, 13.0]
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 4, device='/gpu:0')
-
- feature_shards, label_shards = self.evaluate_shards(
- feature_shards, label_shards)
-
- self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
- self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
-
- def test_one_batch(self):
- with self.test_session():
- features = [0.0, 1.0, 2.0, 3.0]
- labels = [10.0, 11.0, 12.0, 13.0]
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 1, device='/gpu:0')
-
- feature_shards, label_shards = self.evaluate_shards(
- feature_shards, label_shards)
-
- self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
- self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
-
- def test_half_split_in_dictionary(self):
- with self.test_session():
- features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
- labels = [10.0, 11.0, 12.0, 13.0]
-
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 2, device='/gpu:0')
-
- self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
- self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
- self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
- self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
- self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
- self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
-
- def test_sparse_tensor_can_be_split_unevenly(self):
- with self.test_session():
- features = {
- 'x':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 2], [2, 2]],
- values=[1.0, 2.0, 3.0],
- dense_shape=[3, 4])
- }
- labels = np.array([[1.0], [2.0]])
-
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 2, device='/gpu:0')
-
- self.assertSparseValuesEqual(
- sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
- feature_shards[0]['x'].eval())
- self.assertSparseValuesEqual(
- sparse_tensor.SparseTensorValue(
- indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
- feature_shards[1]['x'].eval())
- self.assertAllEqual([[1.0]], label_shards[0].eval())
- self.assertAllEqual([[2.0]], label_shards[1].eval())
-
- def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
- with self.test_session():
- features = {
- 'x':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[1.0, 2.0, 3.0],
- dense_shape=[3, 4])
- }
- labels = np.array([[1.0], [2.0]])
-
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 2, device='/gpu:0')
-
- self.assertSparseValuesEqual(
- sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[1., 2., 3.],
- dense_shape=[2, 4]), feature_shards[0]['x'].eval())
-
- second_batch = feature_shards[1]['x'].eval()
- self.assertFalse(len(second_batch.indices))
- self.assertFalse(len(second_batch.values))
- self.assertAllEqual([1, 4], second_batch.dense_shape)
- self.assertAllEqual([[1.0]], label_shards[0].eval())
- self.assertAllEqual([[2.0]], label_shards[1].eval())
-
- def test_one_batch_in_dictionary(self):
- with self.test_session() as session: # pylint: disable=unused-variable
- features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
- labels = [10.0, 11.0, 12.0, 13.0]
-
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 1, device='/gpu:0')
-
- self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
- feature_shards[0]['first'].eval())
- self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
- feature_shards[0]['second'].eval())
- self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
-
- def test_feature_and_label_dictionaries(self):
- with self.test_session() as session: # pylint: disable=unused-variable
- features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
- labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
-
- feature_shards, label_shards = replicate_model_fn._split_batch(
- features, labels, 2, device='/gpu:0')
-
- self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
- self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
- self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
- self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
- self.assertAllEqual([10.0], label_shards[0]['first'].eval())
- self.assertAllEqual([12.0], label_shards[0]['second'].eval())
- self.assertAllEqual([11], label_shards[1]['first'].eval())
- self.assertAllEqual([13.0], label_shards[1]['second'].eval())
-
-
-class TrainSpecTest(test_util.TensorFlowTestCase):
-
- expected_predictions = {}
-
- def create_estimator_spec(self, loss):
- return model_fn_lib.EstimatorSpec(
- mode=model_fn_lib.ModeKeys.TRAIN,
- loss=loss,
- train_op=loss, # Not used; currently required.
- predictions=self.expected_predictions)
-
- def create_constant_loss(self, loss_value):
- return constant_op.constant(loss_value, dtype=dtypes.float64)
-
- def test_example(self):
- with self.test_session() as session:
- tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
- tower_specs = list(map(self.create_estimator_spec, tower_losses))
-
- expected_train_op = tower_losses[1]
-
- estimator_spec = replicate_model_fn._train_spec(
- tower_specs, expected_train_op, aggregation_device='/gpu:0')
-
- self.assertEqual(expected_train_op, estimator_spec.train_op)
- self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
- self.assertEqual(self.expected_predictions, estimator_spec.predictions)
-
-
-class EvalSpecTest(test_util.TensorFlowTestCase):
-
- def create_estimator_spec(self, loss, metrics):
- return model_fn_lib.EstimatorSpec(
- mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
-
- def create_constant_loss(self, loss_value):
- return constant_op.constant(loss_value, dtype=dtypes.float64)
-
- def create_eval_metrics(self, noise):
- predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
- labels = np.array([0.1, 0.2, 0.3, 0.6])
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
- return metrics
-
- def test_example(self):
- with self.test_session() as session:
- tower_losses = map(self.create_constant_loss, [2, 4, 6])
- tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
- tower_specs = [
- self.create_estimator_spec(l, m)
- for l, m in zip(tower_losses, tower_metrics)
- ]
- session.run(variables.local_variables_initializer())
-
- estimator_spec = replicate_model_fn._eval_spec(
- tower_specs, aggregation_device='/device:GPU:0')
-
- accuracy, a = estimator_spec.eval_metric_ops['accuracy']
- auc, b = estimator_spec.eval_metric_ops['auc']
-
- self.assertEqual('/device:CPU:0', accuracy.device)
- self.assertEqual('/device:CPU:0', auc.device)
-
- session.run([a, b])
- accuracy, auc = session.run([accuracy, auc])
-
- self.assertNear((12 - 2) / 12, accuracy, 0.01)
- self.assertEqual(0, auc)
- self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
-
- def test_handles_single_tower(self):
- with self.test_session() as session:
- tower_losses = map(self.create_constant_loss, [5])
- tower_metrics = map(self.create_eval_metrics, [0.2])
- tower_specs = [
- self.create_estimator_spec(l, m)
- for l, m in zip(tower_losses, tower_metrics)
- ]
- session.run(variables.local_variables_initializer())
-
- estimator_spec = replicate_model_fn._eval_spec(
- tower_specs, aggregation_device='/device:GPU:0')
-
- accuracy, a = estimator_spec.eval_metric_ops['accuracy']
- auc, b = estimator_spec.eval_metric_ops['auc']
-
- self.assertEqual('/device:CPU:0', accuracy.device)
- self.assertEqual('/device:CPU:0', auc.device)
-
- session.run([a, b])
- accuracy = session.run(accuracy)
- auc = session.run(auc)
-
- self.assertNear((4 - 1) / 4, accuracy, 0.01)
- self.assertEqual(0, auc)
- self.assertEqual(5, session.run(estimator_spec.loss))
-
-
-class PredictSpecTest(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(0.25, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = math_ops.add(np.array([features[0], features[0]]), c)
-
- return model_fn_lib.EstimatorSpec(
- mode=model_fn_lib.ModeKeys.PREDICT,
- predictions={
- 'probabilities': predictions
- })
-
- def test_example(self):
- with self.test_session() as session:
- tower_specs = replicate_model_fn._get_loss_towers(
- self.model_fn,
- mode=None,
- features=[[0.1], [0.2]],
- labels=[[], []],
- params=None,
- config=None,
- devices=['/gpu:0', '/gpu:1'],
- local_ps_devices=['/gpu:0'],
- )
- session.run(variables.global_variables_initializer())
-
- estimator_spec = replicate_model_fn._predict_spec(
- tower_specs, aggregation_device='/gpu:0')
-
- self.assertEqual('/device:GPU:0',
- estimator_spec.predictions['probabilities'].device)
- self.assertAllClose({
- 'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
- }, session.run(estimator_spec.predictions))
-
-
-class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
-
- def create_metric_variable(self, initial_value, name):
- return variable_scope.variable(
- initial_value,
- trainable=False,
- collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
- validate_shape=True,
- name=name)
-
- def create_tower_metrics(self, tower_id):
- with variable_scope.variable_scope('', reuse=(tower_id != 0)):
- self.create_metric_variable(1.3 * (tower_id + 1), 'total')
- self.create_metric_variable(2.3 * (tower_id + 1), 'count')
- self.create_metric_variable(
- np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
-
- def test_example(self):
- with self.test_session() as session:
- for tower_id in range(3):
- self.create_tower_metrics(tower_id)
-
- session.run(
- variables.variables_initializer(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
- session.run(
- replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
- # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7]
- # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4]
- # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1]
- # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2]
- # Towers are accumulated in the first tower.
- local_metrics = session.run(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
- self.assertNear(7.8, local_metrics[0], 0.01)
- self.assertNear(13.8, local_metrics[1], 0.01)
- self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
- self.assertNear(0.0, local_metrics[3], 0.01)
- self.assertNear(0.0, local_metrics[4], 0.01)
- self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
- self.assertNear(0.0, local_metrics[6], 0.01)
- self.assertNear(0.0, local_metrics[7], 0.01)
- self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
-
- def test_reduce_is_idempotent(self):
- with self.test_session() as session:
- for tower_id in range(3):
- self.create_tower_metrics(tower_id)
-
- session.run(
- variables.variables_initializer(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
- for _ in range(20):
- session.run(
- replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
- local_metrics = session.run(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
- self.assertNear(7.8, local_metrics[0], 0.01)
- self.assertNear(13.8, local_metrics[1], 0.01)
- self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
- self.assertNear(0.0, local_metrics[3], 0.01)
- self.assertNear(0.0, local_metrics[4], 0.01)
- self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
- self.assertNear(0.0, local_metrics[6], 0.01)
- self.assertNear(0.0, local_metrics[7], 0.01)
- self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
-
- def test_handles_single_tower(self):
- with self.test_session() as session:
- self.create_tower_metrics(0)
- session.run(
- variables.variables_initializer(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
- session.run(
- replicate_model_fn._reduce_metric_variables(number_of_towers=1))
-
- local_metrics = session.run(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
- self.assertNear(1.3, local_metrics[0], 0.01)
- self.assertNear(2.3, local_metrics[1], 0.01)
- self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
-
- def test_doesnt_accept_uneven_number_of_variables(self):
- with self.test_session() as session:
- for tower_id in range(3):
- self.create_tower_metrics(tower_id)
- self.create_metric_variable(-1.0, 'oddball')
-
- session.run(
- variables.variables_initializer(
- ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
- with self.assertRaisesRegexp(
- ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'):
- session.run(
- replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
-
-class MergeExportOutputsTest(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- predictions = {'probabilities': math_ops.multiply(features, c)}
- loss = losses.absolute_difference(
- labels=labels,
- predictions=predictions['probabilities'],
- reduction=losses.Reduction.SUM)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
- 'auc': metrics_lib.auc(labels, predictions['probabilities'])
- }
- tensor_string_repr = str(features)
- classes = constant_op.constant(
- re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
- dtype=dtypes.string)
-
- export_outputs = {
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- export_output.PredictOutput(predictions),
- 'classification_output':
- export_output.ClassificationOutput(predictions['probabilities'],
- classes),
- 'classification_scores':
- export_output.ClassificationOutput(
- scores=predictions['probabilities']),
- 'classification_classes':
- export_output.ClassificationOutput(classes=classes),
- 'regression_output':
- export_output.RegressionOutput(predictions['probabilities']),
- }
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=math_ops.reduce_sum(loss),
- eval_metric_ops=metrics,
- predictions=predictions,
- export_outputs=export_outputs)
-
- def replicate_estimator_spec(self, session):
- features = np.array([0.01, 0.002])
- labels = np.array([0.01, 0.02])
-
- replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(features, labels,
- model_fn_lib.ModeKeys.PREDICT, {})
- session.run(variables.global_variables_initializer())
- return estimator_spec
-
- def test_merge_predict_output(self):
- with self.test_session() as session:
- estimator_spec = self.replicate_estimator_spec(session)
- self.assertAllClose(
- {
- 'probabilities': np.array([0.1, 0.02])
- },
- session.run(estimator_spec.export_outputs[
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
-
- def test_merge_classification_output_scores_classes(self):
- with self.test_session() as session:
- estimator_spec = self.replicate_estimator_spec(session)
- self.assertAllClose(
- [0.1, 0.02],
- session.run(
- estimator_spec.export_outputs['classification_output'].scores))
- self.assertAllEqual(
- [b'split_inputs/split:0', b'split_inputs/split:1'],
- session.run(
- estimator_spec.export_outputs['classification_output'].classes))
-
- def test_merge_classification_output_scores(self):
- with self.test_session() as session:
- estimator_spec = self.replicate_estimator_spec(session)
- self.assertAllClose(
- [0.1, 0.02],
- session.run(
- estimator_spec.export_outputs['classification_scores'].scores))
- self.assertEqual(
- None, estimator_spec.export_outputs['classification_scores'].classes)
-
- def test_merge_classification_output_classes(self):
- with self.test_session() as session:
- estimator_spec = self.replicate_estimator_spec(session)
- self.assertAllEqual(
- [b'split_inputs/split:0', b'split_inputs/split:1'],
- session.run(
- estimator_spec.export_outputs['classification_classes'].classes))
- self.assertEqual(
- None, estimator_spec.export_outputs['classification_classes'].scores)
-
- def test_merge_regression_output(self):
- with self.test_session() as session:
- estimator_spec = self.replicate_estimator_spec(session)
- self.assertAllClose(
- [0.1, 0.02],
- session.run(estimator_spec.export_outputs['regression_output'].value))
-
-
-class GetLocalDevicesTest(test_util.TensorFlowTestCase):
-
- def test_there_is_at_least_a_cpu(self):
- self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
-
- def test_there_is_no_xpu(self):
- self.assertFalse(
- replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist.
-
- def test_whether_there_is_a_gpu(self):
- if test.is_gpu_available():
- self.assertTrue(len(replicate_model_fn._get_local_devices('GPU')))
-
-
-class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
-
- def test_vars_are_on_ps_but_ops_are_on_workers(self):
- ps_devices = ['/device:GPU:3']
- round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
-
- local_device_setter = replicate_model_fn._local_device_setter(
- ps_devices=ps_devices,
- ps_strategy=round_robin,
- worker_device='/device:GPU:2')
-
- with ops_lib.device(local_device_setter):
- a = variables.Variable(0.01)
- self.assertEqual('/device:GPU:3', a.device)
-
- b = variables.Variable(0.02)
- self.assertEqual('/device:GPU:3', b.device)
-
- c = variables.Variable(0.03)
- self.assertEqual('/device:GPU:3', c.device)
-
- a_op = array_ops.concat(a, axis=0)
- self.assertEqual('/device:GPU:2', a_op.device)
-
- b_op = array_ops.concat(b, axis=0)
- self.assertEqual('/device:GPU:2', b_op.device)
-
- def test_round_robin_placement(self):
- ps_devices = [
- '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4'
- ]
- round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
-
- local_device_setter = replicate_model_fn._local_device_setter(
- ps_devices=ps_devices,
- ps_strategy=round_robin,
- worker_device='/device:GPU:2')
-
- with ops_lib.device(local_device_setter):
- a = variables.Variable(0.01)
- self.assertEqual('/device:GPU:0', a.device)
-
- b = variables.Variable(0.02)
- self.assertEqual('/device:GPU:1', b.device)
-
- c = variables.Variable(0.03)
- self.assertEqual('/device:GPU:3', c.device)
-
- a_op = array_ops.concat(a, axis=0)
- self.assertEqual('/device:GPU:2', a_op.device)
-
- b_op = array_ops.concat(b, axis=0)
- self.assertEqual('/device:GPU:2', b_op.device)
-
- c = variables.Variable(0.03)
- self.assertEqual('/device:GPU:4', c.device)
-
- d = variables.Variable(0.03)
- self.assertEqual('/device:GPU:0', d.device)
-
- c_op = array_ops.concat(c, axis=0)
- self.assertEqual('/device:GPU:2', c_op.device)
-
-
-class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
-
- def test_vectors(self):
- with self.test_session() as session:
- total = replicate_model_fn._compute_sum_on_device(
- [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
-
- self.assertEqual('/device:GPU:0', total.device)
- self.assertEqual('test_sum', total.op.name)
- self.assertEqual(10.0, session.run(total))
-
- def test_tensors(self):
- with self.test_session() as session:
- total = replicate_model_fn._compute_sum_on_device(
- [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
-
- self.assertEqual('/device:GPU:0', total.device)
- self.assertEqual('test_sum', total.op.name)
- self.assertAllEqual([4.0, 6.0], session.run(total))
-
- def test_indexedslices(self):
- with self.test_session() as session:
- a = ops_lib.IndexedSlices(
- constant_op.constant([1.0, 2.0]), [0, 1],
- dense_shape=constant_op.constant([2]))
- b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
- total = replicate_model_fn._compute_sum_on_device(
- [a, b], device='/device:GPU:0')
-
- self.assertEqual('/device:GPU:0', total.device)
- self.assertAllEqual([4.0, 6.0],
- session.run(ops_lib.convert_to_tensor(total)))
-
- def test_indexedslices_higher_dimensions(self):
- with self.test_session() as session:
- a = ops_lib.IndexedSlices(
- constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
- dense_shape=constant_op.constant([2, 4]))
- b = ops_lib.IndexedSlices(
- constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
-
- total = replicate_model_fn._compute_sum_on_device(
- [a, b], device='/device:GPU:0')
-
- self.assertEqual('/device:GPU:0', total.device)
- self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
- session.run(ops_lib.convert_to_tensor(total)))
-
- def test_indexedslices_some_dont_overlap(self):
- with self.test_session() as session:
- a = ops_lib.IndexedSlices(
- constant_op.constant([1.0, 2.0]), [0, 3],
- dense_shape=constant_op.constant([4]))
- b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
- total = replicate_model_fn._compute_sum_on_device(
- [a, b], device='/device:GPU:0')
-
- self.assertEqual('/device:GPU:0', total.device)
- self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
- session.run(ops_lib.convert_to_tensor(total)))
-
- def test_no_name_for_indexslices(self):
- a = ops_lib.IndexedSlices(
- constant_op.constant([1.0, 2.0]), [0, 1],
- dense_shape=constant_op.constant([2]))
- b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
- with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'):
- _ = replicate_model_fn._compute_sum_on_device(
- [a, b], device='/device:GPU:0', name='cant_name_indexslices')
-
-
-class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
-
- def test_example(self):
- tensor_dicts = [
- {
- 'a': np.array([1.0, 2.0]),
- 'b': np.array([11.0]),
- 'c': np.array([21.0]),
- },
- {
- 'a': np.array([3.0]),
- 'b': np.array([12.0, 13.0]),
- },
- {
- 'b': np.array([14.0]),
- },
- ]
-
- with self.test_session() as session:
- self.assertAllClose({
- 'a': np.array([1.0, 2.0, 3.0]),
- 'b': np.array([11.0, 12.0, 13.0, 14.0]),
- 'c': np.array([21.0]),
- }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index eda713641d..7f9ef53457 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -656,6 +656,7 @@ _PYTHON_TO_TF = {
bool: bool,
}
+
@tf_export("as_dtype")
def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index 7c2169b2af..a873670e04 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -311,3 +311,4 @@ class TypesTest(test_util.TensorFlowTestCase):
if __name__ == "__main__":
googletest.main()
+
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index e5e3b82199..ad6c36b4b1 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -98,7 +98,7 @@ bool IsOpWithUnderscorePrefix(const string& s) {
// TODO(annarev): reduce usage of '*' imports and remove these from the
// list.
"fused_batch_norm", "histogram_fixed_width", "stack",
- "batch_norm_with_global_normalization"});
+ "batch_norm_with_global_normalization", "clip_by_value"});
return kUnderscoreOps->count(s) > 0;
}
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index a00e82d470..9232d99a1f 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -188,7 +188,7 @@ class DimensionTest(test_util.TensorFlowTestCase):
def testUnsupportedType(self):
with self.assertRaises(TypeError):
tensor_shape.Dimension(dtypes.string)
-
+
def testMod(self):
four = tensor_shape.Dimension(4)
nine = tensor_shape.Dimension(9)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c277c56b8d..210b571449 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1615,6 +1615,7 @@ cuda_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
+ tags = ["noasan"],
)
cuda_py_test(
@@ -2882,7 +2883,10 @@ tf_py_test(
"//tensorflow/python:variables",
],
shard_count = 10,
- tags = ["no_windows_gpu"],
+ tags = [
+ "no_windows_gpu",
+ "noasan",
+ ],
)
tf_py_test(
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index d132f15e51..54f33f3360 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -49,7 +49,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=2,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -116,7 +115,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=2,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values],
@@ -189,7 +187,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=4,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -299,7 +296,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=4,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -429,7 +425,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=2,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -562,7 +557,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=3,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -705,7 +699,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=3,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -782,7 +775,6 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# Grow tree ensemble.
predict_op = boosted_trees_ops.training_predict(
tree_ensemble_handle,
- max_depth=1,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=[feature_0_values, feature_1_values],
@@ -905,8 +897,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
predict_op = boosted_trees_ops.predict(
tree_ensemble_handle,
bucketized_features=[feature_0_values, feature_1_values],
- logits_dimension=1,
- max_depth=2)
+ logits_dimension=1)
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
@@ -915,8 +906,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
predict_op = boosted_trees_ops.predict(
tree_ensemble_handle,
bucketized_features=[feature_0_values, feature_1_values],
- logits_dimension=1,
- max_depth=2)
+ logits_dimension=1)
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 4d09cf94d4..f0bb84e69a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -59,6 +59,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=0.0,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -106,6 +107,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=0.1,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -154,6 +156,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=l1,
l2=0.0,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
@@ -205,6 +208,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=l2,
tree_complexity=tree_complexity,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -220,6 +224,53 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
sess.run(right_node_contribs_list))
+ def testCalculateBestGainsWithMinNodeWEight(self):
+ """Testing Gain calculation without any regularization."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .036], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .68], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
+ [[.1, .1], [.2, .03], [-.4, .05], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=1,
+ max_splits=max_splits)
+
+ # We can't split node 1 on feature 1 and node 2 on feature 2 because of
+ # the min node weight.
+ self.assertAllEqual([[2], [1]], sess.run(node_ids_list))
+ self.assertAllClose([[0.384314], [0.098013]], sess.run(gains_list))
+ self.assertAllEqual([[1], [1]], sess.run(thresholds_list))
+ self.assertAllClose([[[0.4852941]], [[-.6]]],
+ sess.run(left_node_contribs_list))
+ self.assertAllClose([[[-0.75]], [[-0.014925]]],
+ sess.run(right_node_contribs_list))
+
def testMakeStatsSummarySimple(self):
"""Simple test for MakeStatsSummary."""
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index cb1359be15..c3f44f385e 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradient_checker
@@ -29,19 +28,19 @@ from tensorflow.python.platform import test
class ClipTest(test.TestCase):
- def testClipByValueGradient(self):
+ def DISABLED_testClipByValueGradient(self):
inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32)
outputs_1 = clip_ops.clip_by_value(inputs, 0.5, 3.5)
min_val = constant_op.constant([0.5, 0.5, 0.5, 0.5], dtype=dtypes.float32)
max_val = constant_op.constant([3.5, 3.5, 3.5, 3.5], dtype=dtypes.float32)
outputs_2 = clip_ops.clip_by_value(inputs, min_val, max_val)
with self.test_session():
- error_1 = gradient_checker.compute_gradient_error(inputs, [4],
- outputs_1, [4])
+ error_1 = gradient_checker.compute_gradient_error(inputs, [4], outputs_1,
+ [4])
self.assertLess(error_1, 1e-4)
- error_2 = gradient_checker.compute_gradient_error(inputs, [4],
- outputs_2, [4])
+ error_2 = gradient_checker.compute_gradient_error(inputs, [4], outputs_2,
+ [4])
self.assertLess(error_2, 1e-4)
# ClipByValue test
@@ -56,10 +55,11 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
# [Tensor, Scalar, Scalar]
- def testClipByValue0Type(self):
- for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
- dtypes.uint8, dtypes.uint16]:
+ def DISABLED_testClipByValue0Type(self):
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
+ dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
+ ]:
with self.test_session(use_gpu=True):
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [4, 4, 4]]
@@ -71,15 +71,16 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
# [Tensor, Tensor, Scalar]
- def testClipByValue1Type(self):
- for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
- dtypes.uint8, dtypes.uint16]:
+ def DISABLED_testClipByValue1Type(self):
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
+ dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
+ ]:
with self.test_session(use_gpu=True):
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [4, 4, 4]]
- clip_value_min = constant_op.constant([2, 2, 2, 3, 3, 3], shape=[2, 3],
- dtype=dtype)
+ clip_value_min = constant_op.constant(
+ [2, 2, 2, 3, 3, 3], shape=[2, 3], dtype=dtype)
clip_value_max = 4
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
tf_ans = ans.eval()
@@ -87,33 +88,35 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
# [Tensor, Scalar, Tensor]
- def testClipByValue2Type(self):
- for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
- dtypes.uint8, dtypes.uint16]:
+ def DISABLED_testClipByValue2Type(self):
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
+ dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
+ ]:
with self.test_session(use_gpu=True):
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[4, 4, 4], [4, 5, 6]]
clip_value_min = 4
- clip_value_max = constant_op.constant([6, 6, 6, 6, 6, 6], shape=[2, 3],
- dtype=dtype)
+ clip_value_max = constant_op.constant(
+ [6, 6, 6, 6, 6, 6], shape=[2, 3], dtype=dtype)
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
tf_ans = ans.eval()
self.assertAllClose(np_ans, tf_ans)
# [Tensor, Tensor, Tensor]
- def testClipByValue3Type(self):
- for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
- dtypes.uint8, dtypes.uint16]:
+ def DISABLED_testClipByValue3Type(self):
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8,
+ dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16
+ ]:
with self.test_session(use_gpu=True):
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [5, 5, 6]]
- clip_value_min = constant_op.constant([2, 2, 2, 5, 5, 5], shape=[2, 3],
- dtype=dtype)
- clip_value_max = constant_op.constant([5, 5, 5, 7, 7, 7], shape=[2, 3],
- dtype=dtype)
+ clip_value_min = constant_op.constant(
+ [2, 2, 2, 5, 5, 5], shape=[2, 3], dtype=dtype)
+ clip_value_max = constant_op.constant(
+ [5, 5, 5, 7, 7, 7], shape=[2, 3], dtype=dtype)
ans = clip_ops.clip_by_value(x, clip_value_min, clip_value_max)
tf_ans = ans.eval()
@@ -132,7 +135,8 @@ class ClipTest(test.TestCase):
tf_ans = ans.eval()
def testClipByValueNonFinite(self):
- with self.test_session(use_gpu=True):
+ # TODO(b/78016351): Enable test on GPU once the bug is fixed.
+ with self.test_session():
x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
np_ans = [float('NaN'), 4.0, -4.0]
clip_value = 4.0
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 34e7751243..87da89831c 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -398,14 +398,17 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.abs, _ABS)
self._compareCpu(x, np.negative, math_ops.negative)
self._compareCpu(x, np.negative, _NEG)
- self._compareCpu(x, np.square, math_ops.square)
self._compareCpu(x, np.sign, math_ops.sign)
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sign, math_ops.sign)
+ def testInt64Square(self):
+ x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.square, math_ops.square)
+
def testComplex64Basic(self):
x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
np.complex64)
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 9a94692569..a2fcd751df 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -149,6 +149,15 @@ class GatherTest(test.TestCase):
self.assertAllEqual([b"asdf", b"qwer"],
array_ops.gather(params, 0, axis=1).eval())
+ def testUInt32AndUInt64(self):
+ for unsigned_type in (dtypes.uint32, dtypes.uint64):
+ params = self._buildParams(
+ np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
+ with self.test_session():
+ self.assertAllEqual([7, 8, 9],
+ array_ops.gather(params, 1, axis=0).eval())
+ self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
+
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index f7a7119b34..a9b55854f1 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -613,10 +613,12 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
def testShapesValues(self):
+ gain = 3.14
for dtype in [dtypes.float32]:
for kernel_size in [[3], [8], [3, 5], [2, 4], [3, 3, 3], [2, 2, 2]]:
tol = 1e-2
- # Check orthogonality by computing the 2-norms of the inputs and outputs.
+ # Check orthogonality by computing ratio between
+ # the 2-norms of the inputs and outputs.
if len(kernel_size) == 1:
shape = [4, 32, 64]
convolution = convolutional.conv1d
@@ -632,9 +634,10 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
inputs, padding="same", filters=128,
kernel_size=kernel_size, use_bias=False,
kernel_initializer=init_ops.convolutional_delta_orthogonal(
- gain=3.14))
+ gain=gain))
outputs_shape = shape[0:-1] + [128]
outputs_2norm = linalg_ops.norm(outputs)
+ ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.test_session(use_gpu=True) as sess:
sess.run(my_ops)
@@ -642,10 +645,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
t = outputs.eval()
self.assertAllEqual(t.shape, outputs_shape)
# Check isometry of the delta-orthogonal kernel.
- self.assertAllClose(
- sess.run(inputs_2norm)/np.sqrt(np.prod(shape)),
- sess.run(outputs_2norm)/(np.sqrt(np.prod(shape))*np.sqrt(3.14)),
- rtol=tol, atol=tol)
+ self.assertAllClose(sess.run(ratio), np.sqrt(gain),
+ rtol=tol, atol=tol)
def testNonuniformity(self):
value = 0
@@ -653,7 +654,7 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
shape = [3, 3, 10, 10]
count = 70
tol = 1e-5
- with self.test_session(use_gpu=True): # as sess:
+ with self.test_session(use_gpu=True):
for i in range(count):
x = variable_scope.get_variable("{}".format(i), shape=shape,
initializer=
@@ -672,6 +673,120 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
+class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
+
+ def testInitializerIdentical(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
+ self.assertTrue(identicaltest(self, init1, init2, (3, 10, 10)))
+
+ def testInitializerDifferent(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_1d(seed=2, dtype=dtype)
+ self.assertFalse(identicaltest(self, init1, init2, (3, 10, 10)))
+
+ def testDuplicatedInitializer(self):
+ init = init_ops.convolutional_orthogonal_1d()
+ self.assertFalse(duplicated_initializer(self, init, 1, (3, 10, 10)))
+
+ def testInvalidDataType(self):
+ self.assertRaises(
+ ValueError, init_ops.convolutional_orthogonal_1d,
+ dtype=dtypes.string)
+
+ def testInvalidShape(self):
+ init1 = init_ops.convolutional_orthogonal_1d()
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ self.assertRaises(ValueError, init1, shape=[3, 6, 5])
+
+ def testGain(self):
+ shape = (3, 10, 10)
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_1d(gain=3.14,
+ seed=1, dtype=dtype)
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ t1 = init1(shape).eval()
+ t2 = init2(shape).eval()
+ return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
+
+ def testNonuniformity(self):
+ value = 0
+ abs_value = 0
+ shape = [3, 10, 10]
+ count = 70
+ tol = 1e-5
+ with self.test_session(use_gpu=True):
+ for i in range(count):
+ x = variable_scope.get_variable("{}".format(i), shape=shape,
+ initializer=
+ init_ops.convolutional_orthogonal_1d)
+ x.initializer.run()
+ y = np.sum(x.eval(), axis=0)
+ determinant = np.linalg.det(y)
+ value += determinant
+ abs_value += np.abs(determinant)
+
+ # Check there is some variation in the signs of the determinants.
+ self.assertLess(value, count - tol)
+ self.assertLess(-count + tol, value)
+ # Check all determinants have absolute value 1
+ # Compute the sum of the absolute values of 'count' determinants
+ self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
+
+ def testShapesValues(self):
+ def circular_pad(input_, width, kernel_size):
+ """Pad input_ for computing (circular) convolution.
+
+ Args:
+ input_: the input tensor
+ width: the width of the tensor.
+ kernel_size: the kernel size of the filter.
+ Returns:
+ a tensor whose width is (width + kernel_size - 1).
+ """
+
+ beginning = kernel_size // 2
+ end = kernel_size - 1 - beginning
+
+ tmp_up = array_ops.slice(input_, [0, width - beginning, 0],
+ [-1, beginning, -1])
+ tmp_down = array_ops.slice(input_, [0, 0, 0], [-1, end, -1])
+ tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
+
+ return tmp
+
+ cout = 64
+ shape = [10, 20, 32]
+ outputs_shape = shape[0:-1] + [cout]
+ dtype = dtypes.float32
+ tol = 1e-3
+ gain = 3.14
+ # Check orthogonality/isometry by computing the ratio between
+ # the 2-norms of the inputs and ouputs.
+ for kernel_size in [[1], [2], [3], [4], [5], [6]]:
+ convolution = convolutional.conv1d
+ inputs = random_ops.random_normal(shape, dtype=dtype)
+ inputs_2norm = linalg_ops.norm(inputs)
+ input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
+ outputs = convolution(
+ input_with_circular_pad, padding="valid", filters=cout,
+ kernel_size=kernel_size[0], use_bias=False,
+ kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain))
+ outputs_2norm = linalg_ops.norm(outputs)
+ ratio = outputs_2norm / inputs_2norm
+ my_ops = variables.global_variables_initializer()
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(my_ops)
+ # Check the shape of the outputs
+ t = outputs.eval()
+ self.assertAllEqual(t.shape, outputs_shape)
+ # Check isometry of the orthogonal kernel.
+ self.assertAllClose(sess.run(ratio), np.sqrt(gain), rtol=tol, atol=tol)
+
+
class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
def testInitializerIdentical(self):
@@ -722,17 +837,17 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
Returns:
a tensor whose width is (width + kernel_size - 1).
"""
- beg = kernel_size // 2
- end = kernel_size - 1 - beg
+ beginning = kernel_size // 2
+ end = kernel_size - 1 - beginning
- tmp_up = array_ops.slice(input_, [0, width - beg, 0, 0],
- [-1, beg, width, -1])
+ tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0],
+ [-1, beginning, width, -1])
tmp_down = array_ops.slice(input_, [0, 0, 0, 0], [-1, end, width, -1])
tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
new_width = width + kernel_size - 1
- tmp_left = array_ops.slice(tmp, [0, 0, width - beg, 0],
- [-1, new_width, beg, -1])
+ tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0],
+ [-1, new_width, beginning, -1])
tmp_right = array_ops.slice(tmp, [0, 0, 0, 0], [-1, new_width, end, -1])
final = array_ops.concat([tmp_left, tmp, tmp_right], 2)
@@ -756,6 +871,132 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
kernel_size=kernel_size, use_bias=False,
kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain))
outputs_2norm = linalg_ops.norm(outputs)
+ ratio = outputs_2norm / inputs_2norm
+ my_ops = variables.global_variables_initializer()
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(my_ops)
+ # Check the shape of the outputs
+ t = outputs.eval()
+ self.assertAllEqual(t.shape, outputs_shape)
+ # Check isometry of the orthogonal kernel.
+ self.assertAllClose(sess.run(ratio), np.sqrt(gain), rtol=tol, atol=tol)
+
+
+class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
+
+ def testInitializerIdentical(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
+ self.assertTrue(identicaltest(self, init1, init2, (3, 3, 3, 10, 10)))
+
+ def testInitializerDifferent(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_3d(seed=2, dtype=dtype)
+ self.assertFalse(identicaltest(self, init1, init2, (3, 3, 3, 10, 10)))
+
+ def testDuplicatedInitializer(self):
+ init = init_ops.convolutional_orthogonal_3d()
+ self.assertFalse(duplicated_initializer(self, init, 1, (3, 3, 3, 10, 10)))
+
+ def testInvalidDataType(self):
+ self.assertRaises(
+ ValueError, init_ops.convolutional_orthogonal_3d,
+ dtype=dtypes.string)
+
+ def testInvalidShape(self):
+ init1 = init_ops.convolutional_orthogonal_3d()
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ self.assertRaises(ValueError, init1, shape=[3, 3, 3, 6, 5])
+
+ def testGain(self):
+ shape = (3, 3, 3, 10, 10)
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_3d(gain=3.14,
+ seed=1, dtype=dtype)
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ t1 = init1(shape).eval()
+ t2 = init2(shape).eval()
+ return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
+
+ def testNonuniformity(self):
+ value = 0
+ abs_value = 0
+ shape = [3, 3, 3, 5, 5]
+ count = 20
+ tol = 1e-5
+ with self.test_session(use_gpu=True):
+ for i in range(count):
+ x = variable_scope.get_variable("{}".format(i), shape=shape,
+ initializer=
+ init_ops.convolutional_orthogonal_3d)
+ x.initializer.run()
+ y = np.sum(x.eval(), axis=(0, 1, 2))
+ determinant = np.linalg.det(y)
+ value += determinant
+ abs_value += np.abs(determinant)
+
+ # Check there is some variation in the signs of the determinants
+ self.assertLess(value, count - tol)
+ self.assertLess(-count + tol, value)
+ # Check all determinants have absolute value 1
+ # Compute the sum of the absolute values of 'count' determinants
+ self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
+
+ def testShapesValues(self):
+ def circular_pad(input_, width, kernel_size):
+ """Padding input_ for computing circular convolution.
+
+ Args:
+ input_: the input tensor
+ width: the width of the tensor.
+ kernel_size: the kernel size of the filter.
+
+ Returns:
+ a tensor whose width is (width + kernel_size - 1).
+ """
+
+ beginning = kernel_size // 2
+ end = kernel_size - 1 - beginning
+
+ tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0],
+ [-1, beginning, -1, -1, -1])
+ tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0],
+ [-1, end, -1, -1, -1])
+ tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
+
+ tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0],
+ [-1, -1, beginning, -1, -1])
+ tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0],
+ [-1, -1, end, -1, -1])
+ tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2)
+
+ tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0],
+ [-1, -1, -1, beginning, -1])
+ tmp_back = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, -1, end, -1])
+ return array_ops.concat([tmp_front, tmp, tmp_back], 3)
+
+ cout = 32
+ shape = [1, 7, 7, 7, 16]
+ outputs_shape = shape[0:-1] + [cout]
+ dtype = dtypes.float32
+ tol = 1e-3
+ gain = 3.14
+ # Check orthogonality/isometry by computing the ratio between
+ # the 2-norms of the inputs and ouputs.
+ for kernel_size in [[1, 1, 1], [2, 2, 2], [3, 3, 3]]:
+ convolution = convolutional.conv3d
+ inputs = random_ops.random_normal(shape, dtype=dtype)
+ inputs_2norm = linalg_ops.norm(inputs)
+ input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
+ outputs = convolution(
+ input_with_circular_pad, padding="valid", filters=cout,
+ kernel_size=kernel_size[0], use_bias=False,
+ kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain))
+ outputs_2norm = linalg_ops.norm(outputs)
+ ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.test_session(use_gpu=True) as sess:
sess.run(my_ops)
@@ -763,10 +1004,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
t = outputs.eval()
self.assertAllEqual(t.shape, outputs_shape)
# Check isometry of the orthogonal kernel.
- self.assertAllClose(
- sess.run(inputs_2norm)/np.sqrt(np.prod(shape)),
- sess.run(outputs_2norm)/(np.sqrt(np.prod(shape))*np.sqrt(gain)),
- rtol=tol, atol=tol)
+ self.assertAllClose(sess.run(ratio), np.sqrt(gain), rtol=tol, atol=tol)
class IdentityInitializerTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 9a0409c796..fe5ad84c10 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -206,6 +206,28 @@ class RNNTest(test.TestCase):
self.assertAllEqual(4, state[0])
self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
+ def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size,
+ out_size):
+ cell = cell_class(out_size, dtype=dtype)
+ in_shape = tensor_shape.TensorShape((batch_size, in_size))
+ cell.build(in_shape)
+ state_output = cell.zero_state(batch_size, dtype)
+ cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
+ self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testCellsBuild(self):
+ f32 = dtypes.float32
+ f64 = dtypes.float64
+ self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f32, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f64, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f32, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f64, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.GRUCell, f32, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
+ self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
+
######### Benchmarking RNN code
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 981f96b74d..dc4d4dbeab 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -39,6 +39,10 @@ class SoftmaxTest(test.TestCase):
dim = len(features.shape) - 1
one_only_on_dim = list(features.shape)
one_only_on_dim[dim] = 1
+ is_fp16 = features.dtype == np.float16
+ if is_fp16:
+ # Do the compute in fp32 and cast the input back to fp32.
+ features = features.astype(np.float32)
e = np.exp(features - np.reshape(
np.amax(
features, axis=dim), one_only_on_dim))
@@ -47,6 +51,8 @@ class SoftmaxTest(test.TestCase):
res = np.log(softmax)
else:
res = softmax
+ if is_fp16:
+ res = res.astype(np.float16)
return res
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
@@ -125,8 +131,8 @@ class SoftmaxTest(test.TestCase):
"Test only applicable when running on GPUs")
def testFloatGPU(self):
if test.is_gpu_available(cuda_only=True):
- rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
- cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+ rows = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
+ cols = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
for row, col in zip(rows, cols):
logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
data = np.random.rand(row, col)
@@ -140,8 +146,8 @@ class SoftmaxTest(test.TestCase):
"Test only applicable when running on GPUs")
def testHalfGPU(self):
if test.is_gpu_available(cuda_only=True):
- rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
- cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+ rows = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
+ cols = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
for row, col in zip(rows, cols):
logging.info("Testing softmax half dtype in shape [%d, %d]", row, col)
data = np.random.rand(row, col)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 0829aa67ed..75c459a9cf 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -27,7 +27,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -60,13 +59,26 @@ def clip_by_value(t, clip_value_min, clip_value_max,
"""
with ops.name_scope(name, "clip_by_value",
[t, clip_value_min, clip_value_max]) as name:
- return gen_math_ops.clip_by_value(t,
- clip_value_min,
- clip_value_max,
- name=name)
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Go through list of tensors, for each value in each tensor clip
+ t_min = math_ops.minimum(t, clip_value_max)
+ # Assert that the shape is compatible with the initial shape,
+ # to prevent unintentional broadcasting.
+ _ = t.shape.merge_with(t_min.shape)
+
+ t_max = math_ops.maximum(t_min, clip_value_min, name=name)
+ _ = t.shape.merge_with(t_max.shape)
+
+ return t_max
+ # TODO(scottzhu): switch to use new implmentation in 2 weeks.
+ # return gen_math_ops.clip_by_value(
+ # t, clip_value_min, clip_value_max, name=name)
+
-@ops.RegisterGradient("ClipByValue")
-def _ClipByValueGrad(op, grad):
+# TODO(scottzhu): switch to use new implmentation in 2 weeks.
+# @ops.RegisterGradient("ClipByValue")
+def _clip_by_value_grad(op, grad):
"""Returns grad of clip_by_value."""
x = op.inputs[0]
y = op.inputs[1]
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
new file mode 100644
index 0000000000..97331bb5b5
--- /dev/null
+++ b/tensorflow/python/ops/cudnn_rnn_grad.py
@@ -0,0 +1,47 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for CuudnnRNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
+
+
+@ops.RegisterGradient("CudnnRNN")
+def _cudnn_rnn_backward(op, *grads):
+ """Gradients for the CudnnRNN op."""
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "CudnnRNN must set is_training to True to be used in gradients")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grads[0],
+ output_h_backprop=grads[1],
+ output_c_backprop=grads[2],
+ reserve_space=op.outputs[3],
+ dropout=op.get_attr("dropout"),
+ seed=op.get_attr("seed"),
+ seed2=op.get_attr("seed2"),
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 5ded3f7cc2..39b7295124 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -549,12 +549,11 @@ class ConvolutionDeltaOrthogonal(Initializer):
tensor form an orthogonal matrix. Other pixels are set to be zero.
Args:
- gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
+ gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
applying this convolution.
seed: A Python integer. Used to create random seeds. See
- @{tf.set_random_seed}
- for behavior.
+ @{tf.set_random_seed} for behavior.
dtype: The data type.
"""
@@ -600,21 +599,17 @@ class ConvolutionDeltaOrthogonal(Initializer):
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
-class ConvolutionOrthogonal2D(Initializer):
- """Initializer that generates a 2D orthogonal kernel for ConvNets.
+class ConvolutionOrthogonal(Initializer):
+ """Initializer that generates orthogonal kernel for ConvNets.
- The shape of the tensor must have length 2. The number of input
- filters must not exceed the number of output filters.
- The orthogonality(==isometry) is exact when the inputs are circular padded.
- There are finite-width effects with non-circular padding (e.g. zero padding).
+ Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution.
Args:
gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
applying this convolution.
seed: A Python integer. Used to create random seeds. See
- @{tf.set_random_seed}
- for behavior.
+ @{tf.set_random_seed} for behavior.
dtype: The data type.
"""
@@ -624,21 +619,7 @@ class ConvolutionOrthogonal2D(Initializer):
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
- if dtype is None:
- dtype = self.dtype
- # Check the shape
- if len(shape) != 4:
- raise ValueError("The tensor to initialize must be four-dimensional")
-
- if shape[-2] > shape[-1]:
- raise ValueError("In_filters cannot be greater than out_filters.")
-
- if shape[0] != shape[1]:
- raise ValueError("Kernel sizes must be equal.")
-
- kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
- kernel *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
- return kernel
+ raise NotImplementedError
def get_config(self):
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
@@ -648,9 +629,9 @@ class ConvolutionOrthogonal2D(Initializer):
"""Construct an n x n orthogonal matrix.
Args:
- n: dimension.
+ n: Dimension.
Returns:
- a n x n orthogonal matrix.
+ A n x n orthogonal matrix.
"""
a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
if self.seed:
@@ -665,9 +646,9 @@ class ConvolutionOrthogonal2D(Initializer):
"""Compute a n x n symmetric projection matrix.
Args:
- n: dimension.
+ n: Dimension.
Returns:
- a n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
+ A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
"""
q = self._orthogonal_matrix(n)
# randomly zeroing out some columns
@@ -678,15 +659,49 @@ class ConvolutionOrthogonal2D(Initializer):
c = math_ops.multiply(q, mask)
return math_ops.matmul(c, array_ops.matrix_transpose(c))
+
+class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
+ """Initializer that generates a 2D orthogonal kernel for ConvNets.
+
+ The shape of the tensor must have length 4. The number of input
+ filters must not exceed the number of output filters.
+ The orthogonality(==isometry) is exact when the inputs are circular padded.
+ There are finite-width effects with non-circular padding (e.g. zero padding).
+
+ Args:
+ gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
+ This has the effect of scaling the output 2-norm by a factor of
+ `sqrt(gain)`.
+ seed: A Python integer. Used to create random seeds. See
+ @{tf.set_random_seed} for behavior.
+ dtype: The data type.
+ """
+
+ def __call__(self, shape, dtype=None, partition_info=None):
+ if dtype is None:
+ dtype = self.dtype
+ if len(shape) != 4:
+ raise ValueError("The tensor to initialize must be four-dimensional")
+
+ if shape[-2] > shape[-1]:
+ raise ValueError("In_filters cannot be greater than out_filters.")
+
+ if shape[0] != shape[1]:
+ raise ValueError("Kernel sizes must be equal.")
+
+ kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
+ kernel *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
+ return kernel
+
def _dict_to_tensor(self, x, k1, k2):
"""Convert a dictionary to a tensor.
Args:
- x: a k1 * k2 dictionary.
- k1: first dimension of x.
- k2: second dimension of x.
+ x: A k1 * k2 dictionary.
+ k1: First dimension of x.
+ k2: Second dimension of x.
Returns:
- a k1 * k2 tensor.
+ A k1 * k2 tensor.
"""
return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
@@ -696,13 +711,13 @@ class ConvolutionOrthogonal2D(Initializer):
"""Construct a 2 x 2 kernel. Used to construct orthgonal kernel.
Args:
- p1: a symmetric projection matrix
- p2: a symmetric projection matrix
+ p1: A symmetric projection matrix.
+ p2: A symmetric projection matrix.
Returns:
- a 2 x 2 kernel [[p1p2, p1(1-p2)],
+ A 2 x 2 kernel [[p1p2, p1(1-p2)],
[(1-p1)p2, (1-p1)(1-p2)]].
Raises:
- ValueError: if the dimensions of p1 and p2 are different.
+ ValueError: If the dimensions of p1 and p2 are different.
"""
if p1.shape.as_list() != p2.shape.as_list():
raise ValueError("The dimension of the matrices must be the same.")
@@ -720,8 +735,8 @@ class ConvolutionOrthogonal2D(Initializer):
"""Matrix convolution.
Args:
- m1: is a k x k dictionary, each element is a n x n matrix.
- m2: is a l x l dictionary, each element is a n x n matrix.
+ m1: A k x k dictionary, each element is a n x n matrix.
+ m2: A l x l dictionary, each element is a n x n matrix.
Returns:
(k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
@@ -752,13 +767,13 @@ class ConvolutionOrthogonal2D(Initializer):
"""Construct orthogonal kernel for convolution.
Args:
- ksize: kernel size
- cin: number of input channels
- cout: number of output channels
+ ksize: Kernel size.
+ cin: Number of input channels.
+ cout: Number of output channels.
Returns:
- an [ksize, ksize, cin, cout] orthogonal kernel.
+ An [ksize, ksize, cin, cout] orthogonal kernel.
Raises:
- ValueError: if cin > cout.
+ ValueError: If cin > cout.
"""
if cin > cout:
raise ValueError("The number of input channels cannot exceed "
@@ -780,6 +795,273 @@ class ConvolutionOrthogonal2D(Initializer):
return self._dict_to_tensor(p, ksize, ksize)
+class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
+ """Initializer that generates a 1D orthogonal kernel for ConvNets.
+
+ The shape of the tensor must have length 3. The number of input
+ filters must not exceed the number of output filters.
+ The orthogonality(==isometry) is exact when the inputs are circular padded.
+ There are finite-width effects with non-circular padding (e.g. zero padding).
+
+ Args:
+ gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
+ The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
+ applying this convolution.
+ seed: A Python integer. Used to create random seeds. See
+ @{tf.set_random_seed}
+ for behavior.
+ dtype: The data type.
+ """
+
+ def __call__(self, shape, dtype=None, partition_info=None):
+ if dtype is None:
+ dtype = self.dtype
+ if len(shape) != 3:
+ raise ValueError("The tensor to initialize must be three-dimensional")
+
+ if shape[-2] > shape[-1]:
+ raise ValueError("In_filters cannot be greater than out_filters.")
+
+ kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
+ kernel *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
+ return kernel
+
+ def _dict_to_tensor(self, x, k):
+ """Convert a dictionary to a tensor.
+
+ Args:
+ x: A dictionary of length k.
+ k: Dimension of x.
+ Returns:
+ A tensor with the same dimension.
+ """
+
+ return array_ops.stack([x[i] for i in range(k)])
+
+ def _block_orth(self, projection_matrix):
+ """Construct a kernel. Used to construct orthgonal kernel.
+
+ Args:
+ projection_matrix: A symmetric projection matrix of size n x n.
+ Returns:
+ [projection_matrix, (1 - projection_matrix)].
+ """
+ n = projection_matrix.shape.as_list()[0]
+ kernel = {}
+ eye = linalg_ops.eye(n, dtype=self.dtype)
+ kernel[0] = projection_matrix
+ kernel[1] = eye - projection_matrix
+ return kernel
+
+ def _matrix_conv(self, m1, m2):
+ """Matrix convolution.
+
+ Args:
+ m1: A dictionary of length k, each element is a n x n matrix.
+ m2: A dictionary of length l, each element is a n x n matrix.
+
+ Returns:
+ (k + l - 1) dictionary each element is a n x n matrix.
+ Raises:
+ ValueError: Ff the entries of m1 and m2 are of different dimensions.
+ """
+
+ n = (m1[0]).shape.as_list()[0]
+ if n != (m2[0]).shape.as_list()[0]:
+ raise ValueError("The entries in matrices m1 and m2 "
+ "must have the same dimensions!")
+ k = len(m1)
+ l = len(m2)
+ result = {}
+ size = k + l - 1
+ # Compute matrix convolution between m1 and m2.
+ for i in range(size):
+ result[i] = array_ops.zeros([n, n], self.dtype)
+ for index in range(min(k, i + 1)):
+ if (i - index) < l:
+ result[i] += math_ops.matmul(m1[index], m2[i - index])
+ return result
+
+ def _orthogonal_kernel(self, ksize, cin, cout):
+ """Construct orthogonal kernel for convolution.
+
+ Args:
+ ksize: Kernel size.
+ cin: Number of input channels.
+ cout: Number of output channels.
+ Returns:
+ An [ksize, ksize, cin, cout] orthogonal kernel.
+ Raises:
+ ValueError: If cin > cout.
+ """
+ if cin > cout:
+ raise ValueError("The number of input channels cannot exceed "
+ "the number of output channels.")
+ orth = self._orthogonal_matrix(cout)[0:cin, :]
+ if ksize == 1:
+ return array_ops.expand_dims(orth, 0)
+
+ p = self._block_orth(self._symmetric_projection(cout))
+ for _ in range(ksize - 2):
+ temp = self._block_orth(self._symmetric_projection(cout))
+ p = self._matrix_conv(p, temp)
+ for i in range(ksize):
+ p[i] = math_ops.matmul(orth, p[i])
+
+ return self._dict_to_tensor(p, ksize)
+
+
+class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
+ """Initializer that generates a 3D orthogonal kernel for ConvNets.
+
+ The shape of the tensor must have length 5. The number of input
+ filters must not exceed the number of output filters.
+ The orthogonality(==isometry) is exact when the inputs are circular padded.
+ There are finite-width effects with non-circular padding (e.g. zero padding).
+
+ Args:
+ gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
+ The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
+ applying this convolution.
+ seed: A Python integer. Used to create random seeds. See
+ @{tf.set_random_seed} for behavior.
+ dtype: The data type.
+ """
+
+ def __call__(self, shape, dtype=None, partition_info=None):
+ if dtype is None:
+ dtype = self.dtype
+ if len(shape) != 5:
+ raise ValueError("The tensor to initialize must be five-dimensional")
+
+ if shape[-2] > shape[-1]:
+ raise ValueError("In_filters cannot be greater than out_filters.")
+
+ if shape[0] != shape[1] or shape[0] != shape[2]:
+ raise ValueError("Kernel sizes must be equal.")
+
+ kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
+ kernel *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
+ return kernel
+
+ def _dict_to_tensor(self, x, k1, k2, k3):
+ """Convert a dictionary to a tensor.
+
+ Args:
+ x: A k1 * k2 dictionary.
+ k1: First dimension of x.
+ k2: Second dimension of x.
+ k3: Third dimension of x.
+ Returns:
+ A k1 * k2 * k3 tensor.
+ """
+
+ return array_ops.stack([array_ops.stack(
+ [array_ops.stack([x[i, j, k] for k in range(k3)])
+ for j in range(k2)]) for i in range(k1)])
+
+ def _block_orth(self, p1, p2, p3):
+ """Construct a 3 x 3 kernel. Used to construct orthgonal kernel.
+
+ Args:
+ p1: A symmetric projection matrix.
+ p2: A symmetric projection matrix.
+ p3: A symmetric projection matrix.
+ Returns:
+ A 2 x 2 x 2 kernel.
+ Raises:
+ ValueError: If the dimensions of p1, p2 and p3 are different.
+ """
+ p1_shape = p1.shape.as_list()
+ if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
+ raise ValueError("The dimension of the matrices must be the same.")
+ n = p1_shape[0]
+ eye = linalg_ops.eye(n, dtype=self.dtype)
+ kernel2x2x2 = {}
+ def matmul(p1, p2, p3):
+ return math_ops.matmul(math_ops.matmul(p1, p2), p3)
+ def cast(i, p):
+ """Return p or (1-p)."""
+ return i * p + (1-i) * (eye - p)
+ for i in [0, 1]:
+ for j in [0, 1]:
+ for k in [0, 1]:
+ kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3))
+ return kernel2x2x2
+
+ def _matrix_conv(self, m1, m2):
+ """Matrix convolution.
+
+ Args:
+ m1: is a k x k x k dictionary, each element is a n x n matrix.
+ m2: is a l x l x l dictionary, each element is a n x n matrix.
+
+ Returns:
+ (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each
+ element is a n x n matrix.
+ Raises:
+ ValueError: if the entries of m1 and m2 are of different dimensions.
+ """
+
+ n = (m1[0, 0, 0]).shape.as_list()[0]
+ if n != (m2[0, 0, 0]).shape.as_list()[0]:
+ raise ValueError("The entries in matrices m1 and m2 "
+ "must have the same dimensions!")
+ k = int(np.cbrt(len(m1)))
+ l = int(np.cbrt(len(m2)))
+ result = {}
+ size = k + l - 1
+ # Compute matrix convolution between m1 and m2.
+ for i in range(size):
+ for j in range(size):
+ for r in range(size):
+ result[i, j, r] = array_ops.zeros([n, n], self.dtype)
+ for index1 in range(min(k, i + 1)):
+ for index2 in range(min(k, j + 1)):
+ for index3 in range(min(k, r + 1)):
+ if (i - index1) < l and (j - index2) < l and (r - index3) < l:
+ result[i, j, r] += math_ops.matmul(m1[index1, index2, index3],
+ m2[i - index1, j - index2,
+ r - index3])
+ return result
+
+ def _orthogonal_kernel(self, ksize, cin, cout):
+ """Construct orthogonal kernel for convolution.
+
+ Args:
+ ksize: Kernel size.
+ cin: Number of input channels.
+ cout: Number of output channels.
+ Returns:
+ An [ksize, ksize, ksize, cin, cout] orthogonal kernel.
+ Raises:
+ ValueError: If cin > cout.
+ """
+ if cin > cout:
+ raise ValueError("The number of input channels cannot exceed "
+ "the number of output channels.")
+ orth = self._orthogonal_matrix(cout)[0:cin, :]
+ if ksize == 1:
+ return array_ops.expand_dims(
+ array_ops.expand_dims(
+ array_ops.expand_dims(orth, 0), 0), 0)
+
+ p = self._block_orth(self._symmetric_projection(cout),
+ self._symmetric_projection(cout),
+ self._symmetric_projection(cout))
+ for _ in range(ksize - 2):
+ temp = self._block_orth(self._symmetric_projection(cout),
+ self._symmetric_projection(cout),
+ self._symmetric_projection(cout))
+ p = self._matrix_conv(p, temp)
+ for i in range(ksize):
+ for j in range(ksize):
+ for k in range(ksize):
+ p[i, j, k] = math_ops.matmul(orth, p[i, j, k])
+
+ return self._dict_to_tensor(p, ksize, ksize, ksize)
+
+
@tf_export("keras.initializers.Identity", "initializers.identity")
class Identity(Initializer):
"""Initializer that generates the identity matrix.
@@ -825,7 +1107,9 @@ variance_scaling_initializer = VarianceScaling
orthogonal_initializer = Orthogonal
identity_initializer = Identity
convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
+convolutional_orthogonal_1d = ConvolutionOrthogonal1D
convolutional_orthogonal_2d = ConvolutionOrthogonal2D
+convolutional_orthogonal_3d = ConvolutionOrthogonal3D
# pylint: enable=invalid-name
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 9c8abb9740..7e4fb6a6fc 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -233,6 +233,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
def _test_matmul(self, with_batch):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
+ # If batch dimensions are omitted, but there are
+ # no batch dimensions for the linear operator, then
+ # skip the test case. This is already checked with
+ # with_batch=True.
+ if not with_batch and len(build_info.shape) <= 2:
+ continue
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
@@ -270,6 +276,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
def _test_solve(self, with_batch):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
+ # If batch dimensions are omitted, but there are
+ # no batch dimensions for the linear operator, then
+ # skip the test case. This is already checked with
+ # with_batch=True.
+ if not with_batch and len(build_info.shape) <= 2:
+ continue
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 54f4e0f240..86dc053c0f 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -352,10 +352,17 @@ class BasicRNNCell(LayerRNNCell):
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
"""
- def __init__(self, num_units, activation=None, reuse=None, name=None):
- super(BasicRNNCell, self).__init__(_reuse=reuse, name=name)
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -413,6 +420,8 @@ class GRUCell(LayerRNNCell):
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
"""
def __init__(self,
@@ -421,8 +430,9 @@ class GRUCell(LayerRNNCell):
reuse=None,
kernel_initializer=None,
bias_initializer=None,
- name=None):
- super(GRUCell, self).__init__(_reuse=reuse, name=name)
+ name=None,
+ dtype=None):
+ super(GRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -531,8 +541,14 @@ class BasicLSTMCell(LayerRNNCell):
that follows.
"""
- def __init__(self, num_units, forget_bias=1.0,
- state_is_tuple=True, activation=None, reuse=None, name=None):
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
"""Initialize the basic LSTM cell.
Args:
@@ -550,11 +566,13 @@ class BasicLSTMCell(LayerRNNCell):
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
When restoring from CudnnLSTM-trained checkpoints, must use
`CudnnCompatibleLSTMCell` instead.
"""
- super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name)
+ super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@@ -668,7 +686,7 @@ class LSTMCell(LayerRNNCell):
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=None, num_proj_shards=None,
forget_bias=1.0, state_is_tuple=True,
- activation=None, reuse=None, name=None):
+ activation=None, reuse=None, name=None, dtype=None):
"""Initialize the parameters for an LSTM cell.
Args:
@@ -701,11 +719,13 @@ class LSTMCell(LayerRNNCell):
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
When restoring from CudnnLSTM-trained checkpoints, use
`CudnnCompatibleLSTMCell` instead.
"""
- super(LSTMCell, self).__init__(_reuse=reuse, name=name)
+ super(LSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index e90ff0746a..f71f98aa12 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -22,12 +22,13 @@ from __future__ import print_function
import sys as _sys
+# pylint: disable=g-bad-import-order
# Imports the following modules so that @RegisterGradient get executed.
from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import cudnn_rnn_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
-from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
@@ -96,6 +97,7 @@ from tensorflow.python.ops.tensor_array_ops import *
from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import *
# pylint: enable=wildcard-import
+# pylint: enable=g-bad-import-order
#### For use in remove_undocumented below:
from tensorflow.python.framework import constant_op as _constant_op
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 0294ecee54..9b6b8c508f 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -452,8 +452,7 @@ class Template(checkpointable.CheckpointableBase):
# Only reuse variables if they were already created.
with variable_scope.variable_scope(
self._variable_scope, reuse=self._variables_created):
- result = self._call_func(args, kwargs)
- return result
+ return self._call_func(args, kwargs)
else:
# The scope was not created at construction time, so create it here.
# Subsequent calls should reuse variables.
@@ -461,8 +460,7 @@ class Template(checkpointable.CheckpointableBase):
self._unique_name, self._name,
custom_getter=self._custom_getter) as vs:
self._variable_scope = vs
- result = self._call_func(args, kwargs)
- return result
+ return self._call_func(args, kwargs)
@property
def name(self):
@@ -730,8 +728,7 @@ class EagerTemplate(Template):
self._variable_scope, reuse=variable_scope.AUTO_REUSE)
with self._variable_scope_context_manager:
with self._template_store.as_default():
- result = self._call_func(args, kwargs)
- return result
+ return self._call_func(args, kwargs)
else:
# The scope was not created at construction time, so create it here.
# Subsequent calls should reuse variables.
@@ -743,8 +740,7 @@ class EagerTemplate(Template):
# store's variable scope name is unset; set it here.
self._template_store.set_variable_scope_name(vs.name)
with self._template_store.as_default():
- result = self._call_func(args, kwargs)
- return result
+ return self._call_func(args, kwargs)
@property
def name(self):
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 286028b8bb..663036de8a 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -17,21 +17,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import inspect as _inspect
-import six
from collections import namedtuple
+import inspect as _inspect
from tensorflow.python.util import tf_decorator
ArgSpec = _inspect.ArgSpec
-if six.PY3:
- FullArgSpec = _inspect.FullArgSpec
+if hasattr(_inspect, 'FullArgSpec'):
+ FullArgSpec = _inspect.FullArgSpec # pylint: disable=invalid-name
else:
- FullArgSpec = namedtuple(
- 'FullArgSpec', ['args', 'varargs', 'varkw', 'defaults',
- 'kwonlyargs', 'kwonlydefaults', 'annotations'])
+ FullArgSpec = namedtuple('FullArgSpec', [
+ 'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
+ 'annotations'
+ ])
def currentframe():
@@ -70,8 +70,20 @@ def getfullargspec(obj): # pylint: disable=redefined-builtin
callable is not decorated, `inspect.getfullargspec()` will be called
directly on the callable.
"""
- if six.PY2:
+ if hasattr(_inspect, 'getfullargspec'):
+ spec_fn = _inspect.getfullargspec
+ else:
def spec_fn(target):
+ """Spec function that adding default value from FullArgSpec.
+
+ It is used when getfullargspec is not available (eg in PY2).
+
+ Args:
+ target: the target object to inspect.
+ Returns:
+ The full argument specs with empty kwonlyargs, kwonlydefaults and
+ annotations.
+ """
argspecs = _inspect.getargspec(target)
fullargspecs = FullArgSpec(
args=argspecs.args,
@@ -82,8 +94,6 @@ def getfullargspec(obj): # pylint: disable=redefined-builtin
kwonlydefaults=None,
annotations={})
return fullargspecs
- else:
- spec_fn = _inspect.getfullargspec
decorators, target = tf_decorator.unwrap(obj)
return next((d.decorator_argspec for d in decorators
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index c7748f5b7a..c06a39bfbd 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -160,7 +160,7 @@ def get_api_init_text():
# we want to traverse over TensorFlow Python modules.
for module in sys.modules.values():
# Only look at tensorflow modules.
- if (not module or not hasattr(module, "__name__") or
+ if (not module or not hasattr(module, '__name__') or
'tensorflow.' not in module.__name__):
continue
# Do not generate __init__.py files for contrib modules for now.
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
index 0900adaf76..cbbd077c97 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
@@ -64,7 +64,7 @@ tf_class {
}
member_method {
name: "list_files"
- argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 7b16ac90c9..9a56ae8675 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
- argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
index 9cf5f2ae20..e5ec824bb8 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
- argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
index 8c3d669143..008239789c 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
@@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
- argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index fd9be8c759..53a903c239 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 6b305be43f..ba17c90de2 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index f909cd8756..e1abd43ab5 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index 173d2eae63..93e7e40199 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index d7f658aaee..465fc1cd9c 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index b9ab487c77..38a387d55a 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index cec23b1a36..fb0bd2c2ff 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -657,12 +657,14 @@ def _get_arg_spec(func):
argspec_defaults.pop(i-first_default_arg)
else:
first_default_arg -= 1
- return tf_inspect.FullArgSpec(args=argspec_args,
- varargs=argspec.varargs,
- varkw=argspec.varkw,
- defaults=tuple(argspec_defaults),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ return tf_inspect.FullArgSpec(
+ args=argspec_args,
+ varargs=argspec.varargs,
+ varkw=argspec.varkw,
+ defaults=tuple(argspec_defaults),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
else: # Regular function or method, getargspec will work fine.
return tf_inspect.getfullargspec(func)
@@ -672,7 +674,7 @@ def _remove_first_line_indent(string):
return '\n'.join([line[indent:] for line in string.split('\n')])
-PAREN_NUMBER_RE = re.compile("^\(([0-9.e-]+)\)")
+PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)')
def _generate_signature(func, reverse_index):
@@ -1145,10 +1147,11 @@ class _ClassPageInfo(object):
for short_name in parser_config.tree[self.full_name]:
# Remove builtin members that we never want to document.
- if short_name in ['__class__', '__base__', '__weakref__', '__doc__',
- '__module__', '__dict__', '__abstractmethods__',
- '__slots__', '__getnewargs__', '__str__',
- '__repr__', '__hash__']:
+ if short_name in [
+ '__class__', '__base__', '__weakref__', '__doc__', '__module__',
+ '__dict__', '__abstractmethods__', '__slots__', '__getnewargs__',
+ '__str__', '__repr__', '__hash__'
+ ]:
continue
child_name = '.'.join([self.full_name, short_name])
@@ -1193,7 +1196,8 @@ class _ClassPageInfo(object):
# obvious what they do, don't include them in the docs if there's no
# docstring.
if not child_doc.brief.strip() and short_name in [
- '__del__', '__copy__']:
+ '__del__', '__copy__'
+ ]:
print('Skipping %s, defined in %s, no docstring.' % (child_name,
defining_class))
continue
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index d7757d78ed..274d48ef66 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -408,67 +408,98 @@ class ParserTest(googletest.TestCase):
# pylint: disable=protected-access
# Make sure everything works for regular functions.
- expected = tf_inspect.FullArgSpec(args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
- varargs=None, varkw=None, defaults=(1, 2),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 2),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1))
# Make sure doing nothing works.
- expected = tf_inspect.FullArgSpec(args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
- varargs=None, varkw=None, defaults=(1, 2),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg1', 'arg2', 'kwarg1', 'kwarg2'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 2),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1)
self.assertEqual(expected, parser._get_arg_spec(partial))
# Make sure setting args from the front works.
- expected = tf_inspect.FullArgSpec(args=['arg2', 'kwarg1', 'kwarg2'],
- varargs=None, varkw=None, defaults=(1, 2),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg2', 'kwarg1', 'kwarg2'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 2),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1, 1)
self.assertEqual(expected, parser._get_arg_spec(partial))
- expected = tf_inspect.FullArgSpec(args=['kwarg2'],
- varargs=None, varkw=None, defaults=(2,),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['kwarg2'],
+ varargs=None,
+ varkw=None,
+ defaults=(2,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1, 1, 2, 3)
self.assertEqual(expected, parser._get_arg_spec(partial))
# Make sure setting kwargs works.
- expected = tf_inspect.FullArgSpec(args=['arg1', 'arg2', 'kwarg2'],
- varargs=None, varkw=None, defaults=(2,),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg1', 'arg2', 'kwarg2'],
+ varargs=None,
+ varkw=None,
+ defaults=(2,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1, kwarg1=0)
self.assertEqual(expected, parser._get_arg_spec(partial))
- expected = tf_inspect.FullArgSpec(args=['arg1', 'arg2', 'kwarg1'],
- varargs=None, varkw=None, defaults=(1,),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg1', 'arg2', 'kwarg1'],
+ varargs=None,
+ varkw=None,
+ defaults=(1,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1, kwarg2=0)
self.assertEqual(expected, parser._get_arg_spec(partial))
- expected = tf_inspect.FullArgSpec(args=['arg1'],
- varargs=None, varkw=None, defaults=(),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=['arg1'],
+ varargs=None,
+ varkw=None,
+ defaults=(),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial1,
arg2=0, kwarg1=0, kwarg2=0)
self.assertEqual(expected, parser._get_arg_spec(partial))
# Make sure *args, *kwargs is accounted for.
- expected = tf_inspect.FullArgSpec(args=[],
- varargs='my_args', varkw='my_kwargs',
- defaults=(),
- kwonlyargs=[], kwonlydefaults=None,
- annotations={})
+ expected = tf_inspect.FullArgSpec(
+ args=[],
+ varargs='my_args',
+ varkw='my_kwargs',
+ defaults=(),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
partial = functools.partial(test_function_for_partial2, 0, 1)
self.assertEqual(expected, parser._get_arg_spec(partial))
-
+
# pylint: enable=protected-access
def testSaveReferenceResolver(self):
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 376644718f..2ef105755f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -74,7 +74,9 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/predictor:predictor_pip",
+ "//tensorflow/contrib/proto:proto_pip",
"//tensorflow/contrib/receptive_field:receptive_field_pip",
+ "//tensorflow/contrib/rpc:rpc_pip",
"//tensorflow/contrib/session_bundle:session_bundle_pip",
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/contrib/signal:test_util",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index f775491e4a..d7bd2a2be0 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -427,11 +427,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "grpc",
urls = [
- "https://mirror.bazel.build/github.com/grpc/grpc/archive/09386db3939cae1ac12e5f09b735adfa8958c68e.tar.gz",
- "https://github.com/grpc/grpc/archive/09386db3939cae1ac12e5f09b735adfa8958c68e.tar.gz",
+ "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
+ "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
],
- sha256 = "b857969c667c14f37faa507afc07a3f39a47fbf73203be889d55925622e7b317",
- strip_prefix = "grpc-09386db3939cae1ac12e5f09b735adfa8958c68e",
+ sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673",
+ strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f",
)
@@ -451,11 +451,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/15535accd9e1e9d7772202ce51c8428c1994a04b.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/15535accd9e1e9d7772202ce51c8428c1994a04b.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/3210e64b499a31193051208f2f8922dadfc4bb6f.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/3210e64b499a31193051208f2f8922dadfc4bb6f.tar.gz",
],
- sha256 = "3470c2dde055dc974e859e707aa6cd1d22eadd4f3a1f282e74c3cf1f7dc9510a",
- strip_prefix = "llvm-15535accd9e1e9d7772202ce51c8428c1994a04b",
+ sha256 = "017d7db029cc175634d75416c326770139c76590575ed44a3794c11ab160c955",
+ strip_prefix = "llvm-3210e64b499a31193051208f2f8922dadfc4bb6f",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
@@ -752,6 +752,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
name = "grpc_cpp_plugin",
actual = "@grpc//:grpc_cpp_plugin",
)
+ native.bind(
+ name = "grpc_python_plugin",
+ actual = "@grpc//:grpc_python_plugin",
+ )
# gRPC has three empty C++ functions which it wants the user to define
# at build time. https://github.com/grpc/grpc/issues/13590
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 097bbf5d42..cbb1b2fe42 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -2006,7 +2006,6 @@ cc_library(
]) + [
"include/llvm/BinaryFormat/MachO.def",
"include/llvm/Support/VCSRevision.h",
- "include/llvm/ExecutionEngine/ObjectMemoryBuffer.h",
],
deps = [
":config",