aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure2
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/compiler/jit/xla_device.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc3
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h9
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc7
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/__init__.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/identity.py29
-rw-r--r--tensorflow/contrib/distributions/python/ops/conditional_distribution.py (renamed from tensorflow/python/ops/distributions/conditional_distribution.py)0
-rw-r--r--tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py2
-rw-r--r--tensorflow/contrib/layers/BUILD1
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py28
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py123
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py18
-rw-r--r--tensorflow/contrib/lookup/BUILD13
-rw-r--r--tensorflow/contrib/lookup/__init__.py2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py64
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py84
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py275
-rw-r--r--tensorflow/contrib/rnn/BUILD25
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py63
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py48
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py3
-rw-r--r--tensorflow/contrib/rnn/python/tools/checkpoint_convert.py231
-rw-r--r--tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py108
-rw-r--r--tensorflow/contrib/slim/python/slim/learning.py4
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/common_runtime/device.cc3
-rw-r--r--tensorflow/core/common_runtime/device.h3
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc19
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h2
-rw-r--r--tensorflow/core/common_runtime/device_set.h5
-rw-r--r--tensorflow/core/common_runtime/device_set_test.cc3
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc7
-rw-r--r--tensorflow/core/common_runtime/local_device.cc6
-rw-r--r--tensorflow/core/common_runtime/local_device.h4
-rw-r--r--tensorflow/core/common_runtime/renamed_device.cc54
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h119
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc2
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD4
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc102
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.h43
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc37
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h11
-rw-r--r--tensorflow/core/distributed_runtime/master.cc121
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h34
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc160
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h11
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.cc49
-rw-r--r--tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc83
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h15
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc100
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h13
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc19
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc108
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h44
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr_test.cc81
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc32
-rw-r--r--tensorflow/core/distributed_runtime/worker_env.h11
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h5
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc84
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.h12
-rw-r--r--tensorflow/core/framework/device_base.h8
-rw-r--r--tensorflow/core/framework/op_kernel.cc16
-rw-r--r--tensorflow/core/grappler/costs/BUILD27
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc554
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h143
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc113
-rw-r--r--tensorflow/core/grappler/costs/utils.cc4
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h2
-rw-r--r--tensorflow/core/grappler/op_types.cc10
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD17
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc188
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.h6
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc147
-rw-r--r--tensorflow/core/kernels/BUILD10
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc565
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.h8
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc6
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_add_op.cc27
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc65
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h9
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc146
-rw-r--r--tensorflow/core/kernels/unique_op.cc5
-rw-r--r--tensorflow/core/kernels/variable_ops.h12
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc598
-rw-r--r--tensorflow/core/ops/lookup_ops.cc666
-rw-r--r--tensorflow/core/protobuf/cluster.proto82
-rw-r--r--tensorflow/core/protobuf/config.proto6
-rw-r--r--tensorflow/core/protobuf/master.proto3
-rw-r--r--tensorflow/core/protobuf/tensorflow_server.proto64
-rw-r--r--tensorflow/core/protobuf/worker.proto12
-rw-r--r--tensorflow/docs_src/get_started/get_started.md51
-rw-r--r--tensorflow/docs_src/install/install_java.md6
-rw-r--r--tensorflow/docs_src/programmers_guide/index.md5
-rw-r--r--tensorflow/docs_src/programmers_guide/supervisor.md4
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java3
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java5
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java8
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java4
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java4
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java88
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java54
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java26
-rw-r--r--tensorflow/go/README.md25
-rw-r--r--tensorflow/go/op/wrappers.go3546
-rw-r--r--tensorflow/python/BUILD31
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session_test.py267
-rw-r--r--tensorflow/python/estimator/estimator_test.py9
-rw-r--r--tensorflow/python/estimator/export/export.py6
-rw-r--r--tensorflow/python/estimator/export/export_output.py4
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py29
-rw-r--r--tensorflow/python/estimator/export/export_test.py67
-rw-r--r--tensorflow/python/feature_column/BUILD34
-rw-r--r--tensorflow/python/feature_column/feature_column.py379
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py702
-rw-r--r--tensorflow/python/feature_column/lookup_ops.py (renamed from tensorflow/contrib/lookup/lookup_ops.py)60
-rw-r--r--tensorflow/python/feature_column/testdata/warriors_vocabulary.txt5
-rw-r--r--tensorflow/python/feature_column/testdata/wire_vocabulary.txt3
-rw-r--r--tensorflow/python/framework/ops.py35
-rw-r--r--tensorflow/python/framework/ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD17
-rw-r--r--tensorflow/python/kernel_tests/distributions/identity_bijector_test.py (renamed from tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py)10
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py87
-rw-r--r--tensorflow/python/ops/data_flow_ops.py42
-rw-r--r--tensorflow/python/ops/distributions/bijector_test_util.py (renamed from tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py)0
-rw-r--r--tensorflow/python/ops/distributions/identity_bijector.py (renamed from tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py)0
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py (renamed from tensorflow/contrib/distributions/python/ops/transformed_distribution.py)4
-rw-r--r--tensorflow/python/ops/lookup_ops.py77
-rw-r--r--tensorflow/python/ops/metrics_impl.py69
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py40
-rw-r--r--tensorflow/python/ops/sparse_ops.py2
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/saved_model/main_op_impl.py4
-rw-r--r--tensorflow/python/tools/saved_model_cli.py64
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/training/saver_test_utils.py14
-rw-r--r--tensorflow/python/training/server_lib.py7
-rw-r--r--tensorflow/python/training/supervisor.py8
-rw-r--r--tensorflow/python/training/training.py30
-rw-r--r--tensorflow/tensorboard/package.json2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-operation.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt2
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py20
-rw-r--r--tensorflow/workspace.bzl12
192 files changed, 8209 insertions, 4114 deletions
diff --git a/configure b/configure
index 4104651cbb..4e66e952c2 100755
--- a/configure
+++ b/configure
@@ -385,7 +385,7 @@ fi
# Append CC optimization flags to bazel.rc
for opt in $CC_OPT_FLAGS; do
- write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt'
+ write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
done
# Run the gen_git_source to create links where bazel can track dependencies for
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 4ad69ae3fb..3ab4e8efcd 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -58,6 +58,7 @@ tf_cuda_library(
"//tensorflow/cc/saved_model:loader",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:grad_ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 8810b8731a..8d4260a0b9 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -91,6 +91,7 @@ cc_library(
deps = [
":array_grad",
":math_grad",
+ ":nn_grad",
],
)
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 93f487c36c..5e336c5287 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceType& jit_device_name,
perftools::gputools::Platform* platform,
Allocator* xla_allocator)
- : LocalDevice(options, attrs, xla_allocator),
+ : LocalDevice(options, attrs),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator),
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index d86e741b69..362a101895 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
options,
Device::BuildDeviceAttributes(
"", type, Bytes(256 << 20), DeviceLocality(),
- strings::StrCat("device: XLA compilation device ", type.type())),
- cpu_allocator()),
+ strings::StrCat("device: XLA compilation device ", type.type()))),
allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 87ceb43d1f..6af69eeec1 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -668,6 +668,14 @@ class ComputationBuilder {
// then Build() should be used instead.
Computation BuildAndNoteError();
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // ComputationDataHandle and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
private:
using PopulateLiteral = std::function<void(Literal*)>;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 1c704fd1ee..1e34de9e4b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name,
if (&argument == retval) {
continue;
}
- compute_function_->setDoesNotAlias(argument.getArgNo() + 1);
+ compute_function_->addAttribute(argument.getArgNo() + 1,
+ llvm::Attribute::NoAlias);
}
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 04babcca0c..e52e55a1a8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
ir_emitter_context_->buffer_assignment().GetTempAllocation()) {
kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size());
}
- kernel->setDoesNotAlias(temp_buffer_arg_no + 1);
+ kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias);
// Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
// treats it as a CUDA kernel.
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 5e7bd4a7ce..d413621cfe 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -705,7 +705,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()) &&
ShapeUtil::IsArray(operand->shape()));
- if (instruction->IsElementwiseOnOperand(operand_no) &&
+ if ((instruction->IsElementwiseOnOperand(operand_no) ||
+ InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) &&
!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape())) {
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 61dc7b1207..4f586c334d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface {
return Status::OK();
}
+ // This method can be overriden to mark instructions as requiring the operands
+ // to have the same layout as the result, for performance or correctness. This
+ // will propagate constraints through the instruction from the result into the
+ // operands.
+ virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
+ const HloInstruction* instruction) {
+ return false;
+ }
+
// Construct contraints and assign layouts to all instructions in the
// computation satisfying the given ComputationLayout. Layouts constraints are
// added, then propagated until all LogicalBuffers in the computation are
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 338d63f1a0..b2ef8ed486 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -244,8 +244,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
return InvalidArgument(
- "cannot concatenate arrays with different ranks: %lld vs %lld",
- ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape));
+ "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
+ "(%s)",
+ ShapeUtil::Rank(*arg_shape),
+ ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
+ ShapeUtil::HumanString(*shape).c_str());
}
if (arg_shape->element_type() != shape->element_type()) {
return InvalidArgument(
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 6fd1ae0814..560e45fc13 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -118,6 +118,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/types.proto"
"tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto"
+ "tensorflow/core/protobuf/cluster.proto"
"tensorflow/core/protobuf/config.proto"
"tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/rewriter_config.proto"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 2a19433a7b..eae00ab875 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -22,6 +22,7 @@ set(tf_op_lib_names
"image_ops"
"io_ops"
"linalg_ops"
+ "lookup_ops"
"logging_ops"
"math_ops"
"nn_ops"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 53ebfbb57d..9e2eb71b4c 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator")
add_python_module("tensorflow/python/estimator/export")
add_python_module("tensorflow/python/estimator/inputs")
add_python_module("tensorflow/python/estimator/inputs/queues")
+add_python_module("tensorflow/python/feature_column")
add_python_module("tensorflow/python/framework")
add_python_module("tensorflow/python/grappler")
add_python_module("tensorflow/python/kernel_tests")
@@ -596,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
GENERATE_PYTHON_OP_LIB("logging_ops")
+GENERATE_PYTHON_OP_LIB("lookup_ops")
GENERATE_PYTHON_OP_LIB("nn_ops")
GENERATE_PYTHON_OP_LIB("parsing_ops")
GENERATE_PYTHON_OP_LIB("random_ops")
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 9f675c6613..0c818dee03 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -711,25 +711,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "identity_test",
- size = "small",
- srcs = ["python/kernel_tests/bijectors/identity_test.py"],
- additional_deps = [
- ":bijectors_py",
- ":distributions_py",
- "//third_party/py/numpy",
- "@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
-)
-
-cuda_py_test(
name = "inline_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/inline_test.py"],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 6ea74fab0e..ea12e13010 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -25,6 +25,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
+from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
@@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import *
-from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.ops.distributions.bernoulli import *
from tensorflow.python.ops.distributions.beta import *
from tensorflow.python.ops.distributions.categorical import *
-from tensorflow.python.ops.distributions.conditional_distribution import *
from tensorflow.python.ops.distributions.dirichlet import *
from tensorflow.python.ops.distributions.dirichlet_multinomial import *
from tensorflow.python.ops.distributions.distribution import *
@@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import *
from tensorflow.python.ops.distributions.multinomial import *
from tensorflow.python.ops.distributions.normal import *
from tensorflow.python.ops.distributions.student_t import *
+from tensorflow.python.ops.distributions.transformed_distribution import *
from tensorflow.python.ops.distributions.uniform import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
index 13554f7664..e8fd6aa2f7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
@@ -23,9 +23,9 @@ import itertools
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index 994e21dd48..20e7543084 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -20,12 +20,12 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index a4688829f1..0ff3530428 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
-from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
index c30ce60cac..9970c0b4d8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
index a4688829f1..0ff3530428 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
@@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
-from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
index b30a3b599b..de1659aa9f 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
index 6f1a6b1cf4..e4f9d72785 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
@@ -21,9 +21,9 @@ from __future__ import print_function
import numpy as np
from scipy import special
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
index 173d52686d..62e3869db0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
index 214b196b54..d9af9aec50 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
rng = np.random.RandomState(42)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index e1d31e373c..1684a5fffe 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
-from tensorflow.contrib.distributions.python.ops.bijectors.identity import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
@@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.python.ops.distributions.bijector import *
+from tensorflow.python.ops.distributions.identity_bijector import Identity
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py b/tensorflow/contrib/distributions/python/ops/bijectors/identity.py
deleted file mode 100644
index 749dd268f9..0000000000
--- a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Identity bijector."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = ["Identity"]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/distributions/conditional_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_distribution.py
index ef25d4aedd..ef25d4aedd 100644
--- a/tensorflow/python/ops/distributions/conditional_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/conditional_distribution.py
diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
index b0967802bd..2e1e68cf05 100644
--- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
+from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import conditional_distribution
+from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index a66eb1674c..fbd623ed3a 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
+from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 581e190f73..5b57a95c55 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import logistic
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
# Bijectors must be directly imported because `remove_undocumented` prevents
# individual file imports.
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
@@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 00415f5e1a..da1cd72a6f 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import bijectors
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 299ff36962..ae804b6172 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
-from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import student_t
+from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index aba8eabe10..fe661a5625 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -108,6 +108,7 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/feature_column",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index d6d5bf2294..04fe2370d1 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -1497,9 +1499,12 @@ def _real_valued_var_len_column(column_name,
is_sparse)
-class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
- "_RealValuedColumn",
- ["column_name", "dimension", "default_value", "dtype", "normalizer"])):
+class _RealValuedColumn(
+ _FeatureColumn,
+ fc_core._DenseColumn, # pylint: disable=protected-access
+ collections.namedtuple(
+ "_RealValuedColumn",
+ ["column_name", "dimension", "default_value", "dtype", "normalizer"])):
"""Represents a real valued feature column also known as continuous features.
Instances of this class are immutable. The dictionary returned by InputBuilder
@@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
def _to_dense_tensor(self, input_tensor):
return input_tensor
+ @property
+ def _variable_shape(self):
+ return tensor_shape.TensorShape((self.dimension))
+
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
+ def _transform_feature(self, inputs):
+ return math_ops.to_float(
+ self._normalized_input_tensor(inputs.get(self.name)))
+
+ @property
+ def _parse_example_config(self):
+ return self.config
+
def real_valued_column(column_name,
dimension=1,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 632836fee4..a09cc53571 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -27,14 +27,15 @@ from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@@ -223,7 +224,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[keys_sparse].indices.eval(),
@@ -241,7 +242,7 @@ class TransformerTest(test.TestCase):
output = feature_column_ops._Transformer(features).transform(keys_sparse)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.dtype, dtypes.int64)
@@ -310,7 +311,7 @@ class TransformerTest(test.TestCase):
self.assertIn(weighted_ids, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
@@ -340,7 +341,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -362,7 +363,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -386,7 +387,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -408,7 +409,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -600,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column, embedding_column, real_valued_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
def testRealValuedColumn(self):
@@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2)
@@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnSparse(self):
sparse_real_valued = feature_column._real_valued_var_len_column(
@@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
real_valued = feature_column.real_valued_column(
@@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
bucket = feature_column.bucketized_column(
@@ -697,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval())
@@ -715,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval())
@@ -733,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@@ -767,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
def testEmbeddingColumnSucceedsForDNN(self):
@@ -874,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
@@ -897,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
@@ -948,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
def testCrossedColumnFailsForDNN(self):
@@ -1055,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# score: (sum of weights)
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
@@ -1293,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, 4])
@@ -1327,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, hash_buckets])
@@ -1357,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@@ -1386,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@@ -1416,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
expected_input_shape = [4, 3, embedding_dimension]
@@ -1483,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = [
@@ -1564,7 +1581,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testWeightedSparseColumnWithDenseInputTensor(self):
@@ -1580,7 +1597,7 @@ class WeightedSumTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testCrossedColumn(self):
@@ -1634,7 +1651,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.initialize_all_variables().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (3, 1))
@@ -1709,7 +1726,7 @@ class WeightedSumTest(test.TestCase):
features, [age, language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1749,7 +1766,7 @@ class WeightedSumTest(test.TestCase):
self.assertEqual(len(variables), 1)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1813,7 +1830,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1841,7 +1858,7 @@ class WeightedSumTest(test.TestCase):
features, [language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# score: 0.1 + language_weight['hindi'] + language_weight['english']
sess.run(bias.assign([0.1]))
@@ -1864,7 +1881,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (15, 1))
@@ -1898,7 +1915,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -1922,7 +1939,7 @@ class WeightedSumTest(test.TestCase):
features, [language_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[language_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -1955,7 +1972,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -1996,7 +2013,7 @@ class WeightedSumTest(test.TestCase):
scope=scope))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(2, len(column_to_variable[country]))
self.assertEqual(3, len(column_to_variable[language]))
@@ -2033,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, incomes], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
incomes_weights = column_to_variable[incomes][0]
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
@@ -2069,7 +2086,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, height, incomes], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
height_weights = column_to_variable[height][0]
sess.run(
@@ -2099,7 +2116,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
[0.4]]))
@@ -2127,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 1
sess.run(column_to_variable[bucket][0].assign(
@@ -2156,7 +2173,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 5
sess.run(column_to_variable[bucket][0].assign(
@@ -2192,7 +2209,7 @@ class WeightedSumTest(test.TestCase):
features, [country_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_price][0]
sess.run(weights.assign(weights + 0.4))
@@ -2231,7 +2248,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language_price][0]
sess.run(weights.assign(weights + 0.4))
@@ -2255,7 +2272,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2270,7 +2287,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2285,7 +2302,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.6], [0.7]])
@@ -2306,7 +2323,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2318,7 +2335,7 @@ class WeightedSumTest(test.TestCase):
features, [feature_column.real_valued_column("age")], num_outputs=3)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
@@ -2332,7 +2349,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (1, 3))
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
@@ -2356,7 +2373,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
sess.run(
@@ -2382,7 +2399,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2422,7 +2439,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2451,7 +2468,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2516,7 +2533,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py
index 9fb9a3e257..1926cbe7b3 100644
--- a/tensorflow/contrib/layers/python/layers/initializers.py
+++ b/tensorflow/contrib/layers/python/layers/initializers.py
@@ -46,7 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
Args:
uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See
- @{set_random_seed} for behavior.
+ @{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported.
Returns:
@@ -96,7 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See
- @{set_random_seed} for behavior.
+ @{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported.
Returns:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 61a6168a9e..6fc028ab70 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
self.context_feature_columns)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input)
expected_shape = np.array([
3, # expected batch size
@@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
# Obtain values of activations and final state.
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
activations, final_state = sess.run([activations_t, final_state_t])
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 74a6da20d4..36f843ba8e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -57,7 +57,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator):
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 207a189a94..d5777088de 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test
@@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
[0, 2],
model_fn_ops.predictions["classes"].eval())
@@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
[b"key0", b"key2"],
model_fn_ops.predictions["classes"].eval())
@@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
@@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index e7470a544f..69469b577d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.initialize_all_tables())
+ sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time)
self.assertAllEqual(expected, features_val)
@@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.initialize_all_tables())
+ sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run(
[sequence, context])
assert_equal(expected_sequence, actual_sequence)
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 602d33e5f9..85d45aef7a 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -647,6 +647,10 @@ class Experiment(object):
if _sentinel is not None:
raise ValueError("_call_train should be called with keyword args only")
+ # Estimator in core cannot work with monitors. We need to convert them
+ # to hooks. For Estimator in contrib, it is converted internally. So, it is
+ # safe to convert for both cases.
+ hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
if self._core_estimator_used:
return self._estimator.train(input_fn=input_fn,
steps=steps,
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index 4b5f3a195c..9ecfc73299 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -24,7 +24,6 @@ import time
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import experiment
-from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import run_config
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
@@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.eval_count)
self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks)
- self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
+ self.assertTrue(isinstance(est.monitors[0],
+ session_run_hook.SessionRunHook))
def test_train_hooks_extend_does_not_mutate_input_hooks(self):
for est in self._estimators_for_tests():
@@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.export_count)
self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks)
- self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
+ self.assertTrue(isinstance(est.monitors[0],
+ session_run_hook.SessionRunHook))
def test_train_and_evaluate_with_no_eval_during_training(self):
for est in self._estimators_for_tests():
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 4b7867f2d0..98365c05f6 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@@ -429,11 +429,14 @@ def _get_ready_op():
def _get_local_init_op():
+ """Returns the local init ops to initialize tables and local variables."""
local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
- op_list = [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]
+ op_list = [
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
@@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
else:
session.run(variables.global_variables_initializer())
session.run(variables.local_variables_initializer())
- session.run(data_flow_ops.tables_initializer())
+ session.run(lookup_ops.tables_initializer())
coord = coordinator.Coordinator()
threads = None
try:
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index b53be29283..36a1f5f60c 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
@@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
- data_flow_ops.tables_initializer()
+ lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
- export.init(init_op=control_flow_ops.group(
- variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()),
- default_graph_signature=default_graph_signature,
- named_graph_signatures=named_graph_signatures,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS))
+ export.init(
+ init_op=control_flow_ops.group(
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()),
+ default_graph_signature=default_graph_signature,
+ named_graph_signatures=named_graph_signatures,
+ assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index b3316ee8c4..bbbd340352 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -13,19 +13,10 @@ py_library(
name = "lookup_py",
srcs = [
"__init__.py",
- "lookup_ops.py",
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops_gen",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
+ "//tensorflow/python/feature_column:lookup_ops",
],
)
@@ -39,11 +30,11 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:lookup_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
"//tensorflow/python:variables",
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index dbd64cf042..a5fcdc7b42 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -47,7 +47,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
-from tensorflow.contrib.lookup.lookup_ops import *
+from tensorflow.python.feature_column.lookup_ops import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0ec40a63f2..5ec169b6db 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
table3 = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), default_val)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
@@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
@@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_file(self):
@@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_file_with_default_value(self):
@@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_file_with_oov_buckets(self):
@@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
(
1, # From vocabulary file.
@@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval())
@@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval())
@@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self):
@@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
@@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_with_default_value(self):
@@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
@@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertRaises(errors_impl.OpError, ids.eval)
with self.assertRaisesRegexp(
errors_impl.OpError, "keys and values cannot be empty"):
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session():
@@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase):
indices = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, indices.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), indices.eval())
@@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase):
_ = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError,
- data_flow_ops.tables_initializer().run)
+ lookup_ops.tables_initializer().run)
def test_string_to_index_with_default_value(self):
default_value = -42
@@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase):
feats, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, indices.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), indices.eval())
@@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", default_value, default_value),
features.eval())
@@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- init = data_flow_ops.tables_initializer()
+ init = lookup_ops.tables_initializer()
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", init.run)
@@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
@@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
mapping=mapping_strings)
indices = constant_op.constant([0, 1, 4], dtypes.int64)
features = table.lookup(indices)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
def test_index_to_string_with_default_value(self):
@@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase):
feats = lookup.index_to_string(indices, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, feats.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
feats.eval())
@@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase):
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
self.assertRaises(errors_impl.OpError,
- data_flow_ops.tables_initializer().run)
+ lookup_ops.tables_initializer().run)
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
@@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase):
indices, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, feats.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
@@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value,
shared_name=shared_name)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec((1, 2)),
name="table2")
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string = constant_op.constant(
["fruit", "brain", "salad", "surgery", "UNK"])
@@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
default_value2),
oov_buckets)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"])
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index c0969e6dee..2f1fcb149e 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 132b477596..6087a45168 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index f1da05e4c6..c39257ffa9 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,6 +1,7 @@
tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 2a78ea6101..5eadf5d55b 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/rewriter_config.proto
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index d57203c042..727cdd9597 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
name=name_scope)
+def sparse_recall_at_top_k(labels,
+ top_k_predictions,
+ class_id=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes recall@k of top-k predictions with respect to sparse labels.
+
+ If `class_id` is specified, we calculate recall by considering only the
+ entries in the batch for which `class_id` is in the label, and computing
+ the fraction of them for which `class_id` is in the top-k `predictions`.
+ If `class_id` is not specified, we'll calculate recall as how often on
+ average a class among the labels of a batch entry is in the top-k
+ `predictions`.
+
+ `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
+ and `false_negative_at_<k>`, that are used to compute the recall_at_k
+ frequency. This frequency is ultimately returned as `recall_at_<k>`: an
+ idempotent operation that simply divides `true_positive_at_<k>` by total
+ (`true_positive_at_<k>` + `false_negative_at_<k>`).
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
+ true positives and false negatives weighted by `weights`. Then `update_op`
+ increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
+ values.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
+ target classes for the associated prediction. Commonly, N=1 and `labels`
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
+ `top_k_predictions`. Values should be in range [0, num_classes), where
+ num_classes is the last dimension of `predictions`. Values outside this
+ range always count towards `false_negative_at_<k>`.
+ top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
+ N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
+ The final dimension contains the indices of top-k labels. [D1, ... DN]
+ must match `labels`.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If class_id is outside this range, the method returns NAN.
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
+ dimensions must be either `1`, or the same as the corresponding `labels`
+ dimension).
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
+ by the sum of `true_positives` and `false_negatives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_negatives` variables appropriately, and whose value matches
+ `recall`.
+
+ Raises:
+ ValueError: If `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ default_name = _at_k_name('recall', class_id=class_id)
+ with ops.name_scope(name, default_name, (top_k_predictions, labels,
+ weights)) as name_scope:
+ return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
+ labels=labels,
+ predictions_idx=top_k_predictions,
+ class_id=class_id,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=name_scope)
+
+
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
@@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
+ 'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
'streaming_false_negatives',
@@ -2310,7 +2392,9 @@ __all__ = [
'streaming_root_mean_squared_error',
'streaming_sensitivity_at_specificity',
'streaming_sparse_average_precision_at_k',
+ 'streaming_sparse_average_precision_at_top_k',
'streaming_sparse_precision_at_k',
+ 'streaming_sparse_precision_at_top_k',
'streaming_sparse_recall_at_k',
'streaming_specificity_at_sensitivity',
'streaming_true_negatives',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b960e1310e..f42e974e23 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase):
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
+ def _test_sparse_recall_at_top_k(self,
+ labels,
+ top_k_predictions,
+ expected,
+ class_id=None,
+ weights=None):
+ with ops.Graph().as_default() as g, self.test_session(g):
+ if weights is not None:
+ weights = constant_op.constant(weights, dtypes_lib.float32)
+ metric, update = metric_ops.sparse_recall_at_top_k(
+ labels=labels,
+ top_k_predictions=constant_op.constant(top_k_predictions,
+ dtypes_lib.int32),
+ class_id=class_id,
+ weights=weights)
+
+ # Fails without initialized vars.
+ self.assertRaises(errors_impl.OpError, metric.eval)
+ self.assertRaises(errors_impl.OpError, update.eval)
+ variables.variables_initializer(variables.local_variables()).run()
+
+ # Run per-step op and assert expected values.
+ if math.isnan(expected):
+ self.assertTrue(math.isnan(update.eval()))
+ self.assertTrue(math.isnan(metric.eval()))
+ else:
+ self.assertEqual(expected, update.eval())
+ self.assertEqual(expected, metric.eval())
+
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (-1, 0, 1, 4):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 0 predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0, class_id=2)
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, class_id=3)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, class_id=3)
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2)
def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2,
class_id=3,
weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=3.0 / 3,
class_id=3,
weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.3 / 0.3,
class_id=3,
weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.6 / 0.6,
class_id=3,
weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_three_labels_at_k5_no_predictions(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 8: 1 label, no predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=8)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, class_id=8)
def test_three_labels_at_k5(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 6 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=3.0 / 6)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 6)
def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) count in denominator."""
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sp_labels = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
[1, 3]],
@@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=2.0 / 2,
class_id=2)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=1.0 / 1,
class_id=5)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=0.0 / 1,
class_id=7)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=0.0 / 1,
+ class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
+ self._test_sparse_recall_at_top_k(
+ sp_labels, top_k_predictions, expected=3.0 / 8)
def test_3d_nan(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
@@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_3d_no_predictions(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (1, 8):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0, class_id=class_id)
def test_3d(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=4.0 / 4, class_id=2)
# Class 5: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 2, class_id=5)
# Class 7: 2 labels, 1 incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, class_id=7)
# All classes: 12 labels, 7 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=7.0 / 12)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=7.0 / 12)
def test_3d_ignore_all(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=class_id,
+ weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0, 0], [0, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=class_id,
+ weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]])
def test_3d_ignore_some(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[1], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2.0,
+ class_id=2,
+ weights=[[1], [0]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[0], [1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2.0,
+ class_id=2,
+ weights=[[0], [1]])
# Class 7: 1 label, correct.
self._test_streaming_sparse_recall_at_k(
@@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1.0,
class_id=7,
weights=[[0], [1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1.0,
+ class_id=7,
+ weights=[[0], [1]])
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1.0,
class_id=7,
weights=[[1], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.0 / 1.0,
+ class_id=7,
+ weights=[[1], [0]])
# Class 7: 2 labels, 1 correct.
self._test_streaming_sparse_recall_at_k(
@@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 2.0,
class_id=7,
weights=[[1, 0], [1, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 2.0,
+ class_id=7,
+ weights=[[1, 0], [1, 0]])
# Class 7: No labels.
self._test_streaming_sparse_recall_at_k(
@@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=7,
weights=[[0, 1], [0, 1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=7,
+ weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self):
predictions = [[0.1, 0.3, 0.2, 0.4],
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index ab443eab6f..9d67563edd 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -304,6 +304,7 @@ filegroup(
exclude = [
"**/METADATA",
"**/OWNERS",
+ "tools/**",
],
),
visibility = ["//tensorflow:__subpackages__"],
@@ -351,3 +352,27 @@ tf_kernel_library(
"//third_party/eigen3",
],
)
+
+py_binary(
+ name = "checkpoint_convert",
+ srcs = ["python/tools/checkpoint_convert.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "checkpoint_convert_test",
+ size = "small",
+ srcs = ["python/tools/checkpoint_convert_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":checkpoint_convert",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 15afac9823..f4589e3d9e 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
- g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m)
+ cell = core_rnn_cell_impl.BasicRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertEqual(
+ ["root/basic_rnn_cell/%s:0"
+ % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/basic_rnn_cell/%s:0"
+ % core_rnn_cell_impl._BIAS_VARIABLE_NAME],
+ [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ self.assertEqual(res[0].shape, (1, 2))
+
+ def testBasicRNNCellNotTrainable(self):
+ with self.test_session() as sess:
+ def not_trainable_getter(getter, *args, **kwargs):
+ kwargs["trainable"] = False
+ return getter(*args, **kwargs)
+
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5),
+ custom_getter=not_trainable_getter):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = core_rnn_cell_impl.BasicRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertFalse(cell.trainable_variables)
+ self.assertEqual(
+ ["root/basic_rnn_cell/%s:0"
+ % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/basic_rnn_cell/%s:0"
+ % core_rnn_cell_impl._BIAS_VARIABLE_NAME],
+ [v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
@@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8])
- g, out_m = core_rnn_cell_impl.MultiRNNCell(
+ cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell(
2, state_is_tuple=False) for _ in range(2)],
- state_is_tuple=False)(x, m)
+ state_is_tuple=False)
+ g, out_m = cell(x, m)
+ expected_variable_names = [
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
+ % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
+ % core_rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
+ % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
+ % core_rnn_cell_impl._BIAS_VARIABLE_NAME]
+ self.assertEqual(
+ expected_variable_names, [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g, out_m],
@@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase):
m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2)
variables = variables_lib.global_variables()
- self.assertEqual(4, len(variables))
- self.assertEquals(variables[0].op.name,
- "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
- self.assertEquals(variables[1].op.name,
- "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
- self.assertEquals(variables[2].op.name,
- "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
- self.assertEquals(variables[3].op.name,
- "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
+ self.assertEqual(expected_variable_names, [v.name for v in variables])
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem = np.array([[
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
index 884b51926e..eba2c0d2ac 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
@@ -27,7 +27,6 @@ from __future__ import division
from __future__ import print_function
import collections
-import contextlib
import hashlib
import math
import numbers
@@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases"
_WEIGHTS_VARIABLE_NAME = "weights"
-@contextlib.contextmanager
-def _checked_scope(cell, scope, reuse=None, **kwargs):
- if reuse is not None:
- kwargs["reuse"] = reuse
- with vs.variable_scope(scope, **kwargs) as checking_scope:
- scope_name = checking_scope.name
- if hasattr(cell, "_scope"):
- cell_scope = cell._scope # pylint: disable=protected-access
- if cell_scope.name != checking_scope.name:
- raise ValueError(
- "Attempt to reuse RNNCell %s with a different variable scope than "
- "its first use. First use of cell was with scope '%s', this "
- "attempt is with scope '%s'. Please create a new instance of the "
- "cell if you would like it to use a different set of weights. "
- "If before you were using: MultiRNNCell([%s(...)] * num_layers), "
- "change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). "
- "If before you were using the same cell instance as both the "
- "forward and reverse cell of a bidirectional RNN, simply create "
- "two instances (one for forward, one for reverse). "
- "In May 2017, we will start transitioning this cell's behavior "
- "to use existing stored weights, if any, when it is called "
- "with scope=None (which can lead to silent model degradation, so "
- "this error will remain until then.)"
- % (cell, cell_scope.name, scope_name, type(cell).__name__,
- type(cell).__name__))
- else:
- weights_found = False
- try:
- with vs.variable_scope(checking_scope, reuse=True):
- vs.get_variable(_WEIGHTS_VARIABLE_NAME)
- weights_found = True
- except ValueError:
- pass
- if weights_found and reuse is None:
- raise ValueError(
- "Attempt to have a second RNNCell use the weights of a variable "
- "scope that already has weights: '%s'; and the cell was not "
- "constructed as %s(..., reuse=True). "
- "To share the weights of an RNNCell, simply "
- "reuse it in your second calculation, or create a new one with "
- "the argument reuse=True." % (scope_name, type(cell).__name__))
-
- # Everything is OK. Update the cell's scope and yield it.
- cell._scope = checking_scope # pylint: disable=protected-access
- yield checking_scope
-
-
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index df36dd2bf9..9672b8b85f 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
-_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access
-
-
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
new file mode 100644
index 0000000000..1e29114b0c
--- /dev/null
+++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
@@ -0,0 +1,231 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Convert checkpoints using RNNCells to new name convention.
+
+Usage:
+
+ python checkpoint_convert [--write_v1_checkpoint] \
+ '/path/to/checkpoint' '/path/to/new_checkpoint'
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import collections
+import re
+import sys
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import app
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saver as saver_lib
+
+_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
+ ############################################################################
+ # contrib/rnn/python/ops/core_rnn_cell_impl.py
+ # BasicRNNCell
+ ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
+ ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
+ # GRUCell
+ ('gru_cell/weights', 'gru_cell/kernel'),
+ ('gru_cell/biases', 'gru_cell/bias'),
+ ('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
+ ('gru_cell/gates/biases', 'gru_cell/gates/bias'),
+ ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
+ ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
+ # BasicLSTMCell
+ ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
+ ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
+ # LSTMCell
+ ('lstm_cell/weights', 'lstm_cell/kernel'),
+ ('lstm_cell/biases', 'lstm_cell/bias'),
+ ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
+ ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
+ # OutputProjectionWrapper
+ ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
+ ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
+ # InputProjectionWrapper
+ ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
+ ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
+ ############################################################################
+ # contrib/rnn/python/ops/lstm_ops.py
+ # LSTMBlockFusedCell ??
+ ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
+ ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
+ ############################################################################
+ # contrib/rnn/python/ops/rnn_cell.py
+ # LayerNormBasicLSTMCell
+ ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
+ ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
+ # UGRNNCell, not found in g3, but still need it?
+ ('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
+ ('ugrnn_cell/biases', 'ugrnn_cell/bias'),
+ # NASCell
+ ('nas_rnn/weights', 'nas_rnn/kernel'),
+ ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
+ # IntersectionRNNCell
+ ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
+ ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
+ ('intersection_rnn_cell/in_projection/weights',
+ 'intersection_rnn_cell/in_projection/kernel'),
+ ('intersection_rnn_cell/in_projection/biases',
+ 'intersection_rnn_cell/in_projection/bias'),
+ # PhasedLSTMCell
+ ('phased_lstm_cell/mask_gates/weights',
+ 'phased_lstm_cell/mask_gates/kernel'),
+ ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
+ ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
+ ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
+ ('phased_lstm_cell/output_gate/weights',
+ 'phased_lstm_cell/output_gate/kernel'),
+ ('phased_lstm_cell/output_gate/biases',
+ 'phased_lstm_cell/output_gate/bias'),
+ # AttentionCellWrapper
+ ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
+ ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
+ ('attention_cell_wrapper/attn_output_projection/weights',
+ 'attention_cell_wrapper/attn_output_projection/kernel'),
+ ('attention_cell_wrapper/attn_output_projection/biases',
+ 'attention_cell_wrapper/attn_output_projection/bias'),
+ ('attention_cell_wrapper/attention/weights',
+ 'attention_cell_wrapper/attention/kernel'),
+ ('attention_cell_wrapper/attention/biases',
+ 'attention_cell_wrapper/attention/bias'),
+])
+
+_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
+ ('LSTMCell/W_', 'lstm_cell/weights/part_'),
+ ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
+ ('GRUCell/W_', 'gru_cell/weights/part_'),
+ ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
+])
+
+
+def _rnn_name_replacement(var_name):
+ for pattern in _RNN_NAME_REPLACEMENTS:
+ if pattern in var_name:
+ old_var_name = var_name
+ var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
+ logging.info('Converted: %s --> %s' % (old_var_name, var_name))
+ break
+ return var_name
+
+
+def _rnn_name_replacement_sharded(var_name):
+ for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
+ if pattern in var_name:
+ old_var_name = var_name
+ var_name = var_name.replace(pattern,
+ _RNN_SHARDED_NAME_REPLACEMENTS[pattern])
+ logging.info('Converted: %s --> %s' % (old_var_name, var_name))
+ return var_name
+
+
+def _split_sharded_vars(name_shape_map):
+ """Split shareded variables.
+
+ Args:
+ name_shape_map: A dict from variable name to variable shape.
+
+ Returns:
+ not_sharded: Names of the non-sharded variables.
+ sharded: Names of the sharded varibales.
+ """
+ sharded = []
+ not_sharded = []
+ for name in name_shape_map:
+ if re.match(name, '_[0-9]+$'):
+ if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
+ sharded.append(name)
+ else:
+ not_sharded.append(name)
+ else:
+ not_sharded.append(name)
+ return not_sharded, sharded
+
+
+def convert_names(checkpoint_from_path,
+ checkpoint_to_path,
+ write_v1_checkpoint=False):
+ """Migrates the names of variables within a checkpoint.
+
+ Args:
+ checkpoint_from_path: Path to source checkpoint to be read in.
+ checkpoint_to_path: Path to checkpoint to be written out.
+ write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
+
+ Returns:
+ A dictionary that maps the new variable names to the Variable objects.
+ A dictionary that maps the old variable names to the new variable names.
+ """
+ with ops.Graph().as_default():
+ logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
+ reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
+ name_shape_map = reader.get_variable_to_shape_map()
+ not_sharded, sharded = _split_sharded_vars(name_shape_map)
+ new_variable_map = {}
+ conversion_map = {}
+ for var_name in not_sharded:
+ new_var_name = _rnn_name_replacement(var_name)
+ tensor = reader.get_tensor(var_name)
+ var = variables.Variable(tensor, name=var_name)
+ new_variable_map[new_var_name] = var
+ if new_var_name != var_name:
+ conversion_map[var_name] = new_var_name
+ for var_name in sharded:
+ new_var_name = _rnn_name_replacement_sharded(var_name)
+ var = variables.Variable(tensor, name=var_name)
+ new_variable_map[new_var_name] = var
+ if new_var_name != var_name:
+ conversion_map[var_name] = new_var_name
+
+ write_version = (saver_pb2.SaverDef.V1
+ if write_v1_checkpoint else saver_pb2.SaverDef.V2)
+ saver = saver_lib.Saver(new_variable_map, write_version=write_version)
+
+ with session.Session() as sess:
+ sess.run(variables.global_variables_initializer())
+ logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
+ saver.save(sess, checkpoint_to_path)
+
+ logging.info('Summary:')
+ logging.info(' Converted %d variable name(s).' % len(new_variable_map))
+ return new_variable_map, conversion_map
+
+
+def main(_):
+ convert_names(
+ FLAGS.checkpoint_from_path,
+ FLAGS.checkpoint_to_path,
+ write_v1_checkpoint=FLAGS.write_v1_checkpoint)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument('checkpoint_from_path', type=str,
+ help='Path to source checkpoint to be read in.')
+ parser.add_argument('checkpoint_to_path', type=str,
+ help='Path to checkpoint to be written out.')
+ parser.add_argument('--write_v1_checkpoint', action='store_true',
+ help='Write v1 checkpoint')
+ FLAGS, unparsed = parser.parse_known_args()
+
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
new file mode 100644
index 0000000000..e2fc2fa80e
--- /dev/null
+++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
@@ -0,0 +1,108 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for checkpoint converter."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import os
+import tempfile
+
+from tensorflow.contrib.rnn.python.tools import checkpoint_convert
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class CheckpointConvertTest(test.TestCase):
+
+ def setUp(self):
+ self._old_ckpt_path = tempfile.mktemp()
+ self._new_ckpt_path = tempfile.mktemp()
+ ops.reset_default_graph()
+
+ def tearDown(self):
+ for file_name in glob.glob(self._old_ckpt_path + "*"):
+ os.remove(file_name)
+ for file_name in glob.glob(self._new_ckpt_path + "*"):
+ os.remove(file_name)
+
+ def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
+ for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
+ new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
+ self.assertTrue(old_name)
+ self.assertTrue(new_name)
+ self.assertNotEqual(old_name, new_name)
+ for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
+ new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
+ self.assertTrue(old_name)
+ self.assertTrue(new_name)
+ self.assertNotEqual(old_name, new_name)
+
+ def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
+ variables.Variable(10.0, name="a")
+ for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
+ variables.Variable(20.0, name=old_name)
+ with session.Session() as sess:
+ saver = saver_lib.Saver()
+ sess.run(variables.global_variables_initializer())
+ saver.save(sess, self._old_ckpt_path)
+
+ new_var_map, conversion_map = checkpoint_convert.convert_names(
+ self._old_ckpt_path, self._new_ckpt_path)
+ self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
+ self.assertItemsEqual(
+ ["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
+ new_var_map.keys())
+ self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
+
+ def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
+ variables.Variable(10.0, name="a")
+ with session.Session() as sess:
+ saver = saver_lib.Saver()
+ sess.run(variables.global_variables_initializer())
+ saver.save(sess, self._old_ckpt_path)
+
+ new_var_map, conversion_map = checkpoint_convert.convert_names(
+ self._old_ckpt_path, self._new_ckpt_path)
+ self.assertItemsEqual(["a"], new_var_map.keys())
+ self.assertFalse(conversion_map)
+
+ def testConversionToV1Succeeds(self):
+ variables.Variable(10.0, name="a")
+ variables.Variable(
+ 20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
+
+ with session.Session() as sess:
+ saver = saver_lib.Saver()
+ sess.run(variables.global_variables_initializer())
+ saver.save(sess, self._old_ckpt_path)
+
+ new_var_map, conversion_map = checkpoint_convert.convert_names(
+ self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
+ self.assertItemsEqual(
+ ["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
+ new_var_map.keys())
+ self.assertEqual(
+ {list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
+ list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
+ conversion_map)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py
index 5ced8a4f08..b70d612f55 100644
--- a/tensorflow/contrib/slim/python/slim/learning.py
+++ b/tensorflow/contrib/slim/python/slim/learning.py
@@ -261,7 +261,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@@ -657,7 +657,7 @@ def train(train_op,
if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group(
tf_variables.local_variables_initializer(),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance(
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4cfdf844ce..14deffc71b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto",
"lib/core/error_codes.proto",
"protobuf/config.proto",
+ "protobuf/cluster.proto",
"protobuf/debug.proto",
"protobuf/queue_runner.proto",
"protobuf/rewriter_config.proto",
@@ -506,6 +507,7 @@ tf_gen_op_libs(
"image_ops",
"io_ops",
"linalg_ops",
+ "lookup_ops",
"logging_ops",
"math_ops",
"nn_ops",
@@ -582,6 +584,7 @@ cc_library(
":image_ops_op_lib",
":io_ops_op_lib",
":linalg_ops_op_lib",
+ ":lookup_ops_op_lib",
":logging_ops_op_lib",
":math_ops_op_lib",
":nn_ops_op_lib",
@@ -708,6 +711,7 @@ cc_library(
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:linalg",
+ "//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op",
diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc
index 78649afeb9..aa8a2d989b 100644
--- a/tensorflow/core/common_runtime/device.cc
+++ b/tensorflow/core/common_runtime/device.cc
@@ -23,8 +23,7 @@ limitations under the License.
namespace tensorflow {
-Device::Device(Env* env, const DeviceAttributes& device_attributes,
- Allocator* device_allocator)
+Device::Device(Env* env, const DeviceAttributes& device_attributes)
: DeviceBase(env), device_attributes_(device_attributes) {
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
<< "Invalid device name: " << name();
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 07c6bdd683..c0e58f143e 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -53,8 +53,7 @@ namespace tensorflow {
class Device : public DeviceBase {
public:
- Device(Env* env, const DeviceAttributes& device_attributes,
- Allocator* device_allocator);
+ Device(Env* env, const DeviceAttributes& device_attributes);
~Device() override;
// Full name of this device (see top comment).
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index 7807656cb2..31f12d4833 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
for (Device* d : devices) {
devices_.push_back(d);
- // Register under both the full name and the local name.
+ // Register under the (1) full name, (2) canonical name, and (3) local name.
string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d;
+ DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
+ if (parsed_name.has_job && parsed_name.has_replica &&
+ parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
+ string canonical_name = DeviceNameUtils::FullName(
+ parsed_name.job, parsed_name.replica, parsed_name.task,
+ parsed_name.type, parsed_name.id);
+ device_map_[CopyToBackingStore(canonical_name)] = d;
+ }
string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++;
@@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
}
DeviceMgr::~DeviceMgr() {
- for (auto p : devices_) delete p;
+ // TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
+ for (Device* p : devices_) delete p;
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
@@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s;
auto iter = device_map_.find(name);
if (iter == device_map_.end()) {
+ std::vector<StringPiece> device_names;
+ for (auto&& itr : device_map_) {
+ device_names.push_back(itr.first);
+ }
+ LOG(WARNING) << "Unknown device: " << name
+ << " all devices: " << str_util::Join(device_names, ", ");
return errors::InvalidArgument(name, " unknown device.");
}
*device = iter->second;
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index bb1ed72640..d16681ac59 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -36,6 +36,7 @@ class DeviceMgr {
public:
// Takes ownership of each device in 'devices'.
// TODO(zhifengc): Other initialization information.
+ // TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr();
@@ -61,6 +62,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const;
private:
+ // TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
index b0540dfa95..4cd56e583c 100644
--- a/tensorflow/core/common_runtime/device_set.h
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -39,7 +39,10 @@ class DeviceSet {
// Set the device designated as the "client". This device
// must also be registered via AddDevice().
- void set_client_device(Device* device) { client_device_ = device; }
+ void set_client_device(Device* device) {
+ DCHECK(client_device_ == nullptr);
+ client_device_ = device;
+ }
// Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; }
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
index ff20ee94a7..0507076c8c 100644
--- a/tensorflow/core/common_runtime/device_set_test.cc
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -27,8 +27,7 @@ namespace {
static Device* Dev(const char* type, const char* name) {
class FakeDevice : public Device {
public:
- explicit FakeDevice(const DeviceAttributes& attr)
- : Device(nullptr, attr, nullptr) {}
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0e2343cfe3..02f70d835d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
int gpu_id, const string& physical_device_desc,
Allocator* gpu_allocator, Allocator* cpu_allocator,
bool sync_every_op, int32 max_streams)
- : LocalDevice(options,
- Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit,
- locality, physical_device_desc),
- gpu_allocator),
+ : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
+ memory_limit, locality,
+ physical_device_desc)),
gpu_allocator_(gpu_allocator),
cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id),
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index 0a6342ed73..3f7c9f68db 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
};
LocalDevice::LocalDevice(const SessionOptions& options,
- const DeviceAttributes& attributes,
- Allocator* device_allocator)
- : Device(options.env, attributes, device_allocator),
- owned_tp_info_(nullptr) {
+ const DeviceAttributes& attributes)
+ : Device(options.env, attributes), owned_tp_info_(nullptr) {
// If we're running on the CPU, log warnings if we're not compiled using the
// best flags for performance.
port::WarnAboutUnusedCPUFeatures();
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index d1c27c6248..84a4f66db4 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -33,8 +33,8 @@ struct SessionOptions;
// GPUDevice into more 'process-wide' abstractions.
class LocalDevice : public Device {
public:
- LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
- Allocator* device_allocator);
+ LocalDevice(const SessionOptions& options,
+ const DeviceAttributes& attributes);
~LocalDevice() override;
private:
diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc
new file mode 100644
index 0000000000..fa9713735e
--- /dev/null
+++ b/tensorflow/core/common_runtime/renamed_device.cc
@@ -0,0 +1,54 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/renamed_device.h"
+
+namespace tensorflow {
+
+// TODO(saeta): Convert to returning a std::unique_ptr?
+/* static */
+Device* RenamedDevice::NewRenamedDevice(const string& new_base,
+ Device* underlying,
+ bool owns_underlying) {
+ DeviceNameUtils::ParsedName parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
+ DeviceNameUtils::ParsedName underlying_parsed_name =
+ underlying->parsed_name();
+ CHECK(underlying_parsed_name.has_type);
+ CHECK(underlying_parsed_name.has_id);
+ parsed_name.type = underlying_parsed_name.type;
+ parsed_name.id = underlying_parsed_name.id;
+ string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
+ parsed_name.task, parsed_name.type,
+ parsed_name.id);
+ DeviceAttributes attributes(underlying->attributes());
+ attributes.set_name(name);
+ return new RenamedDevice(underlying, attributes, owns_underlying);
+}
+
+RenamedDevice::RenamedDevice(Device* underlying,
+ const DeviceAttributes& attributes,
+ bool owns_underlying)
+ : Device(underlying->env(), attributes),
+ underlying_(underlying),
+ owns_underlying_(owns_underlying) {}
+
+RenamedDevice::~RenamedDevice() {
+ if (owns_underlying_) {
+ delete underlying_;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
new file mode 100644
index 0000000000..0158e18ced
--- /dev/null
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -0,0 +1,119 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// Wraps a device with a new name, delegating work to the wrapped device.
+//
+// This class is used to wrap local devices when using clusterspec propagation
+// where the name of a particular device may change in the context of a given
+// session.
+class RenamedDevice : public Device {
+ public:
+ static Device* NewRenamedDevice(const string& new_base, Device* underlying,
+ bool owns_underlying);
+ ~RenamedDevice() override;
+
+ // Below are virtual methods defined on DeviceBase
+ bool RequiresRecordingAccessedTensors() const override {
+ return underlying_->RequiresRecordingAccessedTensors();
+ }
+
+ const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
+ return underlying_->tensorflow_cpu_worker_threads();
+ }
+
+ const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
+ return underlying_->tensorflow_gpu_device_info();
+ }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ return underlying_->GetAllocator(attr);
+ }
+
+ Allocator* GetStepAllocator(AllocatorAttributes attr,
+ ResourceMgr* step_resource_manager) override {
+ return underlying_->GetStepAllocator(attr, step_resource_manager);
+ }
+
+ const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
+ return underlying_->eigen_cpu_device();
+ }
+
+#ifdef TENSORFLOW_USE_SYCL
+ const Eigen::SyclDevice* eigen_sycl_device() const override {
+ return underlying_->eigen_sycl_device();
+ }
+#endif
+
+ PerOpGpuDevice* MakeGpuDevice() override {
+ return underlying_->MakeGpuDevice();
+ }
+
+ void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc, Allocator* allocator) override {
+ underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ }
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override {
+ return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
+ }
+
+ // Below are virtual methods defined on Device
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+ underlying_->Compute(op_kernel, context);
+ }
+
+ void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override {
+ underlying_->ComputeAsync(op_kernel, context, std::move(done));
+ }
+
+ void ConsumeListOfAccessedTensors(
+ DeviceContext* context, const TensorReferenceVector& tensors) override {
+ underlying_->ConsumeListOfAccessedTensors(context, tensors);
+ }
+
+ Status Sync() override { return underlying_->Sync(); }
+
+ Status MaybeRewriteGraph(const FunctionDefLibrary& library,
+ std::unique_ptr<Graph>* graph) override {
+ return underlying_->MaybeRewriteGraph(library, graph);
+ }
+
+ Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) override {
+ return underlying_->FillContextMap(graph, device_context_map);
+ }
+
+ private:
+ RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
+ bool owns_underlying);
+ Device* const underlying_;
+ const bool owns_underlying_;
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index bd84417b10..24f27af5f1 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
class FakeDevice : public Device {
private:
explicit FakeDevice(const DeviceAttributes& device_attributes)
- : Device(nullptr, device_attributes, nullptr) {}
+ : Device(nullptr, device_attributes) {}
public:
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 60348e885f..f5f8aab694 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& locality,
Allocator* allocator)
- : LocalDevice(options,
- Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit,
- locality),
- allocator),
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator) {}
ThreadPoolDevice::~ThreadPoolDevice() {}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 0f5eb0cb32..d2a828f39f 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -77,7 +77,6 @@ cc_library(
],
deps = [
":graph_mgr",
- ":rendezvous_mgr_interface",
":worker_cache",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
@@ -92,9 +91,9 @@ cc_library(
deps = [
":graph_mgr",
":worker_session",
+ "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@@ -237,6 +236,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
],
)
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 5863727f19..e68aea46ec 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -35,9 +35,8 @@ limitations under the License.
namespace tensorflow {
-BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env,
- const string& worker_name)
- : worker_env_(worker_env), worker_name_(worker_name) {}
+BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
+ : worker_env_(worker_env) {}
BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) {
@@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
}
}
-Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
+RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id);
}
@@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) {
- auto rr = Create(step_id, worker_env_, worker_name_);
+ auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first;
}
iter->second->Ref();
@@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
}
}
-BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
- const string& worker_name,
- int64 step_id,
+BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv)
: env_(env),
- worker_name_(worker_name),
step_id_(step_id),
- local_(NewLocalRendezvous(tolerate_dup_recv)) {}
+ local_(NewLocalRendezvous(tolerate_dup_recv)),
+ session_(nullptr) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty());
@@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
return device_name.starts_with(worker_name);
}
+Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
+ CHECK_NE(session, nullptr) << "session must not be null!";
+ std::vector<DeferredCall> deferred_calls;
+ {
+ mutex_lock l(mu_);
+ if (session_ != nullptr) {
+ if (session_->worker_name == session->worker_name) {
+ LOG(INFO) << "Skipping rendezvous re-initialization.";
+ return Status::OK();
+ }
+ Status s = errors::Internal(
+ "Double init! Worker names would have changed from: ",
+ session_->worker_name, " -> ", session->worker_name);
+ LOG(WARNING) << s;
+ return s;
+ }
+ session_ = session;
+ std::swap(deferred_calls, deferred_calls_);
+ }
+ for (DeferredCall& call : deferred_calls) {
+ RecvLocalAsyncInternal(call.parsed, std::move(call.done));
+ }
+ return Status::OK();
+}
+
+WorkerSession* BaseRemoteRendezvous::session() {
+ mutex_lock l(mu_);
+ return session_;
+}
+
+bool BaseRemoteRendezvous::is_initialized() {
+ mutex_lock l(mu_);
+ return is_initialized_locked();
+}
+
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
@@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
- }
- if (!IsLocalDevice(worker_name_, parsed.src_device)) {
- return errors::InvalidArgument("Invalid rendezvous key (src): ",
- parsed.FullKey(), " @ ", worker_name_);
+ DCHECK(is_initialized_locked());
+ if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
+ return errors::InvalidArgument(
+ "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
+ session_->worker_name);
+ }
}
// Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead);
@@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) {
+ // Cache session pointer to avoid repeatedly taking & releasing the lock
+ // (e.g. calling session())
+ WorkerSession* sess = nullptr;
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
+ if (!is_initialized_locked()) {
+ return errors::Internal("ValidateDevices called before initialization.");
+ }
+ sess = session_;
}
- if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) {
+ if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
- parsed.FullKey(), " @ ", worker_name_);
+ parsed.FullKey(), " @ ", sess->worker_name);
}
- if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) {
+ if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
- parsed.FullKey(), " @ ", worker_name_);
+ parsed.FullKey(), " @ ", sess->worker_name);
}
return Status::OK();
}
@@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
+ CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
@@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) {
+ {
+ mutex_lock l(mu_);
+ if (!is_initialized_locked()) {
+ // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
+ // remote worker) before the RunStep (or PartialRunStep) RPC from the
+ // master arrives. RecvLocalAsync thus buffers the arguments until after
+ // the RemoteRendezvous is Initialize()'d, when it completes the
+ // rendezvous logic. At some point after Initialize() is called, a Tensor
+ // is produced locally that will then be sent in response to the incoming
+ // RPC.
+ DeferredCall call(parsed, std::move(done));
+ deferred_calls_.push_back(call);
+ return;
+ }
+ }
+ RecvLocalAsyncInternal(parsed, std::move(done));
+}
+
+void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
+ DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
@@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
active_.erase(call);
}
+BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
+ DoneCallback done)
+ : parsed(parsed), done(std::move(done)) {}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
index 447a75913d..b252f45fe9 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
@@ -59,15 +59,17 @@ class BaseRecvTensorCall;
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface {
public:
- explicit BaseRendezvousMgr(const WorkerEnv* worker_env,
- const string& worker_name);
+ explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
- Rendezvous* Find(int64 step_id) override;
+ //
+ // Note: the caller must guarantee to eventually call Initialize on the
+ // returned RemoteRendezvous
+ RemoteRendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
@@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
protected:
virtual BaseRemoteRendezvous* Create(int64 step_id,
- const WorkerEnv* worker_env,
- const string& worker_name) = 0;
+ const WorkerEnv* worker_env) = 0;
private:
// Maps step_id to rendezvous.
@@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Not owned.
const WorkerEnv* const worker_env_;
- const string worker_name_;
mutex mu_;
Table table_ GUARDED_BY(mu_);
@@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers.
-class BaseRemoteRendezvous : public Rendezvous {
+class BaseRemoteRendezvous : public RemoteRendezvous {
public:
- BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
- int64 step_id, bool tolerate_dup_recv);
+ BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
+ bool tolerate_dup_recv);
+
+ // Upgrades the BaseRemoteRendezvous to full initialization.
+ Status Initialize(WorkerSession* session) override;
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
@@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
// Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call);
+ WorkerSession* session();
+
+ bool is_initialized();
+
~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned.
- const string worker_name_;
const int64 step_id_;
private:
@@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
// Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_);
+ WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
+
+ // Data structures to handle calls when partially initialized.
+ struct DeferredCall {
+ const ParsedKey parsed;
+ DoneCallback done;
+
+ DeferredCall(const ParsedKey& parsed, DoneCallback done);
+ };
+ std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls.
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
+ bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return session_ != nullptr;
+ }
+
// If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process.
@@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done);
+ // Must be called only if fully initialized.
+ void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
+
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
};
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index ce7ce372e8..5bde771e8d 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -46,10 +46,8 @@ limitations under the License.
namespace tensorflow {
-GraphMgr::GraphMgr(const WorkerEnv* worker_env,
- RendezvousMgrInterface* rendezvous_mgr)
- : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
- CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
+GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
+ : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
Status status =
@@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
};
popts.get_incarnation = [this](const string& name) -> int64 {
Device* device = nullptr;
- Status s = worker_env_->device_mgr->LookupDevice(name, &device);
+ Status s = device_mgr_->LookupDevice(name, &device);
if (s.ok()) {
return device->attributes().incarnation();
} else {
@@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
ExecutionUnit* unit = &(item->units.back());
// Find the device.
- Status s =
- worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
+ Status s = device_mgr_->LookupDevice(device_name, &unit->device);
if (!s.ok()) {
// Remove the empty unit from the item as the item destructor wants all
// units to have valid devices.
@@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
// Function library runtime.
unit->lib = NewFunctionLibraryRuntime(
- worker_env_->device_mgr, worker_env_->env, unit->device,
+ device_mgr_, worker_env_->env, unit->device,
subgraph->versions().producer(), item->lib_def,
graph_options.optimizer_options());
@@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
rendezvous->Unref();
return s;
}
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out);
rendezvous->Unref();
return s;
@@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) {
rendezvous->Unref();
@@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
- const ExecutorOpts& opts,
+ WorkerSession* session,
+ const ExecutorOpts& /*opts*/,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
+ Status s = rendezvous->Initialize(session);
// Sends values specified by the caller.
- Status s = SendInputsToRendezvous(rendezvous, in);
+ if (s.ok()) {
+ s = SendInputsToRendezvous(rendezvous, in);
+ }
+
if (!s.ok()) {
done(s);
item->Unref();
@@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
StatusCallback done) {
const int num_units = item->units.size();
CHECK_GE(num_units, 1);
- ScopedStepContainer* step_container =
- new ScopedStepContainer(step_id, [this](const string& name) {
- worker_env_->device_mgr->ClearContainers({name});
- });
+ ScopedStepContainer* step_container = new ScopedStepContainer(
+ step_id,
+ [this](const string& name) { device_mgr_->ClearContainers({name}); });
// NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier =
new ExecutorBarrier(num_units, rendezvous,
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index 349af6c54e..50391f47e4 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -37,6 +37,8 @@ namespace tensorflow {
class ExecutorOpts;
class StepStatsCollector;
class RendezvousMgrInterface;
+class DeviceMgr;
+struct WorkerSession;
// GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle
@@ -62,8 +64,7 @@ class RendezvousMgrInterface;
// EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr {
public:
- explicit GraphMgr(const WorkerEnv* worker_env,
- RendezvousMgrInterface* rendezvous_mgr);
+ explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
~GraphMgr();
// Registers a graph. Fills in "handle"
@@ -78,8 +79,8 @@ class GraphMgr {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
- const ExecutorOpts& opts, StepStatsCollector* collector,
- CostGraphDef* cost_graph,
+ WorkerSession* session, const ExecutorOpts& opts,
+ StepStatsCollector* collector, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done);
@@ -131,7 +132,7 @@ class GraphMgr {
};
const WorkerEnv* worker_env_; // Not owned.
- RendezvousMgrInterface* rendezvous_mgr_; // Not owned.
+ DeviceMgr* device_mgr_;
CostModelManager cost_model_manager_;
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index b4adee3bf6..e860c99d95 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@@ -48,12 +49,17 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
+namespace {
+const char* const kGrpcProtocol = "grpc://";
+} // namespace
+
Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env),
last_1000_steps_(1000),
@@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
Status status;
+ WorkerCacheFactoryOptions worker_cache_factory_options;
+ string grpc_protocol("grpc");
+ worker_cache_factory_options.protocol = &grpc_protocol;
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
status = ValidateExternalGraphDefSyntax(req->graph_def());
if (!status.ok()) return;
- // Ping all the workers and build the list of devices that the
- // session will use.
+
+ // The following 4 variables are set differently, depending on whether this
+ // session uses a client-provided clusterspec or not.
+ WorkerCacheInterface* worker_cache = nullptr;
+ // Note: worker_cache_ptr will be null except if this session is using a
+ // client-supplied ClusterDef (ClusterSpec propagation).
+ std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
+ std::unique_ptr<DeviceSet> device_set;
// TODO(saeta): Convert to std::make_unique when available.
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
new std::vector<std::unique_ptr<Device>>());
- status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
- env_, env_->worker_cache,
- remote_devices.get());
- if (!status.ok()) return;
+
+ if (req->config().has_cluster_def()) {
+ worker_cache_factory_options.cluster_def = &req->config().cluster_def();
+
+ // Set the server_def's job_name and task_index fields.
+ string normalized_string;
+ string grpc_protocol(kGrpcProtocol);
+ if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
+ 0) {
+ normalized_string =
+ req->target().substr(grpc_protocol.length(), string::npos);
+ } else {
+ normalized_string = req->target();
+ }
+ for (auto&& job : req->config().cluster_def().job()) {
+ for (auto&& task : job.tasks()) {
+ if (task.second == normalized_string) {
+ if (worker_cache_factory_options.job_name != nullptr) {
+ status = errors::InvalidArgument(
+ "Found multiple matching tasks that correspond to "
+ "to the master. Master target: '",
+ req->target(), "'. ClusterDef: ",
+ req->config().cluster_def().ShortDebugString());
+ LOG(ERROR) << status;
+ return;
+ }
+ if (env_->local_devices[0]->parsed_name().job == job.name() &&
+ env_->local_devices[0]->parsed_name().task == task.first) {
+ // TODO(b/37868888): Remove this limitation when resolved
+ status = errors::InvalidArgument(
+ "The ClusterSpec names the job and task index to be the same "
+ "names that were provided when the server booted. This is "
+ "currently not allowed. Job: ",
+ job.name(), ", task index: ", task.first);
+ return;
+ }
+ worker_cache_factory_options.job_name = &job.name();
+ worker_cache_factory_options.task_index = task.first;
+ }
+ }
+ }
+
+ // Create the worker cache from the computed server_def.
+ status = env_->worker_cache_factory(worker_cache_factory_options,
+ &worker_cache);
+ if (!status.ok()) return;
+ worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
+ // Ping all the workers and build the list of devices that the
+ // session will use.
+ status =
+ DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
+ worker_cache, remote_devices.get());
+ if (!status.ok()) return;
+ device_set.reset(new DeviceSet);
+ for (auto&& d : *remote_devices) {
+ device_set->AddDevice(d.get());
+ DeviceNameUtils::ParsedName name = d->parsed_name();
+ if (name.job == *worker_cache_factory_options.job_name &&
+ name.task == worker_cache_factory_options.task_index &&
+ name.type == "CPU") {
+ device_set->set_client_device(d.get());
+ }
+ }
+ } else {
+ worker_cache = env_->worker_cache;
+ // Ping all the workers and build the list of devices that the
+ // session will use.
+ status =
+ DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
+ worker_cache, remote_devices.get());
+ if (!status.ok()) return;
+ device_set.reset(new DeviceSet);
+ for (auto&& d : *remote_devices) {
+ device_set->AddDevice(d.get());
+ }
+ int num_local_devices = 0;
+ for (Device* d : env_->local_devices) {
+ device_set->AddDevice(d);
+ if (num_local_devices == 0) {
+ // Uses the first local device as the client device.
+ device_set->set_client_device(d);
+ }
+ num_local_devices++;
+ }
+ }
+
+ CHECK(device_set->client_device());
+
SessionOptions options;
options.config = req->config();
- MasterSession* session =
- env_->master_session_factory(options, env_, std::move(remote_devices));
+
+ MasterSession* session = env_->master_session_factory(
+ options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
+ std::move(device_set));
+
GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
- status = session->Create(gdef);
+
+ status = session->Create(gdef, worker_cache_factory_options);
if (!status.ok()) {
session->Close().IgnoreError();
session->Unref();
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index a155bd384d..bb548adda1 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -19,17 +19,41 @@ limitations under the License.
#include <functional>
#include <vector>
-#include "tensorflow/core/distributed_runtime/master_session.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
+class DeviceSet;
class Env;
class MasterSession;
class OpRegistryInterface;
class WorkerCacheInterface;
+// Options passed to the worker_cache_factory function.
+struct WorkerCacheFactoryOptions {
+ const ClusterDef* cluster_def = nullptr;
+ const string* job_name = nullptr;
+ int task_index;
+ const string* protocol = nullptr;
+
+ WorkerCacheFactoryOptions() {}
+
+ // Construct from a ServerDef proto.
+ //
+ // Note: server_def must outlive WorkerCacheFactoryOptions!
+ WorkerCacheFactoryOptions(const ServerDef& server_def) {
+ if (server_def.has_cluster() && !server_def.job_name().empty()) {
+ cluster_def = &server_def.cluster();
+ job_name = &server_def.job_name();
+ task_index = server_def.task_index();
+ protocol = &server_def.protocol();
+ }
+ }
+};
+
// The master environment class, which holds a bag of pointers to
// per-master state.
//
@@ -57,8 +81,14 @@ struct MasterEnv {
// `MasterEnv*` is retained by the caller.
std::function<MasterSession*(
SessionOptions, MasterEnv*,
- std::unique_ptr<std::vector<std::unique_ptr<Device>>>)>
+ std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
+ std::unique_ptr<WorkerCacheInterface>,
+ std::unique_ptr<DeviceSet> device_set)>
master_session_factory;
+
+ std::function<Status(const WorkerCacheFactoryOptions&,
+ WorkerCacheInterface**)>
+ worker_cache_factory;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index f7b422b70e..50c5d90fc9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -162,7 +164,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(const PartitionOptions& popts,
- const FunctionDefLibrary& func_def_lib);
+ const FunctionLibraryDefinition& flib_def);
// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
@@ -273,7 +275,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
};
Status MasterSession::ReffedClientGraph::RegisterPartitions(
- const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) {
+ const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
{ // Ensure register once.
mu_.lock();
if (!init_started_) {
@@ -292,7 +294,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
graph_defs_for_publishing.push_back(&name_def.second);
}
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
- s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs));
+ s = DoRegisterPartitions(popts, flib_def.ToProto(),
+ std::move(graph_defs));
}
mu_.lock();
init_result_ = s;
@@ -527,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run);
}
+ c->req->set_session_handle(session_handle_);
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
@@ -870,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
// The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) {
Call* c = new Call;
+ c->req.set_session_handle(session_handle_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called.
@@ -972,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt),
env_(env),
handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)),
+ worker_cache_(std::move(worker_cache)),
+ devices_(std::move(device_set)),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
run_graphs_(5),
partial_run_graphs_(5) {
UpdateLastAccessTime();
+ CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size();
- for (auto&& d : *remote_devs_) {
- devices_.AddDevice(d.get());
- }
- int num_local_devices = 0;
- for (Device* d : env->local_devices) {
- devices_.AddDevice(d);
- if (num_local_devices == 0) {
- // Uses the first local device as the client device.
- devices_.set_client_device(d);
- }
- num_local_devices++;
- }
+
LOG(INFO) << "Start master session " << handle_
<< " with config: " << std::endl
<< session_opts_.config.DebugString();
@@ -1011,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros());
}
-Status MasterSession::Create(GraphDef* graph_def) {
+Status MasterSession::Create(GraphDef* graph_def,
+ const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the "
@@ -1019,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
}
- SimpleGraphExecutionStateOptions options;
- options.device_set = &devices_;
- options.session_options = &session_opts_;
+ SimpleGraphExecutionStateOptions execution_options;
+ execution_options.device_set = devices_.get();
+ execution_options.session_options = &session_opts_;
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
- graph_def, options, &execution_state_));
+ graph_def, execution_options, &execution_state_));
+ }
+ if (options.cluster_def != nullptr) {
+ return CreateWorkerSessions(options);
}
return Status::OK();
}
+Status MasterSession::CreateWorkerSessions(
+ const WorkerCacheFactoryOptions& options) {
+ CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
+ << "dynamic cluster membership.";
+ std::vector<string> worker_names;
+ worker_cache_->ListWorkers(&worker_names);
+
+ struct WorkerGroup {
+ // The worker name. (Not owned.)
+ const string* name;
+
+ // The worker referenced by name. (Not owned.)
+ WorkerInterface* worker = nullptr;
+
+ // Request and responses used for a given worker.
+ CreateWorkerSessionRequest request;
+ CreateWorkerSessionResponse response;
+ Status status = Status::OK();
+ };
+ BlockingCounter done(worker_names.size());
+ std::vector<WorkerGroup> workers(worker_names.size());
+
+ // Release the workers.
+ auto cleanup = gtl::MakeCleanup([this, &workers] {
+ for (auto&& worker_group : workers) {
+ if (worker_group.worker != nullptr) {
+ worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
+ }
+ }
+ });
+
+ Status status = Status::OK();
+ // Create all the workers & kick off the computations.
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ workers[i].name = &worker_names[i];
+ workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
+ workers[i].request.set_session_handle(handle_);
+ *workers[i].request.mutable_server_def()->mutable_cluster() =
+ *options.cluster_def;
+ workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+
+ DeviceNameUtils::ParsedName name;
+ if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
+ status = errors::Internal("Could not parse name ", worker_names[i]);
+ LOG(WARNING) << status;
+ return status;
+ }
+ if (!name.has_job || !name.has_task) {
+ status = errors::Internal("Incomplete worker name ", worker_names[i]);
+ LOG(WARNING) << status;
+ return status;
+ }
+
+ workers[i].request.mutable_server_def()->set_job_name(name.job);
+ workers[i].request.mutable_server_def()->set_task_index(name.task);
+ }
+
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ auto cb = [i, &workers, &done](const Status& s) {
+ workers[i].status = s;
+ done.DecrementCount();
+ };
+ workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
+ &workers[i].response, cb);
+ }
+
+ done.Wait();
+ for (size_t i = 0; i < workers.size(); ++i) {
+ status.Update(workers[i].status);
+ }
+ return status;
+}
+
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
@@ -1059,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK();
}
+WorkerCacheInterface* MasterSession::get_worker_cache() const {
+ if (worker_cache_) {
+ return worker_cache_.get();
+ }
+ return env_->worker_cache;
+}
+
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts);
@@ -1082,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
<< "\n";
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
+ WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial,
- env_->worker_cache);
-
+ worker_cache);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
@@ -1161,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
return errors::FailedPrecondition("Session is closed.");
}
++num_running_;
+ // Note: all code paths must eventually call MarkRunCompletion()
+ // in order to appropriate decrement the num_running_ counter.
}
Status status;
if (!req.partial_run_handle().empty()) {
@@ -1168,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
} else {
status = DoRunWithLocalExecution(opts, req, resp);
}
- {
- mutex_lock l(mu_);
- --num_running_;
- if (num_running_ == 0) {
- num_running_is_zero_.notify_all();
- }
- }
return status;
}
+// Decrements num_running_ and broadcasts if num_running_ is zero.
+void MasterSession::MarkRunCompletion() {
+ mutex_lock l(mu_);
+ --num_running_;
+ if (num_running_ == 0) {
+ num_running_is_zero_.notify_all();
+ }
+}
+
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so.
PartitionOptions popts;
@@ -1187,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.get_incarnation = [this](const string& name) -> int64 {
- Device* d = devices_.FindDeviceByName(name);
+ Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation;
} else {
@@ -1214,7 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
}
TF_RETURN_IF_ERROR(
- rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto()));
+ rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
return Status::OK();
}
@@ -1222,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr;
{
@@ -1320,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata());
+ cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
+ MarkRunCompletion();
});
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
@@ -1367,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
- VLOG(2) << "DoRunWithLocalExecution "
- << "req: " << req.DebugString();
+ VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare.
BuildGraphOptions bgopts;
@@ -1437,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
}
}
rcg->Ref();
- rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
+ cleanup.release(); // MarkRunCompletion called in done closure.
+ rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
+ MarkRunCompletion();
});
return s;
}
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index d47125be99..3acc5bc5f0 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
@@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
MasterSession(
const SessionOptions& options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close().
//
// After this method returns, `def` will no longer be valid.
- Status Create(GraphDef* def);
+ Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
// Returns the session handle.
const string& handle() const { return handle_; }
@@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
+ // The optional session-specific worker cluster.
+ // TODO(saeta): Convert to std::optional when available.
+ std::unique_ptr<WorkerCacheInterface> worker_cache_;
+ // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
+ WorkerCacheInterface* get_worker_cache() const;
+
// The device set used by this session.
- DeviceSet devices_;
+ std::unique_ptr<DeviceSet> devices_;
StatsPublisherFactory stats_publisher_factory_;
@@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
// Private dtor. The client must call Close().
virtual ~MasterSession();
+ // Creates sessions on all workers.
+ //
+ // If this session is operating using the new ClusterSpec propagation behavior
+ // call this method in order to propagate the cluster membership to all
+ // workers.
+ Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
+
Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
@@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
+ void MarkRunCompletion();
void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index 7b58feb93c..b077975ea5 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
+const string& InMemoryRunGraphRequest::session_handle() const {
+ return session_handle_;
+}
+
+void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
+ session_handle_ = handle;
+}
+
const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_;
}
@@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
+ proto_version_->set_session_handle(session_handle());
proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts();
@@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
return *proto_version_;
}
+const string& MutableProtoRunGraphRequest::session_handle() const {
+ return request_.session_handle();
+}
+
+void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
+ request_.set_session_handle(handle);
+}
+
const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle();
}
@@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
: request_(request) {}
+const string& ProtoRunGraphRequest::session_handle() const {
+ return request_->session_handle();
+}
+
const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle();
}
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 02516eabb4..795a6add0e 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
public:
virtual ~RunGraphRequestWrapper() {}
+ // The session handle used to register the graph. If empty, a single global
+ // namespace is used.
+ virtual const string& session_handle() const = 0;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
virtual const string& graph_handle() const = 0;
@@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
// See `RunGraphRequestWrapper` above for a description of the fields.
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public:
+ virtual void set_session_handle(const string& handle) = 0;
virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0;
@@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
+ void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_last_partial_run(bool is_last_partial_run) override;
private:
+ string session_handle_;
string graph_handle_;
int64 step_id_;
ExecutorOpts exec_opts_;
@@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
+ void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
ProtoRunGraphRequest(const RunGraphRequest* request);
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
index 9632e9c439..91c1fb99fe 100644
--- a/tensorflow/core/distributed_runtime/remote_device.cc
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector>
+
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
class RemoteDevice : public Device {
public:
RemoteDevice(Env* env, const DeviceAttributes& da)
- : Device(env, da, nullptr),
- local_dev_name_(GetLocalDeviceName(da.name())) {}
+ : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
@@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
GetStatusResponse resp;
};
Call* call = new Call;
- auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
+ auto cb = [env, worker_cache, worker_name, done, wi,
+ call](const Status& status) {
+ Status s = status;
std::vector<Device*> remote_devices;
+ auto cleanup = gtl::MakeCleanup(
+ [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
+ worker_cache->ReleaseWorker(worker_name, wi);
+ done(s, &remote_devices);
+ delete call;
+ });
if (s.ok()) {
+ DeviceNameUtils::ParsedName worker_name_parsed;
+ if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
+ !worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
+ !worker_name_parsed.has_task) {
+ s = errors::InvalidArgument("Could not parse worker name: ",
+ worker_name);
+ LOG(WARNING) << s;
+ return;
+ }
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
- auto d = new RemoteDevice(env, da);
- remote_devices.push_back(d);
+ DeviceNameUtils::ParsedName device_name_parsed;
+ CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
+ << "Device attribute name '" << da.name() << "' could not be "
+ << "parsed. Device Attribute: " << da.DebugString();
+ // Preserve the exact name, if possible.
+ // TODO(b/37868888): Simplify when legacy device name formats removed.
+ if (device_name_parsed.job == worker_name_parsed.job &&
+ device_name_parsed.replica == worker_name_parsed.replica &&
+ device_name_parsed.task == worker_name_parsed.task) {
+ auto d = new RemoteDevice(env, da);
+ remote_devices.push_back(d);
+ } else {
+ DeviceAttributes da_rewritten = da;
+ da_rewritten.set_name(DeviceNameUtils::FullName(
+ worker_name_parsed.job, worker_name_parsed.replica,
+ worker_name_parsed.task, device_name_parsed.type,
+ device_name_parsed.id));
+ auto d = new RemoteDevice(env, da_rewritten);
+ remote_devices.push_back(d);
+ }
}
}
- worker_cache->ReleaseWorker(worker_name, wi);
- done(s, &remote_devices);
- delete call;
};
wi->GetStatusAsync(&call->req, &call->resp, cb);
}
diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
index 04c1fc248e..43267d4362 100644
--- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
+++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
@@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow {
+struct WorkerSession;
+
+// RemoteRendezvous follow a 2-part initialization. First the objects are
+// constructed. Eventually, they will be initialized. Clients of the
+// RendezvousMgrInterface must guarantee to call Initialize on the returned
+// RemoteRendezvous eventually.
+//
+// Partially initialized RemoteRendezvous must respect the Rendezvous interface
+// (i.e. Send() must never block), however implementations are not expected to
+// actually perform the underlying operations until after the RemoteRendezvous
+// has been Initialize'd.
+class RemoteRendezvous : public Rendezvous {
+ public:
+ // Fully construct the RemoteRendezvous.
+ virtual Status Initialize(WorkerSession* session) = 0;
+};
+
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@@ -51,7 +68,10 @@ class RendezvousMgrInterface {
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
- virtual Rendezvous* Find(int64 step_id) = 0;
+ //
+ // Note: the caller must guarantee to eventually call Initialize on the
+ // returned RemoteRendezvous
+ virtual RemoteRendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 7160962b16..3867dd1f4d 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
};
// static utility function
-RendezvousMgrInterface* NewRpcRendezvousMgr(
- const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* worker_cache) {
- return new RpcRendezvousMgr(env, worker_name, worker_cache);
+RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
+ return new RpcRendezvousMgr(env);
}
} // namespace
@@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
+ // Shut down all outstanding rendezvous.
+ delete worker_env_.rendezvous_mgr;
+
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
@@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
// OpSegments.)
if (worker_env_.session_mgr != nullptr) {
delete worker_env_.session_mgr; // Deletes graph_mgr's.
+ } else {
+ // Note: session_mgr's legacy_session_ deletes device_mgr now.
+ delete worker_env_.device_mgr;
}
- delete worker_env_.device_mgr;
// Do not delete (as these are not owned by the server):
// - master_env_.env
@@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool
}
-Status GrpcServer::Init(ServiceInitFunction service_func,
- RendezvousMgrCreationFunction rendevous_mgr_func) {
+Status GrpcServer::Init(
+ ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
"/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices));
- worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices);
+ worker_env_.local_devices = master_env_.local_devices;
+ worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
+ worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
+ ? new RpcRendezvousMgr(&worker_env_)
+ : rendezvous_mgr_func(&worker_env_);
string unused;
string default_worker_name;
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
@@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
}
WorkerCacheInterface* worker_cache;
- TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache));
+ WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
+ TF_RETURN_IF_ERROR(
+ WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
// Set up worker environment.
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- rendevous_mgr_func == nullptr ?
- new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
- rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
- std::move(rendezvous_mgr),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
- return WorkerCacheFactory(server_def, worker_cache);
+ WorkerCacheFactoryOptions options(server_def);
+ return WorkerCacheFactory(options, worker_cache);
});
worker_env_.compute_pool = ComputePool(sess_opts);
@@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
master_env_.master_session_factory =
[config](
SessionOptions options, const MasterEnv* env,
- std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) {
+ std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs),
+ std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher);
};
+ master_env_.worker_cache_factory =
+ [this](const WorkerCacheFactoryOptions& options,
+ WorkerCacheInterface** worker_cache) {
+ return WorkerCacheFactory(options, worker_cache);
+ };
// Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(),
@@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK();
}
-Status GrpcServer::Init() {
- return Init(nullptr, nullptr);
-}
+Status GrpcServer::Init() { return Init(nullptr, nullptr); }
-Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
+Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
- for (const auto& job : server_def.cluster().job()) {
+ for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports;
for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first];
@@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
task.first, "\": ", host_port, " and ",
task.second);
}
- if (job.name() == server_def.job_name() &&
- task.first == server_def.task_index()) {
+ if (job.name() == *options.job_name && task.first == options.task_index) {
host_port = strings::StrCat("localhost:", bound_port_);
} else {
host_port = task.second;
@@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
return Status::OK();
}
-Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
+Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
- string name_prefix =
- strings::StrCat("/job:", server_def.job_name(), "/replica:0",
- "/task:", server_def.task_index());
+ if (options.job_name == nullptr || options.job_name->empty()) {
+ Status s = errors::InvalidArgument(
+ "The master (current machine) is not included in the provided "
+ "cluster_def. ",
+ options.cluster_def->DebugString());
+ LOG(WARNING) << s;
+ return s;
+ }
GrpcChannelSpec channel_spec;
- TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
+ TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
+
+ std::unique_ptr<GrpcChannelCache> channel_cache(
+ NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
+
+ string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
+ "/task:", options.task_index);
- std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
- channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
@@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
return ::grpc::InsecureServerCredentials();
}
-ChannelCreationFunction GrpcServer::GetChannelCreationFunction(
- const ServerDef& server_def) const {
+ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 3b66291a9a..7b54bb84c8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -37,9 +37,7 @@ class GrpcWorker;
class Master;
// function that creates a RendezvousMgr.
-typedef std::function<RendezvousMgrInterface*(
- const WorkerEnv*, const std::string& worker_name,
- WorkerCacheInterface* worker_cache)>
+typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to
@@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
protected:
Status Init(ServiceInitFunction service_func,
- RendezvousMgrCreationFunction rendezvous_mgr_func);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init();
@@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const;
- virtual ChannelCreationFunction GetChannelCreationFunction(
- const ServerDef& server_def) const;
+ virtual ChannelCreationFunction GetChannelCreationFunction() const;
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session.
- Status WorkerCacheFactory(const ServerDef& server_def,
+ Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache);
- // Parses a ServerDef into a GrpcChannelSpec.
- Status ParseChannelSpec(const ServerDef& server_def,
+ // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
+ Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec);
// Returns the port to which this server is bound.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 1aacef8a26..38d59d5bb5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
/* static */
Status GrpcSession::Create(const SessionOptions& options,
std::unique_ptr<GrpcSession>* out_session) {
- std::unique_ptr<GrpcSession> ret(new GrpcSession(options));
+ std::unique_ptr<GrpcSession> session(new GrpcSession(options));
std::unique_ptr<MasterInterface> master;
// For testing, we enable the client to disable the use of the local
// master registry, so that the RPC stack is exercised.
@@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
options.target.substr(kSchemePrefixLength), &master_channel));
master.reset(NewGrpcMaster(master_channel));
}
- ret->SetRemoteMaster(std::move(master));
- *out_session = std::move(ret);
+ session->SetRemoteMaster(std::move(master));
+ *out_session = std::move(session);
return Status::OK();
}
@@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
CreateSessionRequest req;
*req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph;
+ req.set_target(options_.target);
ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp;
Status s = master_->CreateSession(call_options, &req, &resp);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index c11266587d..873ef8588f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// completes, and we may decide to bound some of the request
// types.
ENQUEUE_REQUEST(GetStatus, false);
+ ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);
@@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(GetStatus, false);
}
+ void CreateWorkerSessionHandler(
+ WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
+ call) {
+ Schedule([this, call]() {
+ Status s = worker_->CreateWorkerSession(&call->request, &call->response);
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(CreateWorkerSession, false);
+ }
+
void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() {
@@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response,
StatusCallback done) {
const int64 step_id = request->step_id();
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
@@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
- session->rendezvous_mgr->RecvLocalAsync(
+ env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, response, done, src_dev](const Status& status,
const Rendezvous::Args& send_args,
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 7518a289fd..8265100061 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -38,9 +38,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public:
- RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* cache, int64 step_id)
- : BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
+ RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
+ : BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private:
~RpcRemoteRendezvous() override {}
- WorkerCacheInterface* const cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};
@@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
return call_freelist;
}
-// A private cache that wraps worker_cache and allows reuse of
-// WorkerInterface objects.
-class WorkerFreeListCache : public WorkerCacheInterface {
- public:
- explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
-
- ~WorkerFreeListCache() {
- for (auto p : workers_) {
- wrapped_->ReleaseWorker(p.first, p.second.worker);
- }
- }
-
- void ListWorkers(std::vector<string>* workers) const override {
- wrapped_->ListWorkers(workers);
- }
-
- WorkerInterface* CreateWorker(const string& target) override {
- mutex_lock l(mu_);
- auto p = workers_.find(target);
- if (p != workers_.end()) {
- return p->second.worker;
- }
- WorkerState state;
- state.worker = wrapped_->CreateWorker(target);
- if (state.worker != nullptr) {
- workers_.insert(std::make_pair(target, state));
- }
- return state.worker;
- }
-
- void ReleaseWorker(const string& target, WorkerInterface* worker) override {
- // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
- }
-
- bool GetDeviceLocalityNonBlocking(const string& device,
- DeviceLocality* locality) override {
- return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
- }
-
- void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
- StatusCallback done) override {
- wrapped_->GetDeviceLocalityAsync(device, locality, done);
- }
-
- void SetLogging(bool active) override { wrapped_->SetLogging(active); }
-
- void ClearLogs() override { wrapped_->ClearLogs(); }
-
- bool RetrieveLogs(int64 step_id, StepStats* ss) override {
- return wrapped_->RetrieveLogs(step_id, ss);
- }
-
- private:
- WorkerCacheInterface* wrapped_;
-
- // Information kept per created WorkerInterface.
- struct WorkerState {
- WorkerInterface* worker;
- // TODO(jeff,sanjay): Add reference count if we support eviction.
- };
-
- // TODO(jeff,sanjay): Eviction when the map becomes too big.
- mutex mu_;
- std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
-};
-
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
+ CHECK(is_initialized());
Status s;
// Prepare a RecvTensor call that can handle being aborted.
@@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
}
- WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
+ WorkerSession* sess = session();
+ WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_);
}
Device* dst_device;
if (s.ok()) {
- s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
- get_call_freelist()->Release(call, cache_);
+ if (rwi != nullptr) {
+ sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
+ }
+ get_call_freelist()->Release(call, sess->worker_cache.get());
done(s, Args(), recv_args, Tensor{}, false);
return;
}
@@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// current status should be bad.
Status s = call->status();
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
- cache_->ReleaseWorker(call->src_worker_, call->wi_);
+ session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
call->wi_ = nullptr;
- get_call_freelist()->Release(call, cache_);
+ get_call_freelist()->Release(call, session()->worker_cache.get());
Unref();
});
}
} // namespace
-RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env,
- const string& worker_name,
- WorkerCacheInterface* worker_cache)
- : BaseRendezvousMgr(env, worker_name),
- cache_(new WorkerFreeListCache(worker_cache)) {}
+RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
+ : BaseRendezvousMgr(env) {}
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
- const WorkerEnv* worker_env,
- const string& worker_name) {
- return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
- step_id);
+ const WorkerEnv* worker_env) {
+ return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
index 75dc62d98f..34c48a7917 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
@@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
-#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+class DeviceMgr;
+
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@@ -44,17 +44,12 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr {
public:
- explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* worker_cache);
+ explicit RpcRendezvousMgr(const WorkerEnv* env);
protected:
- BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
- const string& session_name) override;
+ BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
private:
- // Private cache_ that allows us to reuse WorkerInterface objects.
- std::unique_ptr<WorkerCacheInterface> cache_;
-
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 9b778eab3a..2d0d76623d 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache),
worker_session_("/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_),
- std::unique_ptr<RendezvousMgrInterface>(),
+ std::unique_ptr<DeviceMgr>(),
std::unique_ptr<GraphMgr>()),
- rmgr_(&env, worker_session_.worker_name, cache_) {
+ rmgr_(&env) {
env.env = Env::Default();
}
@@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort().
const int64 step_id = 123;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, rendez]() {
env.env->SleepForMicroseconds(100 * 1000);
@@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
{ // Cleanup causes Abort().
const int64 step_id = 321;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, step_id]() {
env.env->SleepForMicroseconds(100 * 1000);
@@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
}
@@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
const int64 step_id = 123;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
args.device_context = dc;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
}
{
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index e2be62f816..22551d5482 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -17,8 +17,9 @@ limitations under the License.
#include <utility>
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
-#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -26,23 +27,12 @@ namespace tensorflow {
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
- WorkerCacheFactory worker_cache_factory)
- : SessionMgr(
- worker_env, default_worker_name, std::move(default_worker_cache),
- default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
-
-SessionMgr::SessionMgr(
- WorkerEnv* worker_env, const string& default_worker_name,
- std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env),
- legacy_session_(
- default_worker_name, std::move(default_worker_cache),
- std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr),
- std::unique_ptr<GraphMgr>(
- new GraphMgr(worker_env, default_rendezvous_mgr))),
+ legacy_session_(default_worker_name, std::move(default_worker_cache),
+ std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
+ std::unique_ptr<GraphMgr>(
+ new GraphMgr(worker_env, worker_env->device_mgr))),
worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
@@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) {
mutex_lock l(mu_);
+ if (session.empty()) {
+ return errors::InvalidArgument("Session must be non-empty.");
+ }
+
const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr;
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- new RpcRendezvousMgr(worker_env_, worker_name, worker_cache));
+ std::vector<Device*> renamed_devices;
+ for (Device* d : worker_env_->local_devices) {
+ renamed_devices.push_back(
+ RenamedDevice::NewRenamedDevice(worker_name, d, false));
+ }
+ std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
std::unique_ptr<GraphMgr> graph_mgr(
- new GraphMgr(worker_env_, rendezvous_mgr.get()));
+ new GraphMgr(worker_env_, device_mgr.get()));
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
- std::move(rendezvous_mgr), std::move(graph_mgr)));
+ std::move(device_mgr), std::move(graph_mgr)));
sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK();
@@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
if (it != sessions_.end()) {
sessions_.erase(it);
}
- std::set<string> graph_handles;
- for (auto graph_handle_it = sessions_by_graph_handle_.begin();
- graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
- if (graph_handle_it->second == session) {
- graph_handles.insert(graph_handle_it->first);
- graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
- if (graph_handle_it == sessions_by_graph_handle_.end()) break;
- }
- }
- for (auto step_id_it = graphs_by_step_id_.begin();
- step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
- if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
- step_id_it = graphs_by_step_id_.erase(step_id_it);
- if (step_id_it == graphs_by_step_id_.end()) break;
- }
- }
return Status::OK();
}
@@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
-WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
- const string& graph_handle) {
- auto it = sessions_by_graph_handle_.find(graph_handle);
- if (it == sessions_by_graph_handle_.end()) {
- return &legacy_session_;
- } else {
- return WorkerSessionForSessionUnlocked(it->second);
- }
-}
-
-WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
- const string& graph_handle) {
- mutex_lock l(mu_);
- return WorkerSessionForGraphHandleUnlocked(graph_handle);
-}
-
-WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
- mutex_lock l(mu_);
- auto it = graphs_by_step_id_.find(step_id);
- if (it == graphs_by_step_id_.end()) {
- return &legacy_session_;
- } else {
- return WorkerSessionForGraphHandleUnlocked(it->second);
- }
-}
-
-void SessionMgr::AssociateGraphWithSession(const string& session,
- const string& graph_handle) {
- mutex_lock l(mu_);
- sessions_by_graph_handle_[graph_handle] = session;
-}
-
-void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
- mutex_lock l(mu_);
- auto it = sessions_by_graph_handle_.find(graph_handle);
- if (it != sessions_by_graph_handle_.end()) {
- sessions_by_graph_handle_.erase(it);
- }
-}
-
-void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
- const int64 step_id) {
- mutex_lock l(mu_);
- graphs_by_step_id_[step_id] = graph_handle;
-}
-
-void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
- mutex_lock l(mu_);
- auto it = graphs_by_step_id_.find(step_id);
- if (it != graphs_by_step_id_.end()) {
- graphs_by_step_id_.erase(it);
- }
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index 455b5c8d9d..c44bca7b7a 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -30,6 +30,8 @@ struct WorkerEnv;
// SessionMgr keeps track of information related to a given session.
//
+// SessionMgr runs on the workers.
+//
// SessionMgr is threadsafe.
class SessionMgr {
public:
@@ -39,7 +41,6 @@ class SessionMgr {
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
~SessionMgr() {}
@@ -50,49 +51,36 @@ class SessionMgr {
WorkerSession* WorkerSessionForSession(const string& session);
WorkerSession* LegacySession();
- // Locates the worker session for a given graph handle
- WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
- void AssociateGraphWithSession(const string& session,
- const string& graph_handle);
- void DisassociateGraphFromSession(const string& graph_handle);
-
- // Locates a worker session for a given step id
- WorkerSession* WorkerSessionForStepId(const int64 step_id);
- void AssociateStepIdWithGraph(const string& graph_handle,
- const int64 step_id);
- void DisassociateStepIdFromGraph(const int64 step_id);
-
Status DeleteSession(const string& session);
static string WorkerNameFromServerDef(const ServerDef& server_def);
private:
- // Private constructor to work around std::unique_ptr ownership issues.
- explicit SessionMgr(
- WorkerEnv* worker_env, const string& default_worker_name,
- std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- RendezvousMgrInterface* default_rendezvous_mgr,
- WorkerCacheFactory worker_cache_factory);
-
const WorkerEnv* const worker_env_; // Not owned.
+
+ // A note about destruction:
+ // We must delete graph_mgr before device_mgr, due to shared
+ // ownership of OpKernels in the executors. (The graph_mgr will
+ // free all stateless OpKernels, and pass over borrowed stateful
+ // OpKernels, which are also held in their respective devices'
+ // OpSegments.)
+ //
+ // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
+ // that sessions_'s WorkerSessions are deleted (which do not own the
+ // underlying devices, but instead own RenamedDevices) before
+ // legacy_session_ is deleted. Further, we must ensure that WorkerSession's
+ // device_mgr is deleted after WorkerSession's graph_mgr.
+
WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_;
WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
- WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// A map from session identifier to internal session structure.
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
-
- // A map from graph handles to the session that they belong to.
- std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
-
- // A map from globally-unique step id's to the corresponding graph handles.
- std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index d3f3fa8395..7132f123a5 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(),
- std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
- &env_, "/job:mnist/replica:0/task:0", nullptr)),
factory_),
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
@@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
-}
-
-TEST_F(SessionMgrTest, AssociateGraphWithSession) {
- ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
- WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(session, graph_session);
-
+ EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
-TEST_F(SessionMgrTest, AssociateStepWithGraph) {
+TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
+ string session_handle = "";
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(session, graph_session);
-
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(session, step_session);
- ASSERT_EQ(graph_session, step_session);
+ EXPECT_EQ(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
-TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(legacy_session_, graph_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
- ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
- WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(legacy_session_, graph_session);
-
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def;
server_def.set_job_name("worker");
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 89639e21b5..07bb17981d 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(),
request->debug_options(), response->mutable_graph_handle());
- if (s.ok()) {
- env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
- response->graph_handle());
- }
done(s);
}
@@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) {
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Deregister(request->graph_handle());
- env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
done(s);
}
@@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
}
void Worker::AbortStep(int64 step_id) {
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
- Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
+ Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this
@@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
- env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
}
CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync(
- request->graph_handle(), step_id, request->exec_opts(), collector,
- cost_graph, cm, in,
+ request->graph_handle(), step_id, session, request->exec_opts(),
+ collector, cost_graph, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) {
if (s.ok()) {
@@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(graph_handle);
- env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
+
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
[cm]() { cm->StartCancel(); });
}
session->graph_mgr->ExecuteAsync(
- graph_handle, step_id, request->exec_opts(), nullptr /* collector */,
- nullptr /* cost_graph */, cm, in,
+ graph_handle, step_id, session, request->exec_opts(),
+ nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, token, graph_handle, step_id, cm](Status s) {
{
mutex_lock l(mu_);
@@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
const int64 step_id = request->step_id();
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
- session->rendezvous_mgr->Cleanup(step_id);
+ env_->rendezvous_mgr->Cleanup(step_id);
done(Status::OK());
}
@@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) {
// Figures out which device the tensor is hosted on.
- TF_RETURN_IF_ERROR(
- env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
+ string local_name = DeviceNameUtils::LocalName(parsed.src_device);
+ TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h
index 24fb5948a7..f09bea328f 100644
--- a/tensorflow/core/distributed_runtime/worker_env.h
+++ b/tensorflow/core/distributed_runtime/worker_env.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
+#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -24,8 +25,10 @@ namespace thread {
class ThreadPool;
} // namespace thread
+class Device;
class DeviceMgr;
class Env;
+class RendezvousMgrInterface;
class SessionMgr;
// The worker environment class, which holds a bag of pointers to
@@ -38,10 +41,18 @@ struct WorkerEnv {
// session_mgr encapsulates state for each session.
SessionMgr* session_mgr = nullptr;
+ // The local devices of this worker. Devices are owned by the device_mgr.
+ //
+ // REQUIRES: !local_devices.empty().
+ std::vector<Device*> local_devices;
+
// device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr;
+ // A set of rendezvous keyed by step ids.
+ RendezvousMgrInterface* rendezvous_mgr = nullptr;
+
// A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr;
};
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index 508bc7f468..c9db28ec67 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -113,6 +113,11 @@ class WorkerInterface {
return CallAndWait(&ME::GetStatusAsync, request, response);
}
+ Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
+ CreateWorkerSessionResponse* response) {
+ return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
+ }
+
Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response);
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index 8298e16959..8691450e9b 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -17,14 +17,84 @@ limitations under the License.
namespace tensorflow {
-WorkerSession::WorkerSession(
- const string& worker_name,
- std::unique_ptr<WorkerCacheInterface> worker_cache,
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr,
- std::unique_ptr<GraphMgr> graph_mgr)
+namespace {
+
+// A private cache that wraps worker_cache and allows reuse of
+// WorkerInterface objects.
+class WorkerFreeListCache : public WorkerCacheInterface {
+ public:
+ explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
+ : wrapped_(std::move(w)) {}
+
+ ~WorkerFreeListCache() final {
+ for (auto p : workers_) {
+ wrapped_->ReleaseWorker(p.first, p.second.worker);
+ }
+ }
+
+ void ListWorkers(std::vector<string>* workers) const override {
+ wrapped_->ListWorkers(workers);
+ }
+
+ WorkerInterface* CreateWorker(const string& target) override {
+ mutex_lock l(mu_);
+ auto p = workers_.find(target);
+ if (p != workers_.end()) {
+ return p->second.worker;
+ }
+ WorkerState state;
+ state.worker = wrapped_->CreateWorker(target);
+ if (state.worker != nullptr) {
+ workers_.insert(std::make_pair(target, state));
+ }
+ return state.worker;
+ }
+
+ void ReleaseWorker(const string& target, WorkerInterface* worker) override {
+ // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
+ }
+
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ wrapped_->GetDeviceLocalityAsync(device, locality, done);
+ }
+
+ void SetLogging(bool active) override { wrapped_->SetLogging(active); }
+
+ void ClearLogs() override { wrapped_->ClearLogs(); }
+
+ bool RetrieveLogs(int64 step_id, StepStats* ss) override {
+ return wrapped_->RetrieveLogs(step_id, ss);
+ }
+
+ private:
+ std::unique_ptr<WorkerCacheInterface> wrapped_;
+
+ // Information kept per created WorkerInterface.
+ struct WorkerState {
+ WorkerInterface* worker;
+ // TODO(jeff,sanjay): Add reference count if we support eviction.
+ };
+
+ // TODO(jeff,sanjay): Eviction when the map becomes too big.
+ mutex mu_;
+ std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
+};
+
+} // namespace
+
+WorkerSession::WorkerSession(const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceMgr> device_mgr,
+ std::unique_ptr<GraphMgr> graph_mgr)
: worker_name(worker_name),
- worker_cache(std::move(worker_cache)),
- rendezvous_mgr(std::move(rendezvous_mgr)),
+ worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
+ device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)) {}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h
index e6ebe88329..77cf4de8f7 100644
--- a/tensorflow/core/distributed_runtime/worker_session.h
+++ b/tensorflow/core/distributed_runtime/worker_session.h
@@ -18,14 +18,13 @@ limitations under the License.
#include <string>
+#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
-#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
namespace tensorflow {
class GraphMgr;
-class RendezvousMgrInterface;
class WorkerCacheInterface;
// WorkerSession encapsulates all of the state relating to a given session.
@@ -36,17 +35,20 @@ struct WorkerSession {
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache;
- // A set of rendezvous keyed by step ids.
- const std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr;
+ // Collection of local devices. These devices are typically RenamedDevices
+ // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr
+ // == worker_env_.device_mgr, which holds the true devices.
+ const std::unique_ptr<DeviceMgr> device_mgr;
// graph_mgr keeps track of the registered graphs of this session.
//
// Note: graph_mgr must be deleted before rendezvous_mgr!
+ // Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr;
WorkerSession(const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache,
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr,
+ std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr);
};
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 8894671fdf..27fe28fe60 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -115,7 +115,7 @@ class DeviceBase {
cpu_worker_threads_ = t;
}
- const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
+ virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
CHECK(cpu_worker_threads_ != nullptr);
return cpu_worker_threads_;
}
@@ -140,7 +140,7 @@ class DeviceBase {
gpu_device_info_ = g;
}
- const GpuDeviceInfo* tensorflow_gpu_device_info() const {
+ virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const {
return gpu_device_info_;
}
@@ -170,13 +170,13 @@ class DeviceBase {
return GetAllocator(attr);
}
- const Eigen::ThreadPoolDevice* eigen_cpu_device() {
+ virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
}
#ifdef TENSORFLOW_USE_SYCL
- const Eigen::SyclDevice* eigen_sycl_device() const {
+ virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr);
return eigen_sycl_device_;
}
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 3d913cdaf0..6fad379b76 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -656,22 +656,6 @@ Status OpKernelContext::allocate_persistent(DataType type,
*out_tensor = out_persistent->AccessTensor(this);
}
}
- if (track_allocations() && persistent.TotalBytes() > 0) {
- // TODO(yuefengz): some allocators allocate memory even if the requested
- // size is 0.
- Allocator* a = get_allocator(attr);
- if (a->TracksAllocationSizes()) {
- int64 alloc_size =
- a->AllocatedSize(const_cast<char*>(persistent.tensor_data().data()));
- int64 alloc_id =
- a->AllocationId(const_cast<char*>(persistent.tensor_data().data()));
- if (allocate_on_host(attr)) {
- record_host_persistent_memory_allocation(alloc_size, alloc_id);
- } else {
- record_device_persistent_memory_allocation(alloc_size, alloc_id);
- }
- }
- }
return s;
}
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 22f4708d03..372092f42a 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -111,6 +111,7 @@ cc_library(
name = "utils",
srcs = ["utils.cc"],
hdrs = ["utils.h"],
+ defines = if_cuda(["GOOGLE_CUDA=1"]),
visibility = ["//visibility:public"],
deps = [
":op_performance_data_cc",
@@ -167,3 +168,29 @@ cc_library(
"//tensorflow/core/kernels:ops_util",
],
)
+
+cc_library(
+ name = "op_level_cost_estimator",
+ srcs = ["op_level_cost_estimator.cc"],
+ hdrs = ["op_level_cost_estimator.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cost_estimator",
+ ":op_performance_data_cc",
+ ":utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ ],
+)
+
+cc_test(
+ name = "op_level_cost_estimator_test",
+ srcs = ["op_level_cost_estimator_test.cc"],
+ deps = [
+ ":op_level_cost_estimator",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
new file mode 100644
index 0000000000..baed7a8899
--- /dev/null
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -0,0 +1,554 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/grappler/costs/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+constexpr int kOpsPerMac = 2;
+constexpr char kConv2d[] = "Conv2D";
+constexpr char kConv2dBackPropFilter[] = "Conv2DBackpropFilter";
+constexpr char kConv2dBackPropInput[] = "Conv2DBackpropInput";
+constexpr char kMatMul[] = "MatMul";
+constexpr char kSparseMatMul[] = "SparseMatMul";
+constexpr char kIdentity[] = "Identity";
+constexpr char kNoOp[] = "NoOp";
+constexpr char kReshape[] = "Reshape";
+
+OpLevelCostEstimator::OpLevelCostEstimator() {
+ // Syntactic sugar to build and return a lambda that takes an OpInfo and
+ // returns a cost.
+ typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature)
+ const;
+ auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpInfo&)> {
+ return [this, impl](const OpInfo& op) { return (this->*impl)(op); };
+ };
+
+ device_cost_impl_ = {
+ {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)},
+ {kConv2dBackPropFilter,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackPropFilter)},
+ {kConv2dBackPropInput,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackPropInput)},
+ {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
+ {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
+ {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}};
+}
+
+Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const {
+ auto it = device_cost_impl_.find(op_features.op());
+ if (it == device_cost_impl_.end()) {
+ VLOG(1) << "Missing implementation for op: " << op_features.op();
+ Costs costs;
+ costs = DummyExecutionTime(op_features);
+ return costs;
+ }
+
+ std::function<Costs(const OpInfo&)> estimator = it->second;
+ Costs costs = estimator(op_features);
+ VLOG(1) << "Operation " << op_features.op() << " takes "
+ << costs.execution_time.count() << " ns.";
+ return costs;
+}
+
+std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
+ const OpInfo::DeviceProperties& device) const {
+ double gflops = -1;
+ double bandwidth = -1;
+ if (device.bandwidth() > 0) {
+ bandwidth = device.bandwidth() / 1e6;
+ }
+
+ if (device.type() == "CPU") {
+ const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
+ // Check if vector instructions are available, and refine performance
+ // prediction based on this.
+ gflops = local_cpu.num_cores() * local_cpu.frequency();
+ if (bandwidth < 0) {
+ if (local_cpu.bandwidth() > 0) {
+ bandwidth = local_cpu.bandwidth() / 1e6;
+ } else {
+ bandwidth = 32;
+ }
+ }
+ } else if (device.type() == "GPU") {
+ const OpInfo::DeviceProperties local_gpu = GetLocalGPUInfo(0);
+ const string architecture = local_gpu.environment().at("architecture");
+ int cores_per_multiprocessor;
+ if (architecture < "3") {
+ // Fermi
+ cores_per_multiprocessor = 32;
+ } else if (architecture < "4") {
+ // Kepler
+ cores_per_multiprocessor = 192;
+ } else if (architecture < "6") {
+ // Maxwell
+ cores_per_multiprocessor = 128;
+ } else {
+ // Pascal.
+ cores_per_multiprocessor = 64;
+ }
+ gflops = local_gpu.num_cores() * local_gpu.frequency() *
+ cores_per_multiprocessor * kOpsPerMac;
+ if (bandwidth < 0) {
+ CHECK(local_gpu.bandwidth() > 0);
+ bandwidth = local_gpu.bandwidth() / 1e6;
+ }
+ }
+
+ return std::make_pair(gflops, bandwidth);
+}
+
+Costs OpLevelCostEstimator::DummyExecutionTime(
+ const OpInfo& op_features) const {
+ Costs costs = PredictOpCountBasedCost(0, op_features);
+ costs.inaccurate = true;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictOpCountBasedCost(
+ double operations, const OpInfo& op_features) const {
+ std::pair<double, double> device_perf = GetDeviceInfo(op_features.device());
+ Costs::NanoSeconds compute_cost(operations / device_perf.first);
+ VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9
+ << " Execution Time (ns):" << compute_cost.count();
+
+ bool found_unknown_shapes = false;
+ double total_input_size =
+ CalculateInputSize(op_features, &found_unknown_shapes);
+ double total_output_size =
+ CalculateOutputSize(op_features, &found_unknown_shapes);
+ double total_io_size = total_input_size + total_output_size;
+
+ Costs::NanoSeconds memory_cost(total_io_size / device_perf.second);
+ VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3
+ << " Memory Time (ns):" << memory_cost.count();
+
+ Costs costs;
+ costs.compute_time = compute_cost;
+ costs.memory_time = memory_cost;
+ costs.execution_time = compute_cost + memory_cost;
+ costs.inaccurate = found_unknown_shapes;
+ return costs;
+}
+
+int64 OpLevelCostEstimator::CountConv2DOperations(
+ const OpInfo& op_features, bool* found_unknown_shapes) const {
+ return CountConv2DOperations(op_features, nullptr, found_unknown_shapes);
+}
+
+namespace {
+
+string GetDataFormat(const OpInfo& op_features) {
+ string data_format = "NHWC"; // Default format.
+ if (op_features.attr().find("data_format") != op_features.attr().end()) {
+ data_format = op_features.attr().at("data_format").s();
+ }
+ return data_format;
+}
+
+Padding GetPadding(const OpInfo& op_features) {
+ if (op_features.attr().find("padding") != op_features.attr().end() &&
+ op_features.attr().at("padding").s() == "VALID") {
+ return Padding::VALID;
+ }
+ return Padding::SAME; // Default padding.
+}
+
+std::vector<int64> GetStrides(const OpInfo& op_features) {
+ if (op_features.attr().find("strides") != op_features.attr().end()) {
+ const auto strides = op_features.attr().at("strides").list().i();
+ return {strides[0], strides[1], strides[2], strides[3]};
+ }
+ return {1, 1, 1, 1};
+}
+
+int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
+ const Padding& padding) {
+ // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
+ // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
+ if (padding == Padding::VALID) {
+ return (input - filter + stride) / stride;
+ } else { // SAME.
+ return (input + stride - 1) / stride;
+ }
+}
+
+// Return a minimum shape if the shape is unknown. If known, return the original
+// shape.
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes) {
+ auto shape = original_shape;
+ if (shape.unknown_rank()) {
+ *found_unknown_shapes = true;
+ }
+ if (shape.unknown_rank() || shape.dim_size() == 0) {
+ TensorShapeProto::Dim dim;
+ VLOG(1) << "WARNING: Use minimum shape because the shape is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ dim.set_size(1);
+ for (int i = 0; i < rank; i++) {
+ *shape.add_dim() = dim;
+ }
+ } else {
+ CHECK_EQ(shape.dim_size(), rank);
+ for (int i = 0; i < rank; i++) {
+ if (shape.dim(i).size() == -1) {
+ *found_unknown_shapes = true;
+ VLOG(1)
+ << "WARNING: Use minimum dim size 1 because the shape is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ shape.mutable_dim(i)->set_size(1);
+ }
+ }
+ }
+ return shape;
+}
+} // namespace
+
+// Helper to translate the positional arguments into named fields.
+OpLevelCostEstimator::ConvolutionDimensions
+OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape,
+ const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
+ bool* found_unknown_shapes) {
+ auto image_shape =
+ MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
+ auto filter_shape =
+ MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
+
+ int x_index, y_index, channel_index;
+ const string& data_format = GetDataFormat(op_features);
+ if (data_format == "NCHW") {
+ x_index = 2;
+ y_index = 3;
+ channel_index = 1;
+ } else {
+ x_index = 1;
+ y_index = 2;
+ channel_index = 3;
+ }
+ int64 batch = image_shape.dim(0).size();
+ int64 ix = image_shape.dim(x_index).size();
+ int64 iy = image_shape.dim(y_index).size();
+ int64 iz = image_shape.dim(channel_index).size();
+ int64 kx = filter_shape.dim(0).size();
+ int64 ky = filter_shape.dim(1).size();
+ std::vector<int64> strides = GetStrides(op_features);
+ const auto padding = GetPadding(op_features);
+ int64 sx = strides[x_index];
+ int64 sy = strides[y_index];
+ int64 ox = GetOutputSize(ix, kx, sx, padding);
+ int64 oy = GetOutputSize(iy, ky, sy, padding);
+ int64 oz = filter_shape.dim(3).size();
+ // Only check equality when both sizes are known (in other words, when
+ // neither is set to a minimum dimension size of 1).
+ if (iz != 1 && filter_shape.dim(2).size() != 1) {
+ CHECK_EQ(iz, filter_shape.dim(2).size());
+ } else {
+ iz = std::max<int64>(iz, filter_shape.dim(2).size());
+ }
+ OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
+ batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
+
+ VLOG(1) << "Batch Size:" << batch;
+ VLOG(1) << "Image Dims:" << ix << "," << iy;
+ VLOG(1) << "Input Features:" << iz;
+ VLOG(1) << "Kernel Dims:" << kx << "," << ky;
+ VLOG(1) << "Output Features:" << oz;
+ VLOG(1) << "Output Dims:" << ox << "," << oy;
+ VLOG(1) << "Strides:" << sx << "," << sy;
+ VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
+ return conv_dims;
+}
+
+int64 OpLevelCostEstimator::CountConv2DOperations(
+ const OpInfo& op_features, ConvolutionDimensions* conv_info,
+ bool* found_unknown_shapes) const {
+ if (op_features.op() != kConv2d) {
+ LOG(ERROR) << "Invalid Operation";
+ return 0;
+ }
+ ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
+ op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features,
+ found_unknown_shapes);
+
+ int64 ops = conv_dims.batch;
+ ops *= conv_dims.ox * conv_dims.oy;
+ ops *= conv_dims.kx * conv_dims.ky;
+ ops *= conv_dims.iz * conv_dims.oz;
+ ops *= kOpsPerMac;
+ VLOG(1) << "Operations for Conv2D" << ops;
+
+ if (conv_info != nullptr) {
+ *conv_info = conv_dims;
+ }
+ return ops;
+}
+
+int64 OpLevelCostEstimator::CountMatMulOperations(
+ const OpInfo& op_features, bool* found_unknown_shapes) const {
+ return CountMatMulOperations(op_features, nullptr, found_unknown_shapes);
+}
+
+int64 OpLevelCostEstimator::CountMatMulOperations(
+ const OpInfo& op_features, MatMulDimensions* mat_mul,
+ bool* found_unknown_shapes) const {
+ double ops = 0;
+
+ // TODO(nishantpatil): Create separate estimator for Sparse Matmul
+ if ((op_features.op() != kMatMul) && (op_features.op() != kSparseMatMul)) {
+ LOG(ERROR) << "Invalid Operation";
+ return ops;
+ }
+
+ // first matrix
+ auto& a_matrix = op_features.inputs(0);
+ auto& b_matrix = op_features.inputs(1);
+
+ bool transpose_a = false;
+ bool transpose_b = false;
+
+ double m_dim, n_dim, k_dim, k_dim_b = 0;
+
+ for (const auto& item : op_features.attr()) {
+ VLOG(1) << "Key:" << item.first
+ << " Value:" << SummarizeAttrValue(item.second);
+ if (item.first == "transpose_a" && item.second.b() == true)
+ transpose_a = true;
+ if (item.first == "transpose_b" && item.second.b() == true)
+ transpose_b = true;
+ }
+ VLOG(1) << "transpose_a:" << transpose_a;
+ VLOG(1) << "transpose_b:" << transpose_b;
+ auto a_matrix_shape =
+ MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
+ auto b_matrix_shape =
+ MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
+ if (transpose_a) {
+ m_dim = a_matrix_shape.dim(1).size();
+ k_dim = a_matrix_shape.dim(0).size();
+ } else {
+ m_dim = a_matrix_shape.dim(0).size();
+ k_dim = a_matrix_shape.dim(1).size();
+ }
+ if (transpose_b) {
+ k_dim_b = b_matrix_shape.dim(1).size();
+ n_dim = b_matrix_shape.dim(0).size();
+ } else {
+ k_dim_b = b_matrix_shape.dim(0).size();
+ n_dim = b_matrix_shape.dim(1).size();
+ }
+
+ VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
+ // Only check equality when both sizes are known (in other words, when
+ // neither is set to a minimum dimension size of 1).
+ if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
+ LOG(ERROR) << "Incompatible Matrix dimensions";
+ return ops;
+ } else {
+ // One of k_dim and k_dim_b might be 1 (mininum dimension size).
+ k_dim = std::max(k_dim, k_dim_b);
+ }
+
+ ops = m_dim * n_dim * k_dim * 2;
+ VLOG(1) << "Operations for Matmul" << ops;
+
+ if (mat_mul != nullptr) {
+ mat_mul->m = m_dim;
+ mat_mul->n = n_dim;
+ mat_mul->k = k_dim;
+ }
+ return ops;
+}
+
+// TODO(cliffy): Dedup this method and CountConv2DBackPropFilterOperations.
+int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations(
+ const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
+ bool* found_unknown_shapes) const {
+ int64 ops = 0;
+
+ if (op_features.op() != kConv2dBackPropInput) {
+ LOG(ERROR) << "Invalid Operation";
+ return ops;
+ }
+
+ if (op_features.attr().find("_output_shapes") == op_features.attr().end()) {
+ // Need _output_shapes for input shape.
+ LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure.";
+ return ops;
+ }
+
+ const auto& input_shape =
+ op_features.attr().at("_output_shapes").list().shape(0);
+ ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
+ input_shape, op_features.inputs(1).shape(), op_features,
+ found_unknown_shapes);
+
+ ops = conv_dims.batch;
+ ops *= conv_dims.ox * conv_dims.oy;
+ ops *= conv_dims.kx * conv_dims.ky;
+ ops *= conv_dims.iz * conv_dims.oz;
+ ops *= kOpsPerMac;
+
+ VLOG(1) << "Operations for Conv2DBackPropInput" << ops;
+
+ if (returned_conv_dims != nullptr) {
+ *returned_conv_dims = conv_dims;
+ }
+ return ops;
+}
+
+int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations(
+ const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
+ bool* found_unknown_shapes) const {
+ int64 ops = 0;
+ if (op_features.op() != kConv2dBackPropFilter) {
+ LOG(ERROR) << "Invalid Operation";
+ return ops;
+ }
+
+ if (op_features.attr().find("_output_shapes") == op_features.attr().end()) {
+ // Need _output_shapes for filter shape.
+ LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure.";
+ return ops;
+ }
+
+ const auto& filter_shape =
+ op_features.attr().at("_output_shapes").list().shape(0);
+ ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
+ op_features.inputs(0).shape(), filter_shape, op_features,
+ found_unknown_shapes);
+
+ ops = conv_dims.batch;
+ ops *= conv_dims.ox * conv_dims.oy;
+ ops *= conv_dims.kx * conv_dims.ky;
+ ops *= conv_dims.iz * conv_dims.oz;
+ ops *= kOpsPerMac;
+
+ VLOG(1) << "Operations for Conv2DBackPropFilter" << ops;
+
+ if (returned_conv_dims != nullptr) {
+ *returned_conv_dims = conv_dims;
+ }
+ return ops;
+}
+
+int64 OpLevelCostEstimator::CalculateSingleInputSize(
+ const OpInfo::TensorProperties& input, bool* found_unknown_shapes) const {
+ VLOG(1) << " with " << input.dtype() << " input of shape "
+ << input.shape().DebugString();
+ int64 input_size = 1;
+ int num_dims = std::max(1, input.shape().dim_size());
+ auto input_shape =
+ MaybeGetMinimumShape(input.shape(), num_dims, found_unknown_shapes);
+ for (const auto& dim : input_shape.dim()) {
+ input_size *= dim.size();
+ }
+ return input_size * DataTypeSize(input.dtype());
+}
+
+int64 OpLevelCostEstimator::CalculateInputSize(
+ const OpInfo& op_features, bool* found_unknown_shapes) const {
+ int64 total_input_size = 0;
+ for (auto& input : op_features.inputs()) {
+ int64 input_size = CalculateSingleInputSize(input, found_unknown_shapes);
+ total_input_size += input_size;
+ VLOG(1) << "Input Size: " << input_size
+ << " Total Input Size:" << total_input_size;
+ }
+ return total_input_size;
+}
+
+int64 OpLevelCostEstimator::CalculateOutputSize(
+ const OpInfo& op_features, bool* found_unknown_shapes) const {
+ int64 total_output_size = 0;
+ // use float as default for calculations
+ DataType dt = DT_FLOAT;
+ for (const auto& item : op_features.attr()) {
+ VLOG(1) << "Key:" << item.first
+ << " Value:" << SummarizeAttrValue(item.second);
+ if (item.first == "_output_shapes") {
+ for (const auto& original_output_shape : item.second.list().shape()) {
+ int64 output_size = 1;
+ int num_dims = std::max(1, original_output_shape.dim_size());
+ auto output_shape = MaybeGetMinimumShape(
+ original_output_shape, num_dims, found_unknown_shapes);
+ for (const auto& dim : output_shape.dim()) {
+ output_size *= dim.size();
+ }
+ output_size *= DataTypeSize(dt);
+ total_output_size += output_size;
+ VLOG(1) << "Output Size: " << output_size
+ << " Total Output Size:" << total_output_size;
+ }
+ }
+ if (item.first == "T") {
+ dt = item.second.type();
+ }
+ }
+ return total_output_size;
+}
+
+Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const {
+ bool found_unknown_shapes = false;
+ auto costs = PredictOpCountBasedCost(
+ CountConv2DOperations(op_features, &found_unknown_shapes), op_features);
+ costs.inaccurate = found_unknown_shapes;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictConv2DBackPropInput(
+ const OpInfo& op_features) const {
+ bool found_unknown_shapes = false;
+ auto costs =
+ PredictOpCountBasedCost(CountConv2DBackPropInputOperations(
+ op_features, nullptr, &found_unknown_shapes),
+ op_features);
+ costs.inaccurate = found_unknown_shapes;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictConv2DBackPropFilter(
+ const OpInfo& op_features) const {
+ bool found_unknown_shapes = false;
+ auto costs =
+ PredictOpCountBasedCost(CountConv2DBackPropFilterOperations(
+ op_features, nullptr, &found_unknown_shapes),
+ op_features);
+ costs.inaccurate = found_unknown_shapes;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const {
+ bool found_unknown_shapes = false;
+ auto costs = PredictOpCountBasedCost(
+ CountMatMulOperations(op_features, &found_unknown_shapes), op_features);
+ costs.inaccurate = found_unknown_shapes;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const {
+ VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
+ return Costs::ZeroCosts();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
new file mode 100644
index 0000000000..5bb20cc6bb
--- /dev/null
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
+
+#include <functional>
+#include <map>
+#include <string>
+
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/util/padding.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class OpLevelCostEstimator {
+ public:
+ OpLevelCostEstimator();
+ virtual ~OpLevelCostEstimator() {}
+
+ Costs PredictCosts(const OpInfo& op_features) const;
+
+ protected:
+ // Returns an estimate of device performance (in billions of operations
+ // executed per second) and memory bandwith (in GigaBytes/second) for the
+ // specified device.
+ virtual std::pair<double, double> GetDeviceInfo(
+ const OpInfo::DeviceProperties& device) const;
+
+ // For operations for which we haven't yet built estimates, returns a dummy
+ // value based on input size.
+ Costs DummyExecutionTime(const OpInfo& op_features) const;
+
+ // Naive cost estimate based on operations divided by device ops/sec.
+ Costs PredictOpCountBasedCost(double operations,
+ const OpInfo& op_features) const;
+
+ // This family of routines counts the number of operations to perform the
+ // specified TensorFlow Op.
+ struct MatMulDimensions {
+ int m;
+ int n;
+ int k;
+ };
+ struct ConvolutionDimensions {
+ int64 batch; // Batch size.
+ int64 ix; // Input size x.
+ int64 iy; // Input size y.
+ int64 iz; // Input depth.
+ int64 kx; // Kernel x.
+ int64 ky; // Kernel y.
+ int64 oz; // Output depth.
+ int64 ox; // Output size x.
+ int64 oy; // Output size y.
+ int64 sx; // Stride x.
+ int64 sy; // Stride y.
+ Padding padding; // SAME or VALID.
+ };
+ int64 CountConv2DOperations(const OpInfo& op_features,
+ bool* found_unknown_shapes) const;
+ int64 CountConv2DOperations(const OpInfo& op_features,
+ ConvolutionDimensions* conv_info,
+ bool* found_unknown_shapes) const;
+ int64 CountMatMulOperations(const OpInfo& op_features,
+ bool* found_unknown_shapes) const;
+ int64 CountMatMulOperations(const OpInfo& op_features,
+ MatMulDimensions* mat_mul,
+ bool* found_unknown_shapes) const;
+ int64 CountConv2DBackPropInputOperations(const OpInfo& op_features,
+ ConvolutionDimensions* conv_info,
+ bool* found_unknown_shapes) const;
+ int64 CountConv2DBackPropFilterOperations(const OpInfo& op_features,
+ ConvolutionDimensions* conv_info,
+ bool* found_unknown_shapes) const;
+
+ // Calculate the total size in bytes of a single input to a TensorFlow op.
+ int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input,
+ bool* found_unknown_shapes) const;
+
+ // Calculate the total size in bytes of the all
+ // the inputs of specified TensorFlow Op
+ int64 CalculateInputSize(const OpInfo& op_features,
+ bool* found_unknown_shapes) const;
+
+ // Calculate the total size in bytes of the all
+ // the outputs of specified TensorFlow Op
+ int64 CalculateOutputSize(const OpInfo& op_features,
+ bool* found_unknown_shapes) const;
+
+ // This family of routines predicts the costs to
+ // perform the specified TensorFlow Op on the
+ // device represented by a subclass. The default
+ // implementation just divides the operations to
+ // perform the op (from the "Count" routines,
+ // above) by the device peak operations per
+ // second. Override to supply a better estimate.
+ // Implementation of costs other than
+ // execution_time is optional, depending on the
+ // device.
+ Costs PredictConv2D(const OpInfo& op_features) const;
+ Costs PredictConv2DBackPropInput(const OpInfo& op_features) const;
+ Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const;
+ Costs PredictMatMul(const OpInfo& op_features) const;
+ Costs PredictNoOp(const OpInfo& op_features) const;
+
+ // Utility function for safe division. Returns 0
+ // if rhs is 0 or negative.
+ static double SafeDiv(const double lhs, const double rhs) {
+ if (rhs > 0) {
+ return lhs / rhs;
+ } else {
+ return 0.0;
+ }
+ }
+
+ static ConvolutionDimensions ConvolutionDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape,
+ const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
+ bool* found_unknown_shapes);
+
+ private:
+ typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
+ std::map<string, CostImpl> device_cost_impl_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
new file mode 100644
index 0000000000..e0b0348c8e
--- /dev/null
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+// Wrangles the minimum number of proto fields to set up a matrix.
+void DescribeMatrix(int rows, int columns, OpInfo *op_features) {
+ auto input = op_features->add_inputs();
+ auto shape = input->mutable_shape();
+ auto shape_rows = shape->add_dim();
+ shape_rows->set_size(rows);
+ auto shape_columns = shape->add_dim();
+ shape_columns->set_size(columns);
+ input->set_dtype(DT_FLOAT);
+}
+
+// Returns an OpInfo for MatMul with the minimum set of fields set up.
+OpInfo DescribeMatMul(int m, int n, int l, int k) {
+ OpInfo op_features;
+ auto device = op_features.mutable_device();
+ device->set_type("CPU");
+ op_features.set_op("MatMul");
+
+ DescribeMatrix(m, l, &op_features);
+ DescribeMatrix(k, n, &op_features);
+ return op_features;
+}
+
+// Returns an OpInfo for MatMul with unknown input shapes.
+OpInfo DescribeMatMulUnknownShape() {
+ OpInfo op_features;
+ auto device = op_features.mutable_device();
+ device->set_type("CPU");
+ op_features.set_op("MatMul");
+
+ auto input = op_features.add_inputs();
+ auto shape = input->mutable_shape();
+ shape->set_unknown_rank(true);
+
+ input = op_features.add_inputs();
+ shape = input->mutable_shape();
+ shape->set_unknown_rank(true);
+
+ return op_features;
+}
+
+// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
+// estimation purposes.
+void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
+ OpInfo *op_features) {
+ auto input = op_features->add_inputs();
+ auto shape = input->mutable_shape();
+ shape->add_dim()->set_size(dim0);
+ shape->add_dim()->set_size(dim1);
+ shape->add_dim()->set_size(dim2);
+ shape->add_dim()->set_size(dim3);
+}
+
+// Returns an OpInfo for Conv2D with the minimum set of fields set up.
+OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx,
+ int ky, int oz) {
+ OpInfo op_features;
+ auto device = op_features.mutable_device();
+ device->set_type("CPU");
+ op_features.set_op("Conv2D");
+
+ DescribeTensor4D(batch, ix, iy, iz1, &op_features);
+ DescribeTensor4D(kx, ky, iz2, oz, &op_features);
+ return op_features;
+}
+} // namespace
+
+TEST(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
+ OpLevelCostEstimator estimator;
+
+ EXPECT_EQ(false,
+ estimator.PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate);
+ EXPECT_EQ(true,
+ estimator.PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate);
+ EXPECT_EQ(true,
+ estimator.PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate);
+
+ EXPECT_EQ(
+ false,
+ estimator.PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256))
+ .inaccurate);
+ EXPECT_EQ(
+ true,
+ estimator.PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256))
+ .inaccurate);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 4e35de9d4a..0852cb4fd3 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() {
// Combine cpu family and model into the model string.
device.set_model(
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
- device.set_frequency(port::NominalCPUFrequency());
+ device.set_frequency(port::NominalCPUFrequency() * 1e-9);
device.set_num_cores(port::NumSchedulableCPUs());
device.set_l1_cache_size(Eigen::l1CacheSize());
device.set_l2_cache_size(Eigen::l2CacheSize());
@@ -195,6 +195,8 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
properties.memoryClockRate * 2);
}
+ (*device.mutable_environment())["architecture"] =
+ strings::StrCat(properties.major, ".", properties.minor);
(*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
(*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
#endif
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index b7785c94e0..5d437dff50 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -26,7 +26,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace {
struct NodeState {
std::vector<const NodeDef*> inputs;
std::vector<const NodeDef*> outputs;
@@ -86,7 +85,6 @@ class FIFOManager : public ReadyNodeManager {
private:
std::list<const NodeDef*> nodes_;
};
-} // namespace
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index bafbcc200c..64bdd91077 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -18,6 +18,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsConcat(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "Concat" || op == "ConcatV2";
+}
+
bool IsDequeueOp(const NodeDef& node) {
static const std::set<std::string> dequeue_ops = {
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
@@ -30,6 +35,11 @@ bool IsPlaceholder(const NodeDef& node) {
return op == "Placeholder" || op == "PlaceholderV2";
}
+bool IsTranspose(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "Transpose";
+}
+
bool IsVariable(const NodeDef& node) {
const auto op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 2f58835628..4f2bb2bc05 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -21,8 +21,10 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsConcat(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
+bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e3b36c8412..5f30dfbaa2 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -205,11 +205,28 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
+cc_test(
+ name = "layout_optimizer_test",
+ srcs = ["layout_optimizer_test.cc"],
+ deps = [
+ ":layout_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
cc_library(
name = "meta_optimizer",
srcs = ["meta_optimizer.cc"],
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 9570ec17d0..e37c4a5b36 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -68,8 +69,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Slice",
"SquaredDifference",
"Squeeze",
- "Sub",
- "Sum"};
+ "Sub"};
return ops_format_agnostic;
}
@@ -110,9 +110,9 @@ class NodeProcessor {
}
protected:
- bool IsDimsN(NodeDef* node, int n) const {
- if (node->attr().find("_output_shapes") != node->attr().end()) {
- auto shape = node->attr().at("_output_shapes").list().shape(0);
+ bool IsDimsN(const NodeDef& node, int n) const {
+ if (node.attr().find("_output_shapes") != node.attr().end()) {
+ auto shape = node.attr().at("_output_shapes").list().shape(0);
if (shape.dim_size() == n) {
return true;
}
@@ -120,7 +120,7 @@ class NodeProcessor {
return false;
}
- bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); }
+ bool IsDimsFour(const NodeDef& node) const { return IsDimsN(node, 4); }
bool IsNHWC() const {
if (node_->attr().find("data_format") != node_->attr().end()) {
@@ -145,7 +145,7 @@ class NodeProcessor {
}
virtual bool ShouldProcess() const {
- return IsNHWC() && IsDimsFour(node_) && HasOutputs();
+ return IsNHWC() && IsDimsFour(*node_) && HasOutputs();
}
void UpdateAttrDataFormat() {
@@ -268,6 +268,8 @@ class NodeProcessor {
for (const auto& output : outputs) {
string node_name_NCHWToNHWC = strings::StrCat(
kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name());
+ // TODO (yaozhang): handle the rare case where node A is connected to more
+ // than one input of node B.
auto it = std::find_if(output->mutable_input()->begin(),
output->mutable_input()->end(),
[this](const string& input) {
@@ -341,7 +343,7 @@ class BiasAddGradProcessor : public NodeProcessor {
bool ShouldProcess() const override {
auto input = node_map_->GetNode(node_->input(0));
if (input) {
- if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) {
+ if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) {
return true;
}
}
@@ -351,13 +353,89 @@ class BiasAddGradProcessor : public NodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
};
-class Conv2DBackpropFilterProcessor : public NodeProcessor {
+class Conv2DProcessor : public NodeProcessor {
+ public:
+ Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool no_gemm)
+ : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
+ (!IsGemmUsed() || no_gemm_);
+ }
+
+ TensorShapeProto GetShape(const string& input_name) const {
+ string node_name;
+ int output_pos;
+ node_name = ParseNodeName(input_name, &output_pos);
+ NodeDef* node = node_map_->GetNode(node_name);
+ if (node->attr().find("_output_shapes") != node->attr().end()) {
+ return node->attr().at("_output_shapes").list().shape(output_pos);
+ }
+ TensorShapeProto shape;
+ return shape;
+ }
+
+ bool IsStrideOne() const {
+ if (node_->attr().find("strides") != node_->attr().end()) {
+ auto list = node_->attr().at("strides").list();
+ return list.i(1) == 1 && list.i(2) == 1;
+ }
+ return false;
+ }
+
+ bool IsValidPadding() const {
+ if (node_->attr().find("padding") != node_->attr().end()) {
+ auto padding = node_->attr().at("padding").s();
+ return padding == "VALID";
+ }
+ return false;
+ }
+
+ // The logic inside this function is based on the internal implementation of
+ // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
+ // needs to be updated accordingly if the internal implementation changes.
+ bool IsGemmUsed(const TensorShapeProto& filter_shape,
+ const TensorShapeProto& input_shape) const {
+ if (filter_shape.dim_size() == 4) {
+ if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
+ IsStrideOne()) {
+ return true;
+ }
+ }
+ if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
+ if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
+ input_shape.dim(2).size() == filter_shape.dim(1).size() &&
+ IsValidPadding()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ virtual bool IsGemmUsed() const {
+ auto filter_shape = GetShape(node_->input(1));
+ auto input_shape = GetShape(node_->input(0));
+ return IsGemmUsed(filter_shape, input_shape);
+ }
+
+ bool no_gemm_;
+};
+
+class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
public:
Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ NodeMap* node_map, bool no_gemm)
+ : Conv2DProcessor(graph, node, node_map, no_gemm) {}
protected:
+ bool IsGemmUsed() const override {
+ auto filter_shape = GetShape(node_->name());
+ auto input_shape = GetShape(node_->input(0));
+ return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
+ }
+
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos = {0, 2};
return input_pos;
@@ -370,17 +448,24 @@ class Conv2DBackpropFilterProcessor : public NodeProcessor {
void UpdateAttrShape() override {}
};
-class Conv2DBackpropInputProcessor : public NodeProcessor {
+class Conv2DBackpropInputProcessor : public Conv2DProcessor {
public:
Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ NodeMap* node_map, bool no_gemm)
+ : Conv2DProcessor(graph, node, node_map, no_gemm) {}
protected:
+ bool IsGemmUsed() const override {
+ auto filter_shape = GetShape(node_->input(1));
+ auto input_shape = GetShape(node_->name());
+ return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
+ }
+
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos = {2};
return input_pos;
}
+
Status CustomizedProcessing() override {
NodeDef* node = node_map_->GetNode(node_->input(0));
return UpdateAttrValue(node);
@@ -418,7 +503,7 @@ class AgnosticNodeProcessor : public NodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC();
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC();
}
bool IsNodeAfterNCHWToNHWC() const {
@@ -467,7 +552,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
(Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
Is4DOperateWithVector());
}
@@ -484,10 +569,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
auto input0 = node_map_->GetNode(node_->input(0));
auto input1 = node_map_->GetNode(node_->input(1));
if (input0 && input1) {
- return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
((n == 4)
- ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name()))
- : IsDimsN(input1, n));
+ ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name()))
+ : IsDimsN(*input1, n));
}
return false;
}
@@ -571,7 +656,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsAlongDimC();
}
@@ -739,7 +824,7 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsN(*node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsInputConvertible() && IsAlongDimHW();
}
@@ -790,7 +875,7 @@ class SumProcessor : public AgnosticNodeProcessor {
bool ShouldProcess() const override {
auto input0 = node_map_->GetNode(node_->input(0));
return HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
IsAlongDimNHW();
}
@@ -825,10 +910,21 @@ class SumProcessor : public AgnosticNodeProcessor {
}
};
+struct TuningConfig {
+ // If true, do not use the NHWC GEMM implementation. When filter size is
+ // one or filter size is equal to input image size,
+ // the NHWC implementation of Conv2D, Conv2DBackpropInput, and
+ // Conv2DBackpropFilter will use a specialized GEMM implementation, which is
+ // usually faster than the NCHW implementation. The downside is that this
+ // might result in more non-cancellable layout conversion nodes (implemented
+ // by the Tranpose op).
+ bool no_gemm;
+};
+
class DataLayoutOptimizer {
public:
- explicit DataLayoutOptimizer(GraphDef* graph)
- : graph_(graph), node_map_(graph_) {}
+ explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config)
+ : graph_(graph), node_map_(graph_), config_(config) {}
Status Optimize() {
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
@@ -908,12 +1004,15 @@ class DataLayoutOptimizer {
} else if (node->op().compare("BiasAddGrad") == 0) {
node_processor.reset(
new BiasAddGradProcessor(graph_, node, &node_map_));
- } else if (node->op().compare("Conv2DBackpropFilter") == 0) {
+ } else if (node->op().compare("Conv2D") == 0) {
node_processor.reset(
- new Conv2DBackpropFilterProcessor(graph_, node, &node_map_));
+ new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm));
+ } else if (node->op().compare("Conv2DBackpropFilter") == 0) {
+ node_processor.reset(new Conv2DBackpropFilterProcessor(
+ graph_, node, &node_map_, config_.no_gemm));
} else if (node->op().compare("Conv2DBackpropInput") == 0) {
- node_processor.reset(
- new Conv2DBackpropInputProcessor(graph_, node, &node_map_));
+ node_processor.reset(new Conv2DBackpropInputProcessor(
+ graph_, node, &node_map_, config_.no_gemm));
} else if (node->op().compare("FusedBatchNormGrad") == 0) {
node_processor.reset(
new FusedBatchNormGradProcessor(graph_, node, &node_map_));
@@ -1025,17 +1124,46 @@ class DataLayoutOptimizer {
GraphDef* graph_;
NodeMap node_map_;
+ TuningConfig config_;
};
+int GetNumTranspose(const GraphDef& graph) {
+ int number = 0;
+ for (const auto& node : graph.node()) {
+ if (IsTranspose(node)) {
+ number++;
+ }
+ }
+ LOG(INFO) << "Number of Transpose nodes: " << number;
+ return number;
+}
+
Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
- if (GetNumAvailableGPUs() < 1) {
+ if (num_gpus_ == 0) {
+ num_gpus_ = GetNumAvailableGPUs();
+ }
+ if (num_gpus_ < 1) {
// LayoutOptimizer is currently only tuned for GPU.
return Status::OK();
}
+
*output = item.graph;
- DataLayoutOptimizer layout_optimizer(output);
+ TuningConfig config;
+ config.no_gemm = false;
+ DataLayoutOptimizer layout_optimizer(output, config);
auto status = layout_optimizer.Optimize();
+
+ // This is based on an empirical observation that if the introduced Transpose
+ // nodes is more than 30, not using GEMM implementation would result in better
+ // performance.
+ if (status.ok() && GetNumTranspose(*output) > 30) {
+ *output = item.graph;
+ config.no_gemm = true;
+ DataLayoutOptimizer layout_optimizer(output, config);
+ status = layout_optimizer.Optimize();
+ }
+
if (!status.ok()) {
*output = item.graph;
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h
index 66dec17a35..1bd6f9544b 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h
@@ -29,11 +29,17 @@ class LayoutOptimizer : public GraphOptimizer {
string name() const override { return "layout"; };
+ // This is for testing only.
+ void set_num_gpus(int num_gpus) { num_gpus_ = num_gpus; };
+
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
+
+ private:
+ int num_gpus_ = 0;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
new file mode 100644
index 0000000000..be38ca1a69
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+void AddOutputShape(Node* node, const TensorShape& shape) {
+ std::vector<TensorShapeProto> output_shapes;
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+ output_shapes.push_back(shape_proto);
+ node->AddAttr("_output_shapes", output_shapes);
+}
+
+class LayoutOptimizerTest : public ::testing::Test {
+ protected:
+ Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size,
+ const string& padding) {
+ int batch_size = 128;
+ int input_height = input_size;
+ int input_width = input_size;
+ int input_depth = 3;
+ int filter_count = 2;
+ int stride = 1;
+ TensorShape input_shape(
+ {batch_size, input_height, input_width, input_depth});
+ Tensor input_data(DT_FLOAT, input_shape);
+ test::FillIota<float>(&input_data, 1.0f);
+ Output input =
+ ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
+ AddOutputShape(input.node(), input_shape);
+
+ TensorShape filter_shape(
+ {filter_size, filter_size, input_depth, filter_count});
+ Tensor filter_data(DT_FLOAT, filter_shape);
+ test::FillIota<float>(&filter_data, 1.0f);
+ Output filter =
+ ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
+ AddOutputShape(filter.node(), filter_shape);
+
+ Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter,
+ {1, stride, stride, 1}, padding);
+ AddOutputShape(conv.node(), input_shape);
+ return conv;
+ }
+};
+
+TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 1, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 1, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 2, "VALID");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 2, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_TRUE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 3, "VALID");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_TRUE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index abce506aba..2776b95a3c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1327,6 +1327,14 @@ cc_library(
],
)
+cc_library(
+ name = "lookup",
+ deps = [
+ ":lookup_table_init_op",
+ ":lookup_table_op",
+ ],
+)
+
DATA_FLOW_DEPS = [
":bounds_check",
":concat_lib",
@@ -1450,10 +1458,10 @@ LOOKUP_DEPS = [
":initializable_lookup_table",
":lookup_util",
"//tensorflow/core:core_cpu",
- "//tensorflow/core:data_flow_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:lookup_ops_op_lib",
]
tf_kernel_library(
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 1c7afcf866..746fe63e2a 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -19,9 +19,6 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h"
-#include <functional>
-#include <string>
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -29,13 +26,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -43,67 +37,41 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-using Callback = std::function<void()>;
-
-namespace {
-static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
- const Tensor& box_index,
- int* num_boxes) {
- if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
+static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
+ const Tensor& boxes,
+ const Tensor& box_ind,
+ int* num_boxes) {
+ if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
*num_boxes = 0;
- return Status::OK();
+ return;
}
// The shape of 'boxes' is [num_boxes, 4].
- if (boxes.dims() != 2) {
- return errors::InvalidArgument("boxes must be 2-D",
- boxes.shape().DebugString());
- }
+ OP_REQUIRES(context, boxes.dims() == 2,
+ errors::InvalidArgument("boxes must be 2-D",
+ boxes.shape().DebugString()));
*num_boxes = boxes.dim_size(0);
- if (boxes.dim_size(1) != 4) {
- return errors::InvalidArgument("boxes must have 4 columns");
- }
- // The shape of 'box_index' is [num_boxes].
- if (box_index.dims() != 1) {
- return errors::InvalidArgument("box_index must be 1-D",
- box_index.shape().DebugString());
- }
- if (box_index.dim_size(0) != *num_boxes) {
- return errors::InvalidArgument("box_index has incompatible shape");
- }
- return Status::OK();
+ OP_REQUIRES(context, boxes.dim_size(1) == 4,
+ errors::InvalidArgument("boxes must have 4 columns"));
+
+ // The shape of 'box_ind' is [num_boxes].
+ OP_REQUIRES(context, box_ind.dims() == 1,
+ errors::InvalidArgument("box_ind must be 1-D",
+ box_ind.shape().DebugString()));
+ OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
+ errors::InvalidArgument("box_ind has incompatible shape"));
}
-// Conditionally calls the compute callback if all values in box_index are in
-// [0, batch_size) then calls done.
+// Verifies that all values in box_ind are in [0, batch).
template <typename Device>
-inline void RunIfBoxIndexIsValid(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
- int batch_size, Callback compute, Callback done);
-
-// Specialization of CheckValidBoxIndex for a CPUDevice.
-template <>
-inline void RunIfBoxIndexIsValid<CPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
- int batch_size, Callback compute, Callback done) {
- const int num_boxes = box_index.dimension(0);
- for (int b = 0; b < num_boxes; ++b) {
- OP_REQUIRES_ASYNC(
- context, FastBoundsCheck(box_index(b), batch_size),
- errors::OutOfRange("box_index has values outside [0, batch_size)"),
- done);
- }
- compute();
- done();
-}
-
-} // namespace
+inline void CheckValidBoxInd(
+ OpKernelContext* context,
+ typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
template <typename Device, typename T>
-class CropAndResizeOp : public AsyncOpKernel {
+class CropAndResizeOp : public OpKernel {
public:
- explicit CropAndResizeOp(OpKernelConstruction* context)
- : AsyncOpKernel(context) {
+ explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
@@ -112,77 +80,69 @@ class CropAndResizeOp : public AsyncOpKernel {
&extrapolation_value_));
}
- void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
- // The shape of 'image' is [batch_size, image_height, image_width,
- // channels].
+ void Compute(OpKernelContext* context) override {
+ // The shape of 'image' is [batch, image_height, image_width, channels].
const Tensor& image = context->input(0);
- // The shape of 'boxes' is [num_boxes, 4].
- const Tensor& boxes = context->input(1);
- // The shape of 'box_index' is [num_boxes].
- const Tensor& box_index = context->input(2);
- // The shape of 'crop_size' is [2].
- const Tensor& crop_size = context->input(3);
+ OP_REQUIRES(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()));
- // Validate inputs dimensions.
- OP_REQUIRES_ASYNC(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()),
- done);
- const int batch_size = image.dim_size(0);
+ const int batch = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
- OP_REQUIRES_ASYNC(
- context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"), done);
+ OP_REQUIRES(context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"));
+
+ // The shape of 'boxes' is [num_boxes, 4].
+ const Tensor& boxes = context->input(1);
+
+ // The shape of 'box_ind' is [num_boxes].
+ const Tensor& box_ind = context->input(2);
+
int num_boxes = 0;
- OP_REQUIRES_OK_ASYNC(
- context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
-
- OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
- errors::InvalidArgument("crop_size must be 1-D",
- crop_size.shape().DebugString()),
- done);
- OP_REQUIRES_ASYNC(
- context, crop_size.dim_size(0) == 2,
- errors::InvalidArgument("crop_size must have two elements",
- crop_size.shape().DebugString()),
- done);
-
- // Copy and validate crop sizes.
+ ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
+
+ // The shape of 'crop_size' is [2].
+ const Tensor& crop_size = context->input(3);
+
+ OP_REQUIRES(context, crop_size.dims() == 1,
+ errors::InvalidArgument("crop_size must be 1-D",
+ crop_size.shape().DebugString()));
+ OP_REQUIRES(context, crop_size.dim_size(0) == 2,
+ errors::InvalidArgument("crop_size must have two elements",
+ crop_size.shape().DebugString()));
+
auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
- OP_REQUIRES_ASYNC(
- context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("crop dimensions must be positive"), done);
+ OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("crop dimensions must be positive"));
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
+ OP_REQUIRES_OK(
context,
context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
- &output),
- done);
-
- auto compute_callback = [this, context, output]() {
- const Tensor& image = context->input(0);
- const Tensor& boxes = context->input(1);
- const Tensor& box_index = context->input(2);
- const bool status = functor::CropAndResize<Device, T>()(
- context->eigen_device<Device>(), image.tensor<T, 4>(),
- boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
- extrapolation_value_, output->tensor<float, 4>());
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeKernel."));
- }
- };
-
- RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
- batch_size, std::move(compute_callback),
- std::move(done));
+ &output));
+
+ typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
+ typename TTypes<float, 2>::ConstTensor boxes_data =
+ boxes.tensor<float, 2>();
+ typename TTypes<int32, 1>::ConstTensor box_ind_data =
+ box_ind.tensor<int32, 1>();
+ typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
+
+ CheckValidBoxInd<Device>(context, box_ind_data, batch);
+
+ bool status = functor::CropAndResize<Device, T>()(
+ context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
+ extrapolation_value_, crops_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeKernel."));
+ }
}
private:
@@ -195,10 +155,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_index,
+ typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
- const int batch_size = image.dimension(0);
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -213,8 +173,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_index(b);
- if (!FastBoundsCheck(b_in, batch_size)) {
+ const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
continue;
}
@@ -275,94 +235,89 @@ struct CropAndResize<CPUDevice, T> {
return true;
}
};
-
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradImageOp : public AsyncOpKernel {
+class CropAndResizeGradImageOp : public OpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
- : AsyncOpKernel(context) {
+ : OpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
- // The shape of 'boxes' is [num_boxes, 4].
- const Tensor& boxes = context->input(1);
- // The shape of 'box_index' is [num_boxes].
- const Tensor& box_index = context->input(2);
- // The shape of 'image_size' is [4].
- const Tensor& image_size = context->input(3);
- // Validate input shapes.
- OP_REQUIRES_ASYNC(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()),
- done);
+ OP_REQUIRES(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()));
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
- OP_REQUIRES_ASYNC(
- context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"), done);
+ OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"));
+
+ // The shape of 'boxes' is [num_boxes, 4].
+ const Tensor& boxes = context->input(1);
+
+ // The shape of 'box_ind' is [num_boxes].
+ const Tensor& box_ind = context->input(2);
+
int num_boxes = 0;
- OP_REQUIRES_OK_ASYNC(
- context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
- OP_REQUIRES_ASYNC(
+ ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
+
+ OP_REQUIRES(
context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"),
- done);
-
- OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
- errors::InvalidArgument("image_size must be 1-D",
- image_size.shape().DebugString()),
- done);
- OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
- errors::InvalidArgument("image_size must have 4 elements",
- image_size.shape().DebugString()),
- done);
+ errors::InvalidArgument("boxes and grads have incompatible shape"));
+
+ // The shape of 'image_size' is [4].
+ const Tensor& image_size = context->input(3);
+ OP_REQUIRES(context, image_size.dims() == 1,
+ errors::InvalidArgument("image_size must be 1-D",
+ image_size.shape().DebugString()));
+ OP_REQUIRES(context, image_size.dim_size(0) == 4,
+ errors::InvalidArgument("image_size must have 4 elements",
+ image_size.shape().DebugString()));
+
auto image_size_vec = image_size.vec<int32>();
- const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
+ const int batch = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3));
- OP_REQUIRES_ASYNC(
- context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"), done);
- OP_REQUIRES_ASYNC(
+
+ OP_REQUIRES(context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"));
+ OP_REQUIRES(
context, grads.dim_size(3) == depth,
- errors::InvalidArgument("image_size and grads are incompatible"), done);
+ errors::InvalidArgument("image_size and grads are incompatible"));
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- context,
- context->allocate_output(
- 0, TensorShape({batch_size, image_height, image_width, depth}),
- &output),
- done);
-
- auto compute_callback = [context, output]() {
- const Tensor& grads = context->input(0);
- const Tensor& boxes = context->input(1);
- const Tensor& box_index = context->input(2);
- const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
- context->eigen_device<Device>(), grads.tensor<float, 4>(),
- boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
- output->tensor<T, 4>());
- if (!status) {
- context->SetStatus(errors::Internal(
- "Failed launch CropAndResizeBackpropImage kernel."));
- }
- };
-
- RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
- batch_size, std::move(compute_callback),
- std::move(done));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(
+ 0, TensorShape({batch, image_height, image_width, depth}),
+ &output));
+
+ typename TTypes<float, 4>::ConstTensor grads_data =
+ grads.tensor<float, 4>();
+ typename TTypes<float, 2>::ConstTensor boxes_data =
+ boxes.tensor<float, 2>();
+ typename TTypes<int32, 1>::ConstTensor box_ind_data =
+ box_ind.tensor<int32, 1>();
+ typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
+
+ CheckValidBoxInd<Device>(context, box_ind_data, batch);
+
+ bool status = functor::CropAndResizeBackpropImage<Device, T>()(
+ context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
+ output_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
+ }
}
};
@@ -373,9 +328,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_index,
+ typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image) {
- const int batch_size = grads_image.dimension(0);
+ const int batch = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -392,8 +347,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_index(b);
- if (!FastBoundsCheck(b_in, batch_size)) {
+ const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
continue;
}
@@ -444,90 +399,83 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true;
}
};
-
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradBoxesOp : public AsyncOpKernel {
+class CropAndResizeGradBoxesOp : public OpKernel {
public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
- : AsyncOpKernel(context) {
+ : OpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
- // The shape of 'boxes' is [num_boxes, 4].
- const Tensor& boxes = context->input(2);
- // The shape of 'box_index' is [num_boxes].
- const Tensor& box_index = context->input(3);
- // The shape of 'image' is [batch_size, image_height, image_width, depth].
- const Tensor& image = context->input(1);
- // Validate input shapes.
- OP_REQUIRES_ASYNC(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()),
- done);
+ OP_REQUIRES(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()));
+
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3);
- OP_REQUIRES_ASYNC(
- context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"), done);
-
- OP_REQUIRES_ASYNC(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()),
- done);
- const int batch_size = image.dim_size(0);
+ OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"));
+
+ // The shape of 'image' is [batch, image_height, image_width, depth].
+ const Tensor& image = context->input(1);
+ OP_REQUIRES(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()));
+
+ const int batch = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
- OP_REQUIRES_ASYNC(
- context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"), done);
- OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
- errors::InvalidArgument("image, grads depth differ"),
- done);
+ OP_REQUIRES(context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"));
+ OP_REQUIRES(context, image.dim_size(3) == depth,
+ errors::InvalidArgument("image, grads depth differ"));
+
+ // The shape of 'boxes' is [num_boxes, 4].
+ const Tensor& boxes = context->input(2);
+
+ // The shape of 'box_ind' is [num_boxes].
+ const Tensor& box_ind = context->input(3);
int num_boxes = 0;
- OP_REQUIRES_OK_ASYNC(
- context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
+ ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
- OP_REQUIRES_ASYNC(
+ OP_REQUIRES(
context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"),
- done);
+ errors::InvalidArgument("boxes and grads have incompatible shape"));
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- context,
- context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
- done);
-
- auto compute_callback = [context, output]() {
- const Tensor& grads = context->input(0);
- const Tensor& image = context->input(1);
- const Tensor& boxes = context->input(2);
- const Tensor& box_index = context->input(3);
- const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
- context->eigen_device<Device>(), grads.tensor<float, 4>(),
- image.tensor<T, 4>(), boxes.tensor<float, 2>(),
- box_index.tensor<int32, 1>(), output->tensor<float, 2>());
- if (!status) {
- context->SetStatus(errors::Internal(
- "Failed launch CropAndResizeBackpropBoxes kernel."));
- }
- };
-
- RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
- batch_size, std::move(compute_callback),
- std::move(done));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({num_boxes, 4}), &output));
+
+ typename TTypes<float, 4>::ConstTensor grads_data =
+ grads.tensor<float, 4>();
+ typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
+ typename TTypes<float, 2>::ConstTensor boxes_data =
+ boxes.tensor<float, 2>();
+ typename TTypes<int32, 1>::ConstTensor box_ind_data =
+ box_ind.tensor<int32, 1>();
+ typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
+
+ CheckValidBoxInd<Device>(context, box_ind_data, batch);
+
+ bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
+ context->eigen_device<Device>(), grads_data, image_data, boxes_data,
+ box_ind_data, output_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
+ }
}
};
@@ -539,9 +487,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_index,
+ typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
- const int batch_size = image.dimension(0);
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -558,8 +506,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_index(b);
- if (!FastBoundsCheck(b_in, batch_size)) {
+ const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
continue;
}
@@ -641,19 +589,30 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true;
}
};
-
} // namespace functor
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("crop_size"), \
- CropAndResizeOp<CPUDevice, T>); \
- \
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T"), \
+// Specialization of CheckValidBoxInd for a CPUDevice.
+template <>
+inline void CheckValidBoxInd<CPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
+ int batch) {
+ const int num_boxes = box_ind.dimension(0);
+ for (int b = 0; b < num_boxes; ++b) {
+ OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
+ errors::OutOfRange("box_ind has values outside [0, batch)"));
+ }
+}
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crop_size"), \
+ CropAndResizeOp<CPUDevice, T>); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@@ -675,85 +634,49 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA
-// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
+// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
namespace functor {
template <>
-void CheckValidBoxIndexHelper<GPUDevice>::operator()(
- const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
- int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
-extern template struct CheckValidBoxIndexHelper<GPUDevice>;
+void CheckValidBoxIndHelper<GPUDevice>::operator()(
+ const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
+ int batch, typename TTypes<bool, 0>::Tensor isvalid);
+extern template struct CheckValidBoxIndHelper<GPUDevice>;
} // namespace functor
-namespace {
-
-// Specialization of CheckValidBoxIndex for a GPUDevice.
+// Specialization of CheckValidBoxInd for a GPUDevice.
template <>
-inline void RunIfBoxIndexIsValid<GPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
- int batch_size, Callback compute, Callback done) {
- const int num_boxes = box_index.dimension(0);
+inline void CheckValidBoxInd<GPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
+ int batch) {
+ const int num_boxes = box_ind.dimension(0);
if (num_boxes == 0) {
- compute();
- done();
return;
}
+ Tensor isvalid_tensor;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<bool>::value,
+ TensorShape({}), &isvalid_tensor));
- Tensor isvalid_dev_tensor;
- OP_REQUIRES_OK_ASYNC(
- context,
- context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
- &isvalid_dev_tensor),
- done);
- typename TTypes<bool, 0>::Tensor isvalid_dev =
- isvalid_dev_tensor.tensor<bool, 0>();
+ typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
- // Run the actual box check on the device.
- functor::CheckValidBoxIndexHelper<GPUDevice>()(
- context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
+ functor::CheckValidBoxIndHelper<GPUDevice>()(
+ context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
- // Copy the result back to the host.
auto* stream = context->op_device_context()->stream();
- OP_REQUIRES_ASYNC(context, stream,
- errors::Internal("No GPU stream available."), done);
- Tensor isvalid_host_tensor;
- // Use pinned host memory on the host to avoid unnecessary
- // synchronization.
- AllocatorAttributes alloc_attr;
- alloc_attr.set_on_host(true);
- alloc_attr.set_gpu_compatible(true);
- OP_REQUIRES_OK_ASYNC(
- context,
- context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
- &isvalid_host_tensor, alloc_attr),
- done);
- typename TTypes<bool, 0>::Tensor isvalid_host =
- isvalid_host_tensor.tensor<bool, 0>();
-
- perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
- sizeof(bool));
- const bool status = stream
- ->ThenMemcpy(isvalid_host.data() /* destination */,
- wrapped /* source */, sizeof(bool))
- .ok();
- OP_REQUIRES_ASYNC(
- context, status,
- errors::Internal("Failed to launch copy of isvalid from device to host."),
- done);
-
- auto wrapped_callback = [context, isvalid_host, compute, done]() {
- OP_REQUIRES_ASYNC(
- context, isvalid_host(),
- errors::OutOfRange("box_index has values outside [0, batch_size)"),
- done);
- compute();
- done();
- };
-
- context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
- stream, wrapped_callback);
-}
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ bool isvalid_host = false;
+ perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
+ sizeof(bool));
+ stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
+ stream->BlockHostUntilDone();
-} // namespace
+ OP_REQUIRES(context, stream->ok(),
+ errors::Internal("cudaMemcpy from device to host failed"));
+
+ OP_REQUIRES(context, isvalid_host,
+ errors::OutOfRange("box_ind has values outside [0, batch)"));
+}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h
index 460dbad22b..22df1bdd56 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.h
+++ b/tensorflow/core/kernels/crop_and_resize_op.h
@@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
};
template <typename Device>
-struct CheckValidBoxIndexHelper {
- // Checks if all values in box_index are in [0, batch).
+struct CheckValidBoxIndHelper {
+ // Checks if all values in box_ind are in [0, batch).
void operator()(const Device& d,
- typename TTypes<int32, 1>::ConstTensor box_index, int batch,
+ typename TTypes<int32, 1>::ConstTensor box_ind, int batch,
typename TTypes<bool, 0>::Tensor isvalid) {
- isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
+ isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all();
}
};
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index c1235fda89..254475db46 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
-template struct CheckValidBoxIndexHelper<GPUDevice>;
+template struct CheckValidBoxIndHelper<GPUDevice>;
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index d6139dae96..3a7f180598 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("box_index has incompatible shape"))
+ StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
<< s;
}
@@ -264,10 +264,8 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(StringPiece(s.ToString())
- .contains("box_index has values outside [0, batch_size)"))
+ .contains("box_ind has values outside [0, batch)"))
<< s;
}
-// TODO(zhengxq, rmlarsen): Add a benchmark.
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc
index b5093d59fc..48f38872e2 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc
@@ -47,16 +47,26 @@ class SparseTensorDenseAddOp : public OpKernel {
"Input a_indices should be a matrix but received shape: ",
a_indices_t->shape().DebugString()));
OP_REQUIRES(
- ctx, TensorShapeUtils::IsVector(a_values_t->shape()) &&
- TensorShapeUtils::IsVector(a_shape_t->shape()),
+ ctx,
+ TensorShapeUtils::IsVector(a_values_t->shape()) &&
+ TensorShapeUtils::IsVector(a_shape_t->shape()),
errors::InvalidArgument("Inputs a_values and a_shape should be vectors "
"but received shapes: ",
a_values_t->shape().DebugString(), " and ",
a_shape_t->shape().DebugString()));
- OP_REQUIRES(ctx, a_shape_t->NumElements() == b->dims(),
- errors::InvalidArgument(
- "Two operands have different dimensions; received: ",
- a_shape_t->NumElements(), " and ", b->dims()));
+ OP_REQUIRES(
+ ctx, a_shape_t->NumElements() == b->dims(),
+ errors::InvalidArgument("Two operands have different ranks; received: ",
+ a_shape_t->NumElements(), " and ", b->dims()));
+ const auto a_shape_flat = a_shape_t->flat<Index>();
+ for (int i = 0; i < b->dims(); ++i) {
+ OP_REQUIRES(
+ ctx, a_shape_flat(i) == b->dim_size(i),
+ errors::InvalidArgument(
+ "Dimension ", i,
+ " does not equal (no broadcasting is supported): sparse side ",
+ a_shape_flat(i), " vs dense side ", b->dim_size(i)));
+ }
Tensor *out_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t));
@@ -82,8 +92,9 @@ class SparseTensorDenseAddOp : public OpKernel {
NDIMS_CASE(4);
NDIMS_CASE(5);
default:
- OP_REQUIRES(ctx, false, errors::InvalidArgument(
- "Only tensors with ranks between 1 and 5 "
+ OP_REQUIRES(
+ ctx, false,
+ errors::InvalidArgument("Only tensors with ranks between 1 and 5 "
"are currently supported. Tensor rank: ",
ndims));
#undef NDIMS_CASE
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 30026f222a..30c57ef287 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()),
errors::InvalidArgument("Tensor 'a_indices' is not a matrix"));
- OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(),
+ const int64 nnz = a_indices->shape().dim_size(0);
+ OP_REQUIRES(ctx, nnz == a_values->NumElements(),
errors::InvalidArgument("Number of rows of a_indices does not "
"match number of entries in a_values"));
@@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel {
inner_left, " vs. ", inner_right,
". Did you forget a transpose? "
"Dimensions of A: [",
- a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ",
- b->shape().DebugString()));
+ a_shape_t(0), ", ", a_shape_t(1),
+ "). Dimensions of B: ", b->shape().DebugString()));
+
+ if (std::is_same<Device, GPUDevice>::value) {
+ // The GPU implementation is optimized to use 32 bit indexing, so
+ // give a friendly error to the programmer early on if they
+ // exceed.
+ const int int32max = std::numeric_limits<int>::max();
+ OP_REQUIRES(
+ ctx,
+ (FastBoundsCheck(inner_left, int32max) &&
+ FastBoundsCheck(inner_right, int32max) &&
+ FastBoundsCheck(outer_left, int32max) &&
+ FastBoundsCheck(outer_right, int32max) &&
+ FastBoundsCheck(b->NumElements(), int32max) &&
+ FastBoundsCheck(outer_left * outer_right, int32max) &&
+ FastBoundsCheck(a_values->NumElements(), int32max)),
+ errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
+ OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max),
+ errors::InvalidArgument(
+ "Cannot use GPU when output.shape[1] * nnz(a) > 2^31"));
+ }
TensorShape out_shape({outer_left, outer_right});
Tensor* out = nullptr;
@@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel {
return;
}
- Tensor scratch;
-
- if (std::is_same<Device, GPUDevice>::value) {
- // The GPU implementation is optimized to use 32 bit indexing, so
- // give a friendly error to the programmer early on if they exceed.
- OP_REQUIRES(
- ctx,
- FastBoundsCheck(inner_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(inner_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(b->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(out->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(a_values->NumElements(),
- std::numeric_limits<int>::max()),
- errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
- const int nnz = static_cast<const int>(a_values->NumElements());
- // Need nnz length vec scratch space on the GPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({nnz}), &scratch));
- } else {
- // We don't need scratch space on the CPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({0}), &scratch));
- }
-
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
Device, T, Tindices, ADJ_A, \
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
- b->matrix<T>(), scratch.vec<T>()); \
+ b->matrix<T>()); \
OP_REQUIRES_OK(ctx, functor_status); \
}
@@ -189,10 +182,9 @@ namespace functor {
Status SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, \
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
- typename TTypes<Tindices>::ConstMatrix a_indices, \
+ TTypes<Tindices>::ConstMatrix a_indices, \
typename TTypes<T>::ConstVec a_values, \
- typename TTypes<T>::ConstMatrix b, \
- typename TTypes<T>::Vec scratch); \
+ typename TTypes<T>::ConstMatrix b); \
extern template struct SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
@@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b,
- typename TTypes<T>::Vec scratch) {
+ typename TTypes<T>::ConstMatrix b) {
const std::size_t nnz = a_values.size();
const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index e707743f78..da13190494 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -28,11 +28,10 @@ namespace functor {
template <typename Device, typename T, typename Tindices, bool ADJ_A,
bool ADJ_B>
struct SparseTensorDenseMatMulFunctor {
- static EIGEN_ALWAYS_INLINE Status
- Compute(const Device& d, typename TTypes<T>::Matrix out,
- typename TTypes<Tindices>::ConstMatrix a_indices,
- typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
+ static EIGEN_ALWAYS_INLINE Status Compute(
+ const Device& d, typename TTypes<T>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
};
template <typename MATRIX, bool ADJ>
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index 7266e0cf81..e261e42e0d 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -20,71 +20,45 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-namespace generator {
-
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
-class SparseTensorDenseMatMulGPUGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
- typename TTypes<T, 2>::Tensor32Bit out,
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
- typename TTypes<const T, 1>::Tensor32Bit a_values,
- typename TTypes<const T, 2>::Tensor32Bit b)
- : out_(out),
- lhs_index_a_(ADJ_A ? 1 : 0),
- rhs_index_a_(ADJ_A ? 0 : 1),
- a_indices_(a_indices),
- a_values_(a_values),
- lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)),
- maybe_adjoint_b_(
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit,
- ADJ_B>(b)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<int, 2>& j_and_ix) const {
-#ifdef __CUDA_ARCH__
- const int j = j_and_ix[0];
- const int ix = j_and_ix[1];
- int m = a_indices_(ix, lhs_index_a_);
- int k = a_indices_(ix, rhs_index_a_);
- assert(k < lhs_right_size);
- assert(m < out_.dimension(0));
- // If asserts are disabled, the caller is violating the sparse
- // tensor index contract, and so we return invalid results.
- // Force returning NaNs to try to signal that something is amiss.
- T b_value;
- if (k >= lhs_right_size || m >= out_.dimension(0)) {
- m = 0;
- k = 0;
- b_value = std::numeric_limits<T>::quiet_NaN();
- } else {
- b_value = maybe_adjoint_b_(k, j);
+__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
+ int b_cols, int p,
+ const Tindices* a_indices,
+ const T* a_values, const T* b,
+ T* out) {
+ // out_{ij} = sum_k {a_ik b_kj}
+ // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
+ const int n = (ADJ_B) ? b_cols : b_rows;
+ CUDA_1D_KERNEL_LOOP(index, nnz * p) {
+ const int a_ix = index / p;
+ const int j = index % p;
+ const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0));
+ const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1));
+ if (!FastBoundsCheck(i, m)) {
+ continue; // Nowhere to signal an error :(
+ }
+ // out[i, j]
+ T* out_location = out + i * p + j;
+ if (!FastBoundsCheck(k, n)) {
+ CudaAtomicAdd(out_location, std::numeric_limits<T>::quiet_NaN());
+ continue;
}
- atomicAdd(&out_(m, j), a_values_(ix) * b_value);
-#else
- assert(false && "This should only be run on the device");
-#endif
- // Return something
- return T(0);
- }
- private:
- mutable typename TTypes<T, 2>::Tensor32Bit out_;
- const int lhs_index_a_;
- const int rhs_index_a_;
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
- typename TTypes<const T, 1>::Tensor32Bit a_values_;
- const int lhs_right_size;
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
- maybe_adjoint_b_;
-};
+ // a_value == (ADJ_A) ? a[k, i] : a[i, k]
+ const T a_value = ldg(a_values + a_ix);
-} // namespace generator
+ // b_value == (ADJ_B) ? b[j, k] : b[k, j]
+ const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
+ CudaAtomicAdd(out_location, a_value * b_value);
+ }
+}
namespace functor {
@@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
- generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
- sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
- To32Bit(a_values), To32Bit(b));
- To32Bit(out).device(d) = To32Bit(out).constant(T(0));
+ typename TTypes<T>::ConstMatrix b) {
+ out.device(d) = out.constant(T(0));
int nnz = a_values.size();
- int n = (ADJ_B) ? b.dimension(0) : b.dimension(1);
-
-#if !defined(EIGEN_HAS_INDEX_LIST)
- Eigen::Tensor<int, 2>::Dimensions matrix_1_by_nnz{{ 1, nnz }};
- Eigen::array<int, 2> n_by_1{{ n, 1 }};
- Eigen::array<int, 1> reduce_on_rows{{ 0 }};
-#else
- Eigen::IndexList<Eigen::type2index<1>, int> matrix_1_by_nnz;
- matrix_1_by_nnz.set(1, nnz);
- Eigen::IndexList<int, Eigen::type2index<1> > n_by_1;
- n_by_1.set(0, n);
- Eigen::IndexList<Eigen::type2index<0> > reduce_on_rows;
-#endif
-
- // How this works: the generator iterates over (j, ix) where j
- // iterates from 0 .. n - 1 and ix iterates from
- // 0 .. nnz - 1. A side effect of the generator is to accumulate
- // the products of values in A and B into the appropriate location
- // in the dense matrix out. In order to run the iteration,
- // we take a smaller variable and broadcast to a size (n, nnz).
- // This is the scratch variable. In order to enforce execution,
- // we have to perform assignment back into scratch (taking the sum).
- // We don't care what gets assigned to scratch - only the side effect
- // of the execution in the generator.
- //
- // Note it's not sufficient that scratch be a scalar, and to
- // broadcast it to a matrix. Eigen splits the computation not
- // based on the largest intermediate shape (the size of the
- // broadcast of scratch) but based on the output shape. So
- // scratch needs to be a vector at least.
- //
- // Note also that only float type is supported because the
- // atomicAdd operation is only supported for floats in hardware.
- To32Bit(scratch).device(d) =
- To32Bit(scratch)
- .reshape(matrix_1_by_nnz)
- .broadcast(n_by_1)
- .generate(sparse_tensor_dense_matmul_generator)
- .sum(reduce_on_rows);
+ // out = A * B, A is [m x n] and B is [n x p], out is [m x p]
+ int m = out.dimension(0);
+ int p = out.dimension(1);
+ int b_rows = b.dimension(0);
+ int b_cols = b.dimension(1);
+
+ // TODO(ebrevdo): Should this be alpha * nnz instead of
+ // out.size()? Perhaps p * nnz ?
+ CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d);
+
+ SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(),
+ b.data(), out.data());
return Status::OK();
}
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index f5d4fcec84..d50e2060ac 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unordered_map>
#include <utility>
#include "tensorflow/core/framework/op_kernel.h"
@@ -21,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
@@ -50,8 +50,7 @@ class UniqueOp : public OpKernel {
{0}, 1, input.shape(), &idx));
auto idx_vec = idx->template vec<int32>();
- std::unordered_map<T, int32> uniq;
- uniq.reserve(2 * N);
+ gtl::FlatMap<T, int32> uniq(N);
for (int64 i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index 8c173a4ba3..25b17b26c8 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -76,6 +76,18 @@ class VariableOp : public OpKernel {
// As long as the resource manager hasn't been cleared the ref we return
// here is valid because it owns a ref on var.
ctx->set_output_ref(0, var->mu(), var->tensor());
+ if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ if (ctx->allocate_on_host(attr)) {
+ ctx->record_host_persistent_memory_allocation(
+ var->tensor()->AllocatedBytes());
+ } else {
+ ctx->record_device_persistent_memory_allocation(
+ var->tensor()->AllocatedBytes());
+ }
+ }
var->Unref();
}
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index f35a1bb648..032ede6459 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1876,604 +1876,6 @@ size: The number of incomplete elements (i.e. those with some of their value
// --------------------------------------------------------------------------
-REGISTER_OP("LookupTableFind")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("default_value: Tout")
- .Output("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
- return Status::OK();
- })
- .Doc(R"doc(
-Looks up keys in a table, outputs the corresponding values.
-
-The tensor `keys` must of the same type as the keys of the table.
-The output `values` is of the type of the table values.
-
-The scalar `default_value` is the value output for keys not present in the
-table. It must also be of the same type as the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Same shape as `keys`. Values found in the table, or `default_values`
- for missing keys.
-)doc");
-
-REGISTER_OP("LookupTableFindV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("default_value: Tout")
- .Output("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
- return Status::OK();
- })
- .Doc(R"doc(
-Looks up keys in a table, outputs the corresponding values.
-
-The tensor `keys` must of the same type as the keys of the table.
-The output `values` is of the type of the table values.
-
-The scalar `default_value` is the value output for keys not present in the
-table. It must also be of the same type as the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Same shape as `keys`. Values found in the table, or `default_values`
- for missing keys.
-)doc");
-
-REGISTER_OP("LookupTableInsert")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // TODO(ebrevdo): Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Updates the table to associates keys with values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableInsertV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // TODO: Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Updates the table to associates keys with values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableSize")
- .Input("table_handle: Ref(string)")
- .Output("size: int64")
- .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
- .Doc(R"doc(
-Computes the number of elements in the given table.
-
-table_handle: Handle to the table.
-size: Scalar that contains number of elements in the table.
-)doc");
-
-REGISTER_OP("LookupTableSizeV2")
- .Input("table_handle: resource")
- .Output("size: int64")
- .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
- .Doc(R"doc(
-Computes the number of elements in the given table.
-
-table_handle: Handle to the table.
-size: Scalar that contains number of elements in the table.
-)doc");
-
-REGISTER_OP("LookupTableExport")
- .Input("table_handle: Ref(string)")
- .Output("keys: Tkeys")
- .Output("values: Tvalues")
- .Attr("Tkeys: type")
- .Attr("Tvalues: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
- c->set_output(0, keys);
- c->set_output(1, values);
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs all keys and values in the table.
-
-table_handle: Handle to the table.
-keys: Vector of all keys present in the table.
-values: Tensor of all values in the table. Indexed in parallel with `keys`.
-)doc");
-
-REGISTER_OP("LookupTableExportV2")
- .Input("table_handle: resource")
- .Output("keys: Tkeys")
- .Output("values: Tvalues")
- .Attr("Tkeys: type")
- .Attr("Tvalues: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
- c->set_output(0, keys);
- c->set_output(1, values);
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs all keys and values in the table.
-
-table_handle: Handle to the table.
-keys: Vector of all keys present in the table.
-values: Tensor of all values in the table. Indexed in parallel with `keys`.
-)doc");
-
-REGISTER_OP("LookupTableImport")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // TODO(ebrevdo): Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Replaces the contents of the table with the specified keys and values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableImportV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // TODO: Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Replaces the contents of the table with the specified keys and values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("HashTable")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates a non-initialized hash table.
-
-This op creates a hash table, specifying the type of its keys and values.
-Before using the table you will have to initialize it. After initialization the
-table will be immutable.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("HashTableV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates a non-initialized hash table.
-
-This op creates a hash table, specifying the type of its keys and values.
-Before using the table you will have to initialize it. After initialization the
-table will be immutable.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTable")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableOfTensors")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a vector. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableOfTensorsV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a vector. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableDenseHashTable")
- .Input("empty_key: key_dtype")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .Attr("initial_num_buckets: int = 131072") // 2^17
- .Attr("max_load_factor: float = 0.8")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table that uses tensors as the backing store. It uses
-"open addressing" with quadratic reprobing to resolve collisions.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-empty_key: The key used to represent empty key buckets internally. Must not
- be used in insert or lookup operations.
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-value_shape: The shape of each value.
-initial_num_buckets: The initial number of hash table buckets. Must be a power
- to 2.
-max_load_factor: The maximum ratio between number of entries and number of
- buckets before growing the table. Must be between 0 and 1.
-)doc");
-
-REGISTER_OP("MutableDenseHashTableV2")
- .Input("empty_key: key_dtype")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .Attr("initial_num_buckets: int = 131072") // 2^17
- .Attr("max_load_factor: float = 0.8")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table that uses tensors as the backing store. It uses
-"open addressing" with quadratic reprobing to resolve collisions.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-empty_key: The key used to represent empty key buckets internally. Must not
- be used in insert or lookup operations.
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-value_shape: The shape of each value.
-initial_num_buckets: The initial number of hash table buckets. Must be a power
- to 2.
-max_load_factor: The maximum ratio between number of entries and number of
- buckets before growing the table. Must be between 0 and 1.
-)doc");
-
-REGISTER_OP("InitializeTable")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tkey")
- .Input("values: Tval")
- .Attr("Tkey: type")
- .Attr("Tval: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- ShapeHandle keys;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
- TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
- return Status::OK();
- })
- .Doc(R"doc(
-Table initializer that takes two tensors for keys and values respectively.
-
-table_handle: Handle to a table which will be initialized.
-keys: Keys of type Tkey.
-values: Values of type Tval.
-)doc");
-
-REGISTER_OP("InitializeTableV2")
- .Input("table_handle: resource")
- .Input("keys: Tkey")
- .Input("values: Tval")
- .Attr("Tkey: type")
- .Attr("Tval: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle keys;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
- TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
- return Status::OK();
- })
- .Doc(R"doc(
-Table initializer that takes two tensors for keys and values respectively.
-
-table_handle: Handle to a table which will be initialized.
-keys: Keys of type Tkey.
-values: Values of type Tval.
-)doc");
-
-REGISTER_OP("InitializeTableFromTextFile")
- .Input("table_handle: Ref(string)")
- .Input("filename: string")
- .Attr("key_index: int >= -2")
- .Attr("value_index: int >= -2")
- .Attr("vocab_size: int >= -1 = -1")
- .Attr("delimiter: string = '\t'")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
- return Status::OK();
- })
- .Doc(R"doc(
-Initializes a table from a text file.
-
-It inserts one key-value pair into the table for each line of the file.
-The key and value is extracted from the whole line content, elements from the
-split line based on `delimiter` or the line number (starting from zero).
-Where to extract the key and value from a line is specified by `key_index` and
-`value_index`.
-
-- A value of -1 means use the line number(starting from zero), expects `int64`.
-- A value of -2 means use the whole line content, expects `string`.
-- A value >= 0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
-table_handle: Handle to a table which will be initialized.
-filename: Filename of a vocabulary text file.
-key_index: Column index in a line to get the table `key` values from.
-value_index: Column index that represents information of a line to get the table
- `value` values from.
-vocab_size: Number of elements of the file, use -1 if unknown.
-delimiter: Delimiter to separate fields in a line.
-)doc");
-
-REGISTER_OP("InitializeTableFromTextFileV2")
- .Input("table_handle: resource")
- .Input("filename: string")
- .Attr("key_index: int >= -2")
- .Attr("value_index: int >= -2")
- .Attr("vocab_size: int >= -1 = -1")
- .Attr("delimiter: string = '\t'")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
- return Status::OK();
- })
- .Doc(R"doc(
-Initializes a table from a text file.
-
-It inserts one key-value pair into the table for each line of the file.
-The key and value is extracted from the whole line content, elements from the
-split line based on `delimiter` or the line number (starting from zero).
-Where to extract the key and value from a line is specified by `key_index` and
-`value_index`.
-
-- A value of -1 means use the line number(starting from zero), expects `int64`.
-- A value of -2 means use the whole line content, expects `string`.
-- A value >= 0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
-table_handle: Handle to a table which will be initialized.
-filename: Filename of a vocabulary text file.
-key_index: Column index in a line to get the table `key` values from.
-value_index: Column index that represents information of a line to get the table
- `value` values from.
-vocab_size: Number of elements of the file, use -1 if unknown.
-delimiter: Delimiter to separate fields in a line.
-)doc");
-
REGISTER_OP("GetSessionHandle")
.Input("value: T")
.Output("handle: string")
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
new file mode 100644
index 0000000000..498a65690d
--- /dev/null
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -0,0 +1,666 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+// --------------------------------------------------------------------------
+
+namespace {
+Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_handle;
+ for (int i = 0; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
+Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+ for (int i = 1; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
+Status TwoElementOutput(InferenceContext* c) {
+ c->set_output(0, c->Vector(2));
+ return Status::OK();
+}
+
+Status ScalarOutput(InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+}
+} // namespace
+
+REGISTER_OP("LookupTableFind")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("default_value: Tout")
+ .Output("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // Default value must be scalar or vector.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Looks up keys in a table, outputs the corresponding values.
+
+The tensor `keys` must of the same type as the keys of the table.
+The output `values` is of the type of the table values.
+
+The scalar `default_value` is the value output for keys not present in the
+table. It must also be of the same type as the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Same shape as `keys`. Values found in the table, or `default_values`
+ for missing keys.
+)doc");
+
+REGISTER_OP("LookupTableFindV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("default_value: Tout")
+ .Output("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // Default value must be scalar or vector.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Looks up keys in a table, outputs the corresponding values.
+
+The tensor `keys` must of the same type as the keys of the table.
+The output `values` is of the type of the table values.
+
+The scalar `default_value` is the value output for keys not present in the
+table. It must also be of the same type as the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Same shape as `keys`. Values found in the table, or `default_values`
+ for missing keys.
+)doc");
+
+REGISTER_OP("LookupTableInsert")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the table to associates keys with values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableInsertV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // TODO: Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the table to associates keys with values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableSize")
+ .Input("table_handle: Ref(string)")
+ .Output("size: int64")
+ .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: Handle to the table.
+size: Scalar that contains number of elements in the table.
+)doc");
+
+REGISTER_OP("LookupTableSizeV2")
+ .Input("table_handle: resource")
+ .Output("size: int64")
+ .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: Handle to the table.
+size: Scalar that contains number of elements in the table.
+)doc");
+
+REGISTER_OP("LookupTableExport")
+ .Input("table_handle: Ref(string)")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ c->set_output(0, keys);
+ c->set_output(1, values);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
+REGISTER_OP("LookupTableExportV2")
+ .Input("table_handle: resource")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ c->set_output(0, keys);
+ c->set_output(1, values);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
+REGISTER_OP("LookupTableImport")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Replaces the contents of the table with the specified keys and values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableImportV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // TODO: Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Replaces the contents of the table with the specified keys and values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("HashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates a non-initialized hash table.
+
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("HashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates a non-initialized hash table.
+
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableOfTensors")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableOfTensorsV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableDenseHashTable")
+ .Input("empty_key: key_dtype")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .Attr("initial_num_buckets: int = 131072") // 2^17
+ .Attr("max_load_factor: float = 0.8")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table that uses tensors as the backing store. It uses
+"open addressing" with quadratic reprobing to resolve collisions.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+empty_key: The key used to represent empty key buckets internally. Must not
+ be used in insert or lookup operations.
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+value_shape: The shape of each value.
+initial_num_buckets: The initial number of hash table buckets. Must be a power
+ to 2.
+max_load_factor: The maximum ratio between number of entries and number of
+ buckets before growing the table. Must be between 0 and 1.
+)doc");
+
+REGISTER_OP("MutableDenseHashTableV2")
+ .Input("empty_key: key_dtype")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .Attr("initial_num_buckets: int = 131072") // 2^17
+ .Attr("max_load_factor: float = 0.8")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table that uses tensors as the backing store. It uses
+"open addressing" with quadratic reprobing to resolve collisions.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+empty_key: The key used to represent empty key buckets internally. Must not
+ be used in insert or lookup operations.
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+value_shape: The shape of each value.
+initial_num_buckets: The initial number of hash table buckets. Must be a power
+ to 2.
+max_load_factor: The maximum ratio between number of entries and number of
+ buckets before growing the table. Must be between 0 and 1.
+)doc");
+
+REGISTER_OP("InitializeTable")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: Handle to a table which will be initialized.
+keys: Keys of type Tkey.
+values: Values of type Tval.
+)doc");
+
+REGISTER_OP("InitializeTableV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: Handle to a table which will be initialized.
+keys: Keys of type Tkey.
+values: Values of type Tval.
+)doc");
+
+REGISTER_OP("InitializeTableFromTextFile")
+ .Input("table_handle: Ref(string)")
+ .Input("filename: string")
+ .Attr("key_index: int >= -2")
+ .Attr("value_index: int >= -2")
+ .Attr("vocab_size: int >= -1 = -1")
+ .Attr("delimiter: string = '\t'")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initializes a table from a text file.
+
+It inserts one key-value pair into the table for each line of the file.
+The key and value is extracted from the whole line content, elements from the
+split line based on `delimiter` or the line number (starting from zero).
+Where to extract the key and value from a line is specified by `key_index` and
+`value_index`.
+
+- A value of -1 means use the line number(starting from zero), expects `int64`.
+- A value of -2 means use the whole line content, expects `string`.
+- A value >= 0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+table_handle: Handle to a table which will be initialized.
+filename: Filename of a vocabulary text file.
+key_index: Column index in a line to get the table `key` values from.
+value_index: Column index that represents information of a line to get the table
+ `value` values from.
+vocab_size: Number of elements of the file, use -1 if unknown.
+delimiter: Delimiter to separate fields in a line.
+)doc");
+
+REGISTER_OP("InitializeTableFromTextFileV2")
+ .Input("table_handle: resource")
+ .Input("filename: string")
+ .Attr("key_index: int >= -2")
+ .Attr("value_index: int >= -2")
+ .Attr("vocab_size: int >= -1 = -1")
+ .Attr("delimiter: string = '\t'")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initializes a table from a text file.
+
+It inserts one key-value pair into the table for each line of the file.
+The key and value is extracted from the whole line content, elements from the
+split line based on `delimiter` or the line number (starting from zero).
+Where to extract the key and value from a line is specified by `key_index` and
+`value_index`.
+
+- A value of -1 means use the line number(starting from zero), expects `int64`.
+- A value of -2 means use the whole line content, expects `string`.
+- A value >= 0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+table_handle: Handle to a table which will be initialized.
+filename: Filename of a vocabulary text file.
+key_index: Column index in a line to get the table `key` values from.
+value_index: Column index that represents information of a line to get the table
+ `value` values from.
+vocab_size: Number of elements of the file, use -1 if unknown.
+delimiter: Delimiter to separate fields in a line.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/protobuf/cluster.proto b/tensorflow/core/protobuf/cluster.proto
new file mode 100644
index 0000000000..33c87eefe0
--- /dev/null
+++ b/tensorflow/core/protobuf/cluster.proto
@@ -0,0 +1,82 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "ClusterProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+// This file contains protos to be used when defining a TensorFlow
+// cluster.
+//
+// EXAMPLES
+// --------
+//
+// 1. A single-process cluster, containing "/job:local/task:0".
+//
+// Cluster:
+// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
+//
+// Server:
+// cluster { $CLUSTER } job_name: 'local' task_index: 0
+//
+// 2. A two-process cluster, containing "/job:local/task:{0,1}".
+//
+// Cluster:
+// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
+// tasks { key: 1 value: 'localhost:2223' } }
+//
+// Servers:
+// cluster { $CLUSTER } job_name: 'local' task_index: 0
+// cluster { $CLUSTER } job_name: 'local' task_index: 1
+//
+// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
+// "/job:ps/task:{0,1}".
+//
+// Cluster:
+// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
+// tasks { key: 1 value: 'worker2:2222' }
+// tasks { key: 2 value: 'worker3:2222' } }
+// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+// tasks { key: 1 value: 'ps1:2222' } }
+//
+// Servers:
+// cluster { $CLUSTER } job_name: 'worker' task_index: 0
+// cluster { $CLUSTER } job_name: 'worker' task_index: 1
+// cluster { $CLUSTER } job_name: 'worker' task_index: 2
+// cluster { $CLUSTER } job_name: 'ps' task_index: 0
+// cluster { $CLUSTER } job_name: 'ps' task_index: 1
+
+// Defines a single job in a TensorFlow cluster.
+message JobDef {
+ // The name of this job.
+ string name = 1;
+
+ // Mapping from task ID to "hostname:port" string.
+ //
+ // If the `name` field contains "worker", and the `tasks` map contains a
+ // mapping from 7 to "example.org:2222", then the device prefix
+ // "/job:worker/task:7" will be assigned to "example.org:2222".
+ map<int32, string> tasks = 2;
+}
+
+// Defines a TensorFlow cluster as a set of jobs.
+message ClusterDef {
+ // The jobs that comprise the cluster.
+ repeated JobDef job = 1;
+}
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 5c0f7232eb..630f47633f 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/protobuf/debug.proto";
+import "tensorflow/core/protobuf/cluster.proto";
import "tensorflow/core/protobuf/rewriter_config.proto";
message GPUOptions {
@@ -259,6 +260,11 @@ message ConfigProto {
// Options that apply when this session uses the distributed runtime.
RPCOptions rpc_options = 13;
+
+ // Optional list of all workers to use in this session.
+ ClusterDef cluster_def = 14;
+
+ // Next: 15
};
// Options for a single Run() call.
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index de91b6133e..e607b1c42a 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -38,6 +38,9 @@ message CreateSessionRequest {
// Configuration options.
ConfigProto config = 2;
+
+ // The target string used from the client's perspective.
+ string target = 3;
}
message CreateSessionResponse {
diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto
index c4077bd98e..6199e707e5 100644
--- a/tensorflow/core/protobuf/tensorflow_server.proto
+++ b/tensorflow/core/protobuf/tensorflow_server.proto
@@ -16,6 +16,7 @@ limitations under the License.
syntax = "proto3";
import "tensorflow/core/protobuf/config.proto";
+import "tensorflow/core/protobuf/cluster.proto";
package tensorflow;
option cc_enable_arenas = true;
@@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
-// This file contains protos to be used when defining a TensorFlow
-// cluster, and a server within that cluster.
-//
-// EXAMPLES
-// --------
-//
-// 1. A single-process cluster, containing "/job:local/task:0".
-//
-// Cluster:
-// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
-//
-// Server:
-// cluster { $CLUSTER } job_name: 'local' task_index: 0
-//
-// 2. A two-process cluster, containing "/job:local/task:{0,1}".
-//
-// Cluster:
-// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
-// tasks { key: 1 value: 'localhost:2223' } }
-//
-// Servers:
-// cluster { $CLUSTER } job_name: 'local' task_index: 0
-// cluster { $CLUSTER } job_name: 'local' task_index: 1
-//
-// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
-// "/job:ps/task:{0,1}".
-//
-// Cluster:
-// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
-// tasks { key: 1 value: 'worker2:2222' }
-// tasks { key: 2 value: 'worker3:2222' } }
-// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
-// tasks { key: 1 value: 'ps1:2222' } }
-//
-// Servers:
-// cluster { $CLUSTER } job_name: 'worker' task_index: 0
-// cluster { $CLUSTER } job_name: 'worker' task_index: 1
-// cluster { $CLUSTER } job_name: 'worker' task_index: 2
-// cluster { $CLUSTER } job_name: 'ps' task_index: 0
-// cluster { $CLUSTER } job_name: 'ps' task_index: 1
-
-// Defines a single job in a TensorFlow cluster.
-message JobDef {
- // The name of this job.
- string name = 1;
-
- // Mapping from task ID to "hostname:port" string.
- //
- // If the `name` field contains "worker", and the `tasks` map contains a
- // mapping from 7 to "example.org:2222", then the device prefix
- // "/job:worker/task:7" will be assigned to "example.org:2222".
- //
- // NOTE(mrry): Currently, only a dense task ID space starting at 0 is
- // supported.
- map<int32, string> tasks = 2;
-}
-
-// Defines a TensorFlow cluster as a set of jobs.
-message ClusterDef {
- // The jobs that comprise the cluster.
- repeated JobDef job = 1;
-}
-
// Defines the configuration of a single TensorFlow server.
message ServerDef {
// The cluster of which this server is a member.
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 661327847c..cf05aece39 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -119,6 +119,10 @@ message RegisterGraphResponse {
////////////////////////////////////////////////////////////////////////////////
message DeregisterGraphRequest {
+ // The session_handle used when registering the graph. If session_handle is
+ // empty, a single global namespace is used.
+ string session_handle = 2;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
@@ -167,6 +171,12 @@ message ExecutorOpts {
};
message RunGraphRequest {
+ // session_handle is the the master-generated unique id for this session.
+ // If session_handle is non-empty, it must be the same as used when
+ // registering the graph. If it is empty, a single global namespace is used to
+ // search for the graph_handle.
+ string session_handle = 8;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
@@ -193,6 +203,8 @@ message RunGraphRequest {
bool is_partial = 6;
// True if this is the last partial run request in a sequence of requests.
bool is_last_partial_run = 7;
+
+ // Next: 9
}
message RunGraphResponse {
diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md
index b52adc3790..00cc10cd34 100644
--- a/tensorflow/docs_src/get_started/get_started.md
+++ b/tensorflow/docs_src/get_started/get_started.md
@@ -372,25 +372,36 @@ features = [tf.contrib.layers.real_valued_column("x", dimension=1)]
estimator = tf.contrib.learn.LinearRegressor(feature_columns=features)
# TensorFlow provides many helper methods to read and set up data sets.
-# Here we use `numpy_input_fn`. We have to tell the function how many batches
+# Here we use two data sets: one for training and one for evaluation
+# We have to tell the function how many batches
# of data (num_epochs) we want and how big each batch should be.
-x = np.array([1., 2., 3., 4.])
-y = np.array([0., -1., -2., -3.])
-input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x}, y, batch_size=4,
+x_train = np.array([1., 2., 3., 4.])
+y_train = np.array([0., -1., -2., -3.])
+x_eval = np.array([2., 5., 8., 1.])
+y_eval = np.array([-1.01, -4.1, -7, 0.])
+input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train,
+ batch_size=4,
num_epochs=1000)
+eval_input_fn = tf.contrib.learn.io.numpy_input_fn(
+ {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000)
-# We can invoke 1000 training steps by invoking the `fit` method and passing the
+# We can invoke 1000 training steps by invoking the method and passing the
# training data set.
estimator.fit(input_fn=input_fn, steps=1000)
-# Here we evaluate how well our model did. In a real example, we would want
-# to use a separate validation and testing data set to avoid overfitting.
-print(estimator.evaluate(input_fn=input_fn))
+# Here we evaluate how well our model did.
+train_loss = estimator.evaluate(input_fn=input_fn)
+eval_loss = estimator.evaluate(input_fn=eval_input_fn)
+print("train loss: %r"% train_loss)
+print("eval loss: %r"% eval_loss)
```
When run, it produces
```
- {'global_step': 1000, 'loss': 1.9650059e-11}
+ train loss: {'global_step': 1000, 'loss': 4.3049088e-08}
+ eval loss: {'global_step': 1000, 'loss': 0.0025487561}
```
+Notice how our eval data has a higher loss, but it is still close to zero.
+That means we are learning properly.
### A custom model
@@ -432,19 +443,25 @@ def model(features, labels, mode):
train_op=train)
estimator = tf.contrib.learn.Estimator(model_fn=model)
-# define our data set
-x = np.array([1., 2., 3., 4.])
-y = np.array([0., -1., -2., -3.])
-input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x}, y, 4, num_epochs=1000)
+# define our data sets
+x_train = np.array([1., 2., 3., 4.])
+y_train = np.array([0., -1., -2., -3.])
+x_eval = np.array([2., 5., 8., 1.])
+y_eval = np.array([-1.01, -4.1, -7, 0.])
+input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x_train}, y_train, 4, num_epochs=1000)
# train
estimator.fit(input_fn=input_fn, steps=1000)
-# evaluate our model
-print(estimator.evaluate(input_fn=input_fn, steps=10))
+# Here we evaluate how well our model did.
+train_loss = estimator.evaluate(input_fn=input_fn)
+eval_loss = estimator.evaluate(input_fn=eval_input_fn)
+print("train loss: %r"% train_loss)
+print("eval loss: %r"% eval_loss)
```
When run, it produces
-```python
-{'loss': 5.9819476e-11, 'global_step': 1000}
+```
+train loss: {'global_step': 1000, 'loss': 4.9380226e-11}
+eval loss: {'global_step': 1000, 'loss': 0.01010081}
```
Notice how the contents of the custom `model()` function are very similar
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 5304779c00..55d9c2c08f 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -218,11 +218,7 @@ and Mac OS X:
And the following comand line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
-
-And the following comand line executes the `HelloTF` program on Windows:
-
-<pre><b>java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.1.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md
index 309b39451f..acdca2bad4 100644
--- a/tensorflow/docs_src/programmers_guide/index.md
+++ b/tensorflow/docs_src/programmers_guide/index.md
@@ -39,6 +39,11 @@ trained graph. The following guide details `MetaGraph` objects:
* @{$meta_graph$Exporting and Importing a MetaGraph}.
+`SavedModel` is the universal serialization format for Tensorflow models. TensorFlow provides SavedModel CLI (command-line interface) as a tool to inspect and execute a MetaGraph in a SavedModel. The detailed usages and examples are
+documented in the following guide:
+
+ * @{$saved_model_cli$SavedModel CLI (Command-Line Interface)}.
+
To learn about the TensorFlow versioning scheme, consult the following two
guides:
diff --git a/tensorflow/docs_src/programmers_guide/supervisor.md b/tensorflow/docs_src/programmers_guide/supervisor.md
index 82ed1c2cf7..55a090df58 100644
--- a/tensorflow/docs_src/programmers_guide/supervisor.md
+++ b/tensorflow/docs_src/programmers_guide/supervisor.md
@@ -362,8 +362,8 @@ following keyword arguments to the `Supervisor()` constructor:
If not specified, the supervisor uses the first op in the
`tf.GraphKeys.LOCAL_INIT_OP` collection. If the collection is empty the
supervisor adds an op to initialize all the tables and local variables in
- the graph by calling `tf.initialize_all_tables()` and
- `tf.initialize_all_local_variables()`.
+ the graph by calling `tf.tables_initializer()` and
+ `tf.local_variables_initializer()`.
Pass `None` to not use a local init op.
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java
index b26a231678..bc39126925 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java
@@ -194,13 +194,12 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
- rgbBytes,
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
- false);
+ rgbBytes);
image.close();
} catch (final Exception e) {
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
index 206a99f3e3..5800f80651 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
@@ -124,7 +124,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
- tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
+ tracker = new MultiBoxTracker(this);
if (USE_YOLO) {
detector =
@@ -273,13 +273,12 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
- rgbBytes,
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
- false);
+ rgbBytes);
image.close();
} catch (final Exception e) {
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
index 7634be5c02..7afe2bf541 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
@@ -65,10 +65,6 @@ import org.tensorflow.demo.R;
* Artistic Style" (https://arxiv.org/abs/1610.07629)
*/
public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
- static {
- System.loadLibrary("tensorflow_demo");
- }
-
private static final Logger LOGGER = new Logger();
private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
@@ -509,17 +505,17 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL
final int yRowStride = planes[0].getRowStride();
final int uvRowStride = planes[1].getRowStride();
final int uvPixelStride = planes[1].getPixelStride();
+
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
- rgbBytes,
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
- false);
+ rgbBytes);
image.close();
} catch (final Exception e) {
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
index f3e7114335..1dcf9f55ef 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -41,10 +41,6 @@ import org.tensorflow.demo.env.Logger;
public class TensorFlowMultiBoxDetector implements Classifier {
private static final Logger LOGGER = new Logger();
- static {
- System.loadLibrary("tensorflow_demo");
- }
-
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = Integer.MAX_VALUE;
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
index 174723071d..b7e36a2379 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
@@ -31,10 +31,6 @@ import org.tensorflow.demo.env.SplitTimer;
public class TensorFlowYoloDetector implements Classifier {
private static final Logger LOGGER = new Logger();
- static {
- System.loadLibrary("tensorflow_demo");
- }
-
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 5;
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java
index db929e5e08..5f2ff9164c 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java
@@ -27,6 +27,14 @@ import java.io.FileOutputStream;
public class ImageUtils {
@SuppressWarnings("unused")
private static final Logger LOGGER = new Logger();
+
+ static {
+ try {
+ System.loadLibrary("tensorflow_demo");
+ } catch (UnsatisfiedLinkError e) {
+ LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable.");
+ }
+ }
/**
* Utility method to compute the allocated size in bytes of a YUV420SP image
@@ -83,10 +91,84 @@ public class ImageUtils {
}
}
+ // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
+ // are normalized to eight bits.
+ static final int kMaxChannelValue = 262143;
+
+ // Always prefer the native implementation if available.
+ private static boolean useNativeConversion = true;
+
+ public static void convertYUV420ToARGB8888(
+ byte[] yData,
+ byte[] uData,
+ byte[] vData,
+ int width,
+ int height,
+ int yRowStride,
+ int uvRowStride,
+ int uvPixelStride,
+ int[] out) {
+ if (useNativeConversion) {
+ try {
+ convertYUV420ToARGB8888(
+ yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false);
+ return;
+ } catch (UnsatisfiedLinkError e) {
+ LOGGER.w("Native YUV -> RGB implementation not found, falling back to Java implementation");
+ useNativeConversion = false;
+ }
+ }
+
+ int i = 0;
+ for (int y = 0; y < height; y++) {
+ int pY = yRowStride * y;
+ int uv_row_start = uvRowStride * (y >> 1);
+ int pUV = uv_row_start;
+ int pV = uv_row_start;
+
+ for (int x = 0; x < width; x++) {
+ int uv_offset = pUV + (x >> 1) * uvPixelStride;
+ out[i++] =
+ YUV2RGB(
+ convertByteToInt(yData, pY + x),
+ convertByteToInt(uData, uv_offset),
+ convertByteToInt(vData, uv_offset));
+ }
+ }
+ }
+
+ private static int convertByteToInt(byte[] arr, int pos) {
+ return arr[pos] & 0xFF;
+ }
+
+ private static int YUV2RGB(int nY, int nU, int nV) {
+ nY -= 16;
+ nU -= 128;
+ nV -= 128;
+ if (nY < 0) nY = 0;
+
+ // This is the floating point equivalent. We do the conversion in integer
+ // because some Android devices do not have floating point in hardware.
+ // nR = (int)(1.164 * nY + 2.018 * nU);
+ // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
+ // nB = (int)(1.164 * nY + 1.596 * nV);
+
+ final int foo = 1192 * nY;
+ int nR = foo + 1634 * nV;
+ int nG = foo - 833 * nV - 400 * nU;
+ int nB = foo + 2066 * nU;
+
+ nR = Math.min(kMaxChannelValue, Math.max(0, nR));
+ nG = Math.min(kMaxChannelValue, Math.max(0, nG));
+ nB = Math.min(kMaxChannelValue, Math.max(0, nB));
+
+ return 0xff000000 | ((nR << 6) & 0x00ff0000) | ((nG >> 2) & 0x0000FF00) | ((nB >> 10) & 0xff);
+ }
+
/**
- * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
- * and height. The input and output must already be allocated and non-null.
- * For efficiency, no error checking is performed.
+ * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The
+ * input and output must already be allocated and non-null. For efficiency, no error checking is
+ * performed.
*
* @param input The array of YUV 4:2:0 input data.
* @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
index 49c91d600d..91d1f9feb1 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
@@ -15,6 +15,7 @@ limitations under the License.
package org.tensorflow.demo.tracking;
+import android.content.Context;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
@@ -24,9 +25,9 @@ import android.graphics.Paint.Join;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.text.TextUtils;
-import android.util.DisplayMetrics;
import android.util.Pair;
import android.util.TypedValue;
+import android.widget.Toast;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
@@ -69,6 +70,7 @@ public class MultiBoxTracker {
private static class TrackedRecognition {
ObjectTracker.TrackedObject trackedObject;
+ RectF location;
float detectionConfidence;
int color;
String title;
@@ -87,8 +89,10 @@ public class MultiBoxTracker {
private int frameHeight;
private int sensorOrientation;
+ private Context context;
- public MultiBoxTracker(final DisplayMetrics metrics) {
+ public MultiBoxTracker(final Context context) {
+ this.context = context;
for (final int color : COLORS) {
availableColors.add(color);
}
@@ -100,7 +104,9 @@ public class MultiBoxTracker {
boxPaint.setStrokeJoin(Join.ROUND);
boxPaint.setStrokeMiter(100);
- textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics);
+ textSizePx =
+ TypedValue.applyDimension(
+ TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
}
@@ -152,10 +158,6 @@ public class MultiBoxTracker {
}
public synchronized void draw(final Canvas canvas) {
- if (objectTracker == null) {
- return;
- }
-
// TODO(andrewharp): This may not work for non-90 deg rotations.
final float multiplier =
Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth);
@@ -168,9 +170,11 @@ public class MultiBoxTracker {
sensorOrientation,
false);
for (final TrackedRecognition recognition : trackedObjects) {
- final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+ final RectF trackedPos =
+ (objectTracker != null)
+ ? recognition.trackedObject.getTrackedPositionInPreviewFrame()
+ : new RectF(recognition.location);
- final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
getFrameToCanvasMatrix().mapRect(trackedPos);
boxPaint.setColor(recognition.color);
@@ -185,6 +189,8 @@ public class MultiBoxTracker {
}
}
+ private boolean initialized = false;
+
public synchronized void onFrame(
final int w,
final int h,
@@ -192,7 +198,7 @@ public class MultiBoxTracker {
final int sensorOrienation,
final byte[] frame,
final long timestamp) {
- if (objectTracker == null) {
+ if (objectTracker == null && !initialized) {
ObjectTracker.clearInstance();
logger.i("Initializing ObjectTracker: %dx%d", w, h);
@@ -200,6 +206,19 @@ public class MultiBoxTracker {
frameWidth = w;
frameHeight = h;
this.sensorOrientation = sensorOrienation;
+ initialized = true;
+
+ if (objectTracker == null) {
+ String message =
+ "Object tracking support not found. "
+ + "See tensorflow/examples/android/README.md for details.";
+ Toast.makeText(context, message, Toast.LENGTH_LONG).show();
+ logger.e(message);
+ }
+ }
+
+ if (objectTracker == null) {
+ return;
}
objectTracker.nextFrame(frame, null, timestamp, null, true);
@@ -255,7 +274,20 @@ public class MultiBoxTracker {
}
if (objectTracker == null) {
- logger.w("No ObjectTracker, can't track anything!");
+ trackedObjects.clear();
+ for (final Pair<Float, Recognition> potential : rectsToTrack) {
+ final TrackedRecognition trackedRecognition = new TrackedRecognition();
+ trackedRecognition.detectionConfidence = potential.first;
+ trackedRecognition.location = new RectF(potential.second.getLocation());
+ trackedRecognition.trackedObject = null;
+ trackedRecognition.title = potential.second.getTitle();
+ trackedRecognition.color = COLORS[trackedObjects.size()];
+ trackedObjects.add(trackedRecognition);
+
+ if (trackedObjects.size() >= COLORS.length) {
+ break;
+ }
+ }
return;
}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
index 82de634baf..69f202b568 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
@@ -48,7 +48,18 @@ import org.tensorflow.demo.env.Size;
* ObjectTracker still exists.
*/
public class ObjectTracker {
- private final Logger logger = new Logger();
+ private static final Logger LOGGER = new Logger();
+
+ private static boolean libraryFound = false;
+
+ static {
+ try {
+ System.loadLibrary("tensorflow_demo");
+ libraryFound = true;
+ } catch (UnsatisfiedLinkError e) {
+ LOGGER.e("libtensorflow_demo.so not found, tracking unavailable");
+ }
+ }
private static final boolean DRAW_TEXT = false;
@@ -194,6 +205,13 @@ public class ObjectTracker {
public static synchronized ObjectTracker getInstance(
final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ if (!libraryFound) {
+ LOGGER.e(
+ "Native object tracking support not found. "
+ + "See tensorflow/examples/android/README.md for details.");
+ return null;
+ }
+
if (instance == null) {
instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
instance.init();
@@ -519,7 +537,7 @@ public class ObjectTracker {
checkValidObject();
synchronized (ObjectTracker.this) {
if (lastExternalPositionTime > timestamp) {
- logger.w("Tried to use older position time!");
+ LOGGER.w("Tried to use older position time!");
return;
}
final RectF externalPosition = downscaleRect(position);
@@ -640,8 +658,4 @@ public class ObjectTracker {
protected static native void downsampleImageNative(
int width, int height, int rowStride, byte[] input, int factor, byte[] output);
-
- static {
- System.loadLibrary("tensorflow_demo");
- }
}
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index e32c21ca72..a1b4255292 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -9,24 +9,22 @@ Construct and execute TensorFlow graphs in Go.
> (`github.com/tensorflow/tensorflow/tensorflow/go`).
## Quickstart
-
1. Download and extract the TensorFlow C library, preferably into `/usr/local`.
GPU-enabled versions require CUDA 8.0 and cuDNN 5.1. For other versions, the
TensorFlow C library will have to be built from source (see below).
- Linux:
- [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.0.0.tar.gz),
- [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.0.0.tar.gz)
+ [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.1.0.tar.gz),
+ [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.1.0.tar.gz)
- OS X
- [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.0.0.tar.gz),
- [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-darwin-x86_64-1.0.0.tar.gz)
+ [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.1.0.tar.gz),
The following shell snippet downloads and extracts into `/usr/local`:
```sh
TF_TYPE="cpu" # Set to "gpu" for GPU support
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.0.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0.tar.gz" |
sudo tar -C /usr/local -xz
```
@@ -41,20 +39,7 @@ Construct and execute TensorFlow graphs in Go.
### Installing into locations other than `/usr/local`
-The TensorFlow C library (`libtensorflow.so`) needs to be available at build
-time (e.g., `go build`) and run time (`go test` or executing binaries). If the
-library has not been extracted into `/usr/local`, then it needs to be made
-available through the `LIBRARY_PATH` environment variable at build time and the
-`LD_LIBRARY_PATH` environment variable (`DYLD_LIBRARY_PATH` on OS X) at run
-time.
-
-For example, if the TensorFlow C library was extracted into `/dir`, then:
-
-```sh
-export LIBRARY_PATH=/dir/lib
-export LD_LIBRARY_PATH=/dir/lib # For Linux
-export DYLD_LIBRARY_PATH=/dir/lib # For OS X
-```
+Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go)
## Building the TensorFlow C library from source
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index c63be8bc5e..eb4789a182 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3522,256 +3522,6 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera
return scope.AddOperation(opspec)
}
-// Table initializer that takes two tensors for keys and values respectively.
-//
-// Arguments:
-// table_handle: Handle to a table which will be initialized.
-// keys: Keys of type Tkey.
-// values: Values of type Tval.
-//
-// Returns the created operation.
-func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InitializeTableV2",
- Input: []tf.Input{
- table_handle, keys, values,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
-type MutableHashTableV2Attr func(optionalAttr)
-
-// MutableHashTableV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this table is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func MutableHashTableV2Container(value string) MutableHashTableV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// MutableHashTableV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this table is shared under the given name across
-// multiple sessions.
-// If not specified, defaults to ""
-func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
-//
-// value: If true and shared_name is empty, the table is shared
-// using the node name.
-// If not specified, defaults to false
-func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr {
- return func(m optionalAttr) {
- m["use_node_name_sharing"] = value
- }
-}
-
-// Creates an empty hash table.
-//
-// This op creates a mutable hash table, specifying the type of its keys and
-// values. Each value must be a scalar. Data can be inserted into the table using
-// the insert operations. It does not support the initialization operation.
-//
-// Arguments:
-// key_dtype: Type of the table keys.
-// value_dtype: Type of the table values.
-//
-// Returns Handle to a table.
-func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MutableHashTableV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// HashTableV2Attr is an optional argument to HashTableV2.
-type HashTableV2Attr func(optionalAttr)
-
-// HashTableV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this table is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func HashTableV2Container(value string) HashTableV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// HashTableV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this table is shared under the given name across
-// multiple sessions.
-// If not specified, defaults to ""
-func HashTableV2SharedName(value string) HashTableV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
-//
-// value: If true and shared_name is empty, the table is shared
-// using the node name.
-// If not specified, defaults to false
-func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr {
- return func(m optionalAttr) {
- m["use_node_name_sharing"] = value
- }
-}
-
-// Creates a non-initialized hash table.
-//
-// This op creates a hash table, specifying the type of its keys and values.
-// Before using the table you will have to initialize it. After initialization the
-// table will be immutable.
-//
-// Arguments:
-// key_dtype: Type of the table keys.
-// value_dtype: Type of the table values.
-//
-// Returns Handle to a table.
-func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "HashTableV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Replaces the contents of the table with the specified keys and values.
-//
-// The tensor `keys` must be of the same type as the keys of the table.
-// The tensor `values` must be of the type of the table values.
-//
-// Arguments:
-// table_handle: Handle to the table.
-// keys: Any shape. Keys to look up.
-// values: Values to associate with keys.
-//
-// Returns the created operation.
-func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LookupTableImportV2",
- Input: []tf.Input{
- table_handle, keys, values,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// Outputs all keys and values in the table.
-//
-// Arguments:
-// table_handle: Handle to the table.
-//
-//
-//
-// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`.
-func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues}
- opspec := tf.OpSpec{
- Type: "LookupTableExportV2",
- Input: []tf.Input{
- table_handle,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Updates the table to associates keys with values.
-//
-// The tensor `keys` must be of the same type as the keys of the table.
-// The tensor `values` must be of the type of the table values.
-//
-// Arguments:
-// table_handle: Handle to the table.
-// keys: Any shape. Keys to look up.
-// values: Values to associate with keys.
-//
-// Returns the created operation.
-func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LookupTableInsertV2",
- Input: []tf.Input{
- table_handle, keys, values,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// Looks up keys in a table, outputs the corresponding values.
-//
-// The tensor `keys` must of the same type as the keys of the table.
-// The output `values` is of the type of the table values.
-//
-// The scalar `default_value` is the value output for keys not present in the
-// table. It must also be of the same type as the table values.
-//
-// Arguments:
-// table_handle: Handle to the table.
-// keys: Any shape. Keys to look up.
-//
-//
-// Returns Same shape as `keys`. Values found in the table, or `default_values`
-// for missing keys.
-func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LookupTableFindV2",
- Input: []tf.Input{
- table_handle, keys, default_value,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs.
type FakeQuantWithMinMaxArgsAttr func(optionalAttr)
@@ -5404,6 +5154,435 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou
return op.Output(0)
}
+// Draw bounding boxes on a batch of images.
+//
+// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
+// boxes specified by the locations in `boxes`. The coordinates of the each
+// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example, if an image is 100 x 200 pixels and the bounding box is
+// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the
+// bounding box will be `(10, 40)` to `(50, 180)`.
+//
+// Parts of the bounding box may fall outside the image.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
+// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
+// boxes.
+//
+// Returns 4-D with the same shape as `images`. The batch of input images with
+// bounding boxes drawn on the images.
+func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DrawBoundingBoxes",
+ Input: []tf.Input{
+ images, boxes,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Convert one or more images from HSV to RGB.
+//
+// Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+// value of the pixels. The output is only well defined if the value in `images`
+// are in `[0,1]`.
+//
+// See `rgb_to_hsv` for a description of the HSV encoding.
+//
+// Arguments:
+// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3.
+//
+// Returns `images` converted to RGB.
+func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "HSVToRGB",
+ Input: []tf.Input{
+ images,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Decode the first frame of a GIF-encoded image to a uint8 tensor.
+//
+// GIF with frame or transparency compression are not supported
+// convert animated GIF from compressed to uncompressed by:
+//
+// convert $src.gif -coalesce $dst.gif
+//
+// Arguments:
+// contents: 0-D. The GIF-encoded image.
+//
+// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order
+func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeGif",
+ Input: []tf.Input{
+ contents,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DecodePngAttr is an optional argument to DecodePng.
+type DecodePngAttr func(optionalAttr)
+
+// DecodePngChannels sets the optional channels attribute to value.
+//
+// value: Number of color channels for the decoded image.
+// If not specified, defaults to 0
+func DecodePngChannels(value int64) DecodePngAttr {
+ return func(m optionalAttr) {
+ m["channels"] = value
+ }
+}
+
+// DecodePngDtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_UINT8
+func DecodePngDtype(value tf.DataType) DecodePngAttr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Decode a PNG-encoded image to a uint8 or uint16 tensor.
+//
+// The attr `channels` indicates the desired number of color channels for the
+// decoded image.
+//
+// Accepted values are:
+//
+// * 0: Use the number of channels in the PNG-encoded image.
+// * 1: output a grayscale image.
+// * 3: output an RGB image.
+// * 4: output an RGBA image.
+//
+// If needed, the PNG-encoded image is transformed to match the requested number
+// of color channels.
+//
+// Arguments:
+// contents: 0-D. The PNG-encoded image.
+//
+// Returns 3-D with shape `[height, width, channels]`.
+func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodePng",
+ Input: []tf.Input{
+ contents,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Adjust the contrast of one or more images.
+//
+// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
+// interpreted as `[height, width, channels]`. The other dimensions only
+// represent a collection of images, such as `[batch, height, width, channels].`
+//
+// Contrast is adjusted independently for each channel of each image.
+//
+// For each channel, the Op first computes the mean of the image pixels in the
+// channel and then adjusts each component of each pixel to
+// `(x - mean) * contrast_factor + mean`.
+//
+// Arguments:
+// images: Images to adjust. At least 3-D.
+// contrast_factor: A float multiplier for adjusting contrast.
+//
+// Returns The contrast-adjusted image or images.
+func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustContrastv2",
+ Input: []tf.Input{
+ images, contrast_factor,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DecodeJpegAttr is an optional argument to DecodeJpeg.
+type DecodeJpegAttr func(optionalAttr)
+
+// DecodeJpegChannels sets the optional channels attribute to value.
+//
+// value: Number of color channels for the decoded image.
+// If not specified, defaults to 0
+func DecodeJpegChannels(value int64) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["channels"] = value
+ }
+}
+
+// DecodeJpegRatio sets the optional ratio attribute to value.
+//
+// value: Downscaling ratio.
+// If not specified, defaults to 1
+func DecodeJpegRatio(value int64) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["ratio"] = value
+ }
+}
+
+// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
+//
+// value: If true use a slower but nicer upscaling of the
+// chroma planes (yuv420/422 only).
+// If not specified, defaults to true
+func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["fancy_upscaling"] = value
+ }
+}
+
+// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value.
+//
+// value: If true try to recover an image from truncated input.
+// If not specified, defaults to false
+func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["try_recover_truncated"] = value
+ }
+}
+
+// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
+//
+// value: The minimum required fraction of lines before a truncated
+// input is accepted.
+// If not specified, defaults to 1
+func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["acceptable_fraction"] = value
+ }
+}
+
+// DecodeJpegDctMethod sets the optional dct_method attribute to value.
+//
+// value: string specifying a hint about the algorithm used for
+// decompression. Defaults to "" which maps to a system-specific
+// default. Currently valid values are ["INTEGER_FAST",
+// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
+// jpeg library changes to a version that does not have that specific
+// option.)
+// If not specified, defaults to ""
+func DecodeJpegDctMethod(value string) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["dct_method"] = value
+ }
+}
+
+// Decode a JPEG-encoded image to a uint8 tensor.
+//
+// The attr `channels` indicates the desired number of color channels for the
+// decoded image.
+//
+// Accepted values are:
+//
+// * 0: Use the number of channels in the JPEG-encoded image.
+// * 1: output a grayscale image.
+// * 3: output an RGB image.
+//
+// If needed, the JPEG-encoded image is transformed to match the requested number
+// of color channels.
+//
+// The attr `ratio` allows downscaling the image by an integer factor during
+// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
+// downscaling the image later.
+//
+// Arguments:
+// contents: 0-D. The JPEG-encoded image.
+//
+// Returns 3-D with shape `[height, width, channels]`..
+func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeJpeg",
+ Input: []tf.Input{
+ contents,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad.
+type ResizeNearestNeighborGradAttr func(optionalAttr)
+
+// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value.
+//
+// value: If true, rescale grads by (orig_height - 1) / (height - 1), which
+// exactly aligns the 4 corners of grads and original_image. If false, rescale by
+// orig_height / height. Treat similarly the width dimension.
+// If not specified, defaults to false
+func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr {
+ return func(m optionalAttr) {
+ m["align_corners"] = value
+ }
+}
+
+// Computes the gradient of nearest neighbor interpolation.
+//
+// Arguments:
+// grads: 4-D with shape `[batch, height, width, channels]`.
+// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The
+// original input size.
+//
+// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients
+// with respect to the input image.
+func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResizeNearestNeighborGrad",
+ Input: []tf.Input{
+ grads, size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor.
+type ResizeNearestNeighborAttr func(optionalAttr)
+
+// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value.
+//
+// value: If true, rescale input by (new_height - 1) / (height - 1), which
+// exactly aligns the 4 corners of images and resized images. If false, rescale
+// by new_height / height. Treat similarly the width dimension.
+// If not specified, defaults to false
+func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr {
+ return func(m optionalAttr) {
+ m["align_corners"] = value
+ }
+}
+
+// Resize `images` to `size` using nearest neighbor interpolation.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, channels]`.
+// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
+//
+// Returns 4-D with shape
+// `[batch, new_height, new_width, channels]`.
+func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResizeNearestNeighbor",
+ Input: []tf.Input{
+ images, size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the set of files matching one or more glob patterns.
+//
+// Note that this routine only supports wildcard characters in the
+// basename portion of the pattern, not in the directory portion.
+//
+// Arguments:
+// pattern: Shell wildcard pattern(s). Scalar or vector of type string.
+//
+// Returns A vector of matching filenames.
+func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MatchingFiles",
+ Input: []tf.Input{
+ pattern,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Shuffle dimensions of x according to a permutation.
+//
+// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Transpose",
+ Input: []tf.Input{
+ x, perm,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Reads and outputs the entire contents of the input filename.
+func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReadFile",
+ Input: []tf.Input{
+ filename,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes softmax cross entropy cost and gradients to backpropagate.
//
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
@@ -6560,6 +6739,95 @@ func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
return op.Output(0)
}
+// ResizeBilinearAttr is an optional argument to ResizeBilinear.
+type ResizeBilinearAttr func(optionalAttr)
+
+// ResizeBilinearAlignCorners sets the optional align_corners attribute to value.
+//
+// value: If true, rescale input by (new_height - 1) / (height - 1), which
+// exactly aligns the 4 corners of images and resized images. If false, rescale
+// by new_height / height. Treat similarly the width dimension.
+// If not specified, defaults to false
+func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr {
+ return func(m optionalAttr) {
+ m["align_corners"] = value
+ }
+}
+
+// Resize `images` to `size` using bilinear interpolation.
+//
+// Input images can be of different types but output images are always float.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, channels]`.
+// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
+//
+// Returns 4-D with shape
+// `[batch, new_height, new_width, channels]`.
+func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResizeBilinear",
+ Input: []tf.Input{
+ images, size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ProdAttr is an optional argument to Prod.
+type ProdAttr func(optionalAttr)
+
+// ProdKeepDims sets the optional keep_dims attribute to value.
+//
+// value: If true, retain reduced dimensions with length 1.
+// If not specified, defaults to false
+func ProdKeepDims(value bool) ProdAttr {
+ return func(m optionalAttr) {
+ m["keep_dims"] = value
+ }
+}
+
+// Computes the product of elements across dimensions of a tensor.
+//
+// Reduces `input` along the dimensions given in `reduction_indices`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
+//
+// Arguments:
+// input: The tensor to reduce.
+// reduction_indices: The dimensions to reduce.
+//
+// Returns The reduced tensor.
+func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Prod",
+ Input: []tf.Input{
+ input, reduction_indices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative.
type DepthwiseConv2dNativeAttr func(optionalAttr)
@@ -6770,6 +7038,181 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output)
return op.Output(0)
}
+// EncodeJpegAttr is an optional argument to EncodeJpeg.
+type EncodeJpegAttr func(optionalAttr)
+
+// EncodeJpegFormat sets the optional format attribute to value.
+//
+// value: Per pixel image format.
+// If not specified, defaults to ""
+func EncodeJpegFormat(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["format"] = value
+ }
+}
+
+// EncodeJpegQuality sets the optional quality attribute to value.
+//
+// value: Quality of the compression from 0 to 100 (higher is better and slower).
+// If not specified, defaults to 95
+func EncodeJpegQuality(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["quality"] = value
+ }
+}
+
+// EncodeJpegProgressive sets the optional progressive attribute to value.
+//
+// value: If True, create a JPEG that loads progressively (coarse to fine).
+// If not specified, defaults to false
+func EncodeJpegProgressive(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["progressive"] = value
+ }
+}
+
+// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value.
+//
+// value: If True, spend CPU/RAM to reduce size with no quality change.
+// If not specified, defaults to false
+func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["optimize_size"] = value
+ }
+}
+
+// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value.
+//
+// value: See http://en.wikipedia.org/wiki/Chroma_subsampling.
+// If not specified, defaults to true
+func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["chroma_downsampling"] = value
+ }
+}
+
+// EncodeJpegDensityUnit sets the optional density_unit attribute to value.
+//
+// value: Unit used to specify `x_density` and `y_density`:
+// pixels per inch (`'in'`) or centimeter (`'cm'`).
+// If not specified, defaults to "in"
+func EncodeJpegDensityUnit(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["density_unit"] = value
+ }
+}
+
+// EncodeJpegXDensity sets the optional x_density attribute to value.
+//
+// value: Horizontal pixels per density unit.
+// If not specified, defaults to 300
+func EncodeJpegXDensity(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["x_density"] = value
+ }
+}
+
+// EncodeJpegYDensity sets the optional y_density attribute to value.
+//
+// value: Vertical pixels per density unit.
+// If not specified, defaults to 300
+func EncodeJpegYDensity(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["y_density"] = value
+ }
+}
+
+// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value.
+//
+// value: If not empty, embed this XMP metadata in the image header.
+// If not specified, defaults to ""
+func EncodeJpegXmpMetadata(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["xmp_metadata"] = value
+ }
+}
+
+// JPEG-encode an image.
+//
+// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
+//
+// The attr `format` can be used to override the color format of the encoded
+// output. Values can be:
+//
+// * `''`: Use a default format based on the number of channels in the image.
+// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension
+// of `image` must be 1.
+// * `rgb`: Output an RGB JPEG image. The `channels` dimension
+// of `image` must be 3.
+//
+// If `format` is not specified or is the empty string, a default format is picked
+// in function of the number of channels in `image`:
+//
+// * 1: Output a grayscale image.
+// * 3: Output an RGB image.
+//
+// Arguments:
+// image: 3-D with shape `[height, width, channels]`.
+//
+// Returns 0-D. JPEG-encoded image.
+func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeJpeg",
+ Input: []tf.Input{
+ image,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Gradients for batch normalization.
+//
+// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization()
+//
+// This op is deprecated. See `tf.nn.batch_normalization`.
+//
+// Arguments:
+// t: A 4D input Tensor.
+// m: A 1D mean Tensor with size matching the last dimension of t.
+// This is the first output from tf.nn.moments,
+// or a saved moving average thereof.
+// v: A 1D variance Tensor with size matching the last dimension of t.
+// This is the second output from tf.nn.moments,
+// or a saved moving average thereof.
+// gamma: A 1D gamma Tensor with size matching the last dimension of t.
+// If "scale_after_normalization" is true, this Tensor will be multiplied
+// with the normalized Tensor.
+// backprop: 4D backprop Tensor.
+// variance_epsilon: A small float number to avoid dividing by 0.
+// scale_after_normalization: A bool indicating whether the resulted tensor
+// needs to be multiplied with gamma.
+//
+// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma.
+func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization}
+ opspec := tf.OpSpec{
+ Type: "BatchNormWithGlobalNormalizationGrad",
+ Input: []tf.Input{
+ t, m, v, gamma, backprop,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput.
type Conv2DBackpropInputAttr func(optionalAttr)
@@ -7160,6 +7603,51 @@ func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes
return scope.AddOperation(opspec)
}
+// Writes contents to the file at input filename. Creates file if not existing.
+//
+// Arguments:
+// filename: scalar. The name of the file to which we write the contents.
+// contents: scalar. The content to be written to the output file.
+//
+// Returns the created operation.
+func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "WriteFile",
+ Input: []tf.Input{
+ filename, contents,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Computes the Cholesky decomposition of one or more square matrices.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices, with the same constraints as the single matrix Cholesky
+// decomposition above. The output is a tensor of the same shape as the input
+// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+//
+// Arguments:
+// input: Shape is `[..., M, M]`.
+//
+// Returns Shape is `[..., M, M]`.
+func Cholesky(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Cholesky",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the rank of a tensor.
//
// This operation returns an integer representing the rank of `input`.
@@ -7243,54 +7731,6 @@ func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, opt
return output
}
-// BiasAddGradAttr is an optional argument to BiasAddGrad.
-type BiasAddGradAttr func(optionalAttr)
-
-// BiasAddGradDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the bias tensor will be added to the last dimension
-// of the value tensor.
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// The tensor will be added to "in_channels", the third-to-the-last
-// dimension.
-// If not specified, defaults to "NHWC"
-func BiasAddGradDataFormat(value string) BiasAddGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// The backward operation for "BiasAdd" on the "bias" tensor.
-//
-// It accumulates all the values from out_backprop into the feature dimension.
-// For NHWC data format, the feature dimension is the last. For NCHW data format,
-// the feature dimension is the third-to-last.
-//
-// Arguments:
-// out_backprop: Any number of dimensions.
-//
-// Returns 1-D with size the feature dimension of `out_backprop`.
-func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "BiasAddGrad",
- Input: []tf.Input{
- out_backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Convert JSON-encoded Example records to binary protocol buffer strings.
//
// This op translates a tensor containing Example records, encoded using
@@ -8024,27 +8464,51 @@ func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output
return op.Output(0)
}
-// Convert one or more images from HSV to RGB.
+// EncodePngAttr is an optional argument to EncodePng.
+type EncodePngAttr func(optionalAttr)
+
+// EncodePngCompression sets the optional compression attribute to value.
//
-// Outputs a tensor of the same shape as the `images` tensor, containing the RGB
-// value of the pixels. The output is only well defined if the value in `images`
-// are in `[0,1]`.
+// value: Compression level.
+// If not specified, defaults to -1
+func EncodePngCompression(value int64) EncodePngAttr {
+ return func(m optionalAttr) {
+ m["compression"] = value
+ }
+}
+
+// PNG-encode an image.
//
-// See `rgb_to_hsv` for a description of the HSV encoding.
+// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
+// where `channels` is:
+//
+// * 1: for grayscale.
+// * 2: for grayscale + alpha.
+// * 3: for RGB.
+// * 4: for RGBA.
+//
+// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
+// default or a value from 0 to 9. 9 is the highest compression level, generating
+// the smallest output, but is slower.
//
// Arguments:
-// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3.
+// image: 3-D with shape `[height, width, channels]`.
//
-// Returns `images` converted to RGB.
-func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) {
+// Returns 0-D. PNG-encoded image.
+func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "HSVToRGB",
+ Type: "EncodePng",
Input: []tf.Input{
- images,
+ image,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -8976,29 +9440,6 @@ func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output,
return op.Output(0), op.Output(1)
}
-// Returns the set of files matching one or more glob patterns.
-//
-// Note that this routine only supports wildcard characters in the
-// basename portion of the pattern, not in the directory portion.
-//
-// Arguments:
-// pattern: Shell wildcard pattern(s). Scalar or vector of type string.
-//
-// Returns A vector of matching filenames.
-func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatchingFiles",
- Input: []tf.Input{
- pattern,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the gradient of the sigmoid of `x` wrt its input.
//
// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
@@ -10269,117 +10710,6 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max
return op.Output(0), op.Output(1), op.Output(2)
}
-// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2.
-type InitializeTableFromTextFileV2Attr func(optionalAttr)
-
-// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value.
-//
-// value: Number of elements of the file, use -1 if unknown.
-// If not specified, defaults to -1
-//
-// REQUIRES: value >= -1
-func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr {
- return func(m optionalAttr) {
- m["vocab_size"] = value
- }
-}
-
-// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value.
-//
-// value: Delimiter to separate fields in a line.
-// If not specified, defaults to "\t"
-func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr {
- return func(m optionalAttr) {
- m["delimiter"] = value
- }
-}
-
-// Initializes a table from a text file.
-//
-// It inserts one key-value pair into the table for each line of the file.
-// The key and value is extracted from the whole line content, elements from the
-// split line based on `delimiter` or the line number (starting from zero).
-// Where to extract the key and value from a line is specified by `key_index` and
-// `value_index`.
-//
-// - A value of -1 means use the line number(starting from zero), expects `int64`.
-// - A value of -2 means use the whole line content, expects `string`.
-// - A value >= 0 means use the index (starting at zero) of the split line based
-// on `delimiter`.
-//
-// Arguments:
-// table_handle: Handle to a table which will be initialized.
-// filename: Filename of a vocabulary text file.
-// key_index: Column index in a line to get the table `key` values from.
-// value_index: Column index that represents information of a line to get the table
-// `value` values from.
-//
-// Returns the created operation.
-func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "InitializeTableFromTextFileV2",
- Input: []tf.Input{
- table_handle, filename,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
-type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
-
-// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value.
-//
-// value: If True, the subtraction will be protected by a lock;
-// otherwise the behavior is undefined, but may exhibit less contention.
-// If not specified, defaults to false
-func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Sparse update '*var' as FOBOS algorithm with fixed learning rate.
-//
-// That is for rows we have grad for, we update var as follows:
-// prox_v = var - alpha * grad
-// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
-//
-// Arguments:
-// var_: Should be from a Variable().
-// alpha: Scaling factor. Must be a scalar.
-// l1: L1 regularization. Must be a scalar.
-// l2: L2 regularization. Must be a scalar.
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var and accum.
-//
-// Returns the created operation.
-func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceSparseApplyProximalGradientDescent",
- Input: []tf.Input{
- var_, alpha, l1, l2, grad, indices,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// Computes rectified linear gradients for a Relu operation.
//
// Arguments:
@@ -10420,51 +10750,6 @@ func ReciprocalGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Computes the Cholesky decomposition of one or more square matrices.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices, with the same constraints as the single matrix Cholesky
-// decomposition above. The output is a tensor of the same shape as the input
-// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
-//
-// Arguments:
-// input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
-func Cholesky(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Cholesky",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Writes contents to the file at input filename. Creates file if not existing.
-//
-// Arguments:
-// filename: scalar. The name of the file to which we write the contents.
-// contents: scalar. The content to be written to the output file.
-//
-// Returns the created operation.
-func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "WriteFile",
- Input: []tf.Input{
- filename, contents,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Reverses specific dimensions of a tensor.
//
// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
@@ -10627,6 +10912,35 @@ func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Looks up keys in a table, outputs the corresponding values.
+//
+// The tensor `keys` must of the same type as the keys of the table.
+// The output `values` is of the type of the table values.
+//
+// The scalar `default_value` is the value output for keys not present in the
+// table. It must also be of the same type as the table values.
+//
+// Arguments:
+// table_handle: Handle to the table.
+// keys: Any shape. Keys to look up.
+//
+//
+// Returns Same shape as `keys`. Values found in the table, or `default_values`
+// for missing keys.
+func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LookupTableFindV2",
+ Input: []tf.Input{
+ table_handle, keys, default_value,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Given a quantized tensor described by (input, input_min, input_max), outputs a
//
// range that covers the actual values present in that tensor. This op is
@@ -11189,122 +11503,6 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
-// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
-type MutableHashTableOfTensorsV2Attr func(optionalAttr)
-
-// MutableHashTableOfTensorsV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this table is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this table is shared under the given name across
-// multiple sessions.
-// If not specified, defaults to ""
-func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
-// If not specified, defaults to false
-func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["use_node_name_sharing"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value.
-// If not specified, defaults to <>
-func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["value_shape"] = value
- }
-}
-
-// Creates an empty hash table.
-//
-// This op creates a mutable hash table, specifying the type of its keys and
-// values. Each value must be a vector. Data can be inserted into the table using
-// the insert operations. It does not support the initialization operation.
-//
-// Arguments:
-// key_dtype: Type of the table keys.
-// value_dtype: Type of the table values.
-//
-// Returns Handle to a table.
-func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MutableHashTableOfTensorsV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad.
-type ResourceApplyProximalAdagradAttr func(optionalAttr)
-
-// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value.
-//
-// value: If True, updating of the var and accum tensors will be protected by
-// a lock; otherwise the behavior is undefined, but may exhibit less contention.
-// If not specified, defaults to false
-func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate.
-//
-// accum += grad * grad
-// prox_v = var - lr * grad * (1 / sqrt(accum))
-// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regularization. Must be a scalar.
-// l2: L2 regularization. Must be a scalar.
-// grad: The gradient.
-//
-// Returns the created operation.
-func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyProximalAdagrad",
- Input: []tf.Input{
- var_, accum, lr, l1, l2, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// TensorArrayV3Attr is an optional argument to TensorArrayV3.
type TensorArrayV3Attr func(optionalAttr)
@@ -11619,54 +11817,6 @@ func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Ou
return op.Output(0)
}
-// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
-type AvgPool3DGradAttr func(optionalAttr)
-
-// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format of the input and output data. With the
-// default format "NDHWC", the data is stored in the order of:
-// [batch, in_depth, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCDHW", the data storage order is:
-// [batch, in_channels, in_depth, in_height, in_width].
-// If not specified, defaults to "NDHWC"
-func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes gradients of average pooling function.
-//
-// Arguments:
-// orig_input_shape: The original input dimensions.
-// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
-// ksize: 1-D tensor of length 5. The size of the window for each dimension of
-// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
-// strides: 1-D tensor of length 5. The stride of the sliding window for each
-// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
-// padding: The type of padding algorithm to use.
-//
-// Returns The backprop for input.
-func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AvgPool3DGrad",
- Input: []tf.Input{
- orig_input_shape, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizedRelu6Attr is an optional argument to QuantizedRelu6.
type QuantizedRelu6Attr func(optionalAttr)
@@ -12745,6 +12895,54 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
+type AvgPool3DGradAttr func(optionalAttr)
+
+// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format of the input and output data. With the
+// default format "NDHWC", the data is stored in the order of:
+// [batch, in_depth, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCDHW", the data storage order is:
+// [batch, in_channels, in_depth, in_height, in_width].
+// If not specified, defaults to "NDHWC"
+func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes gradients of average pooling function.
+//
+// Arguments:
+// orig_input_shape: The original input dimensions.
+// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+// ksize: 1-D tensor of length 5. The size of the window for each dimension of
+// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+// strides: 1-D tensor of length 5. The stride of the sliding window for each
+// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+// padding: The type of padding algorithm to use.
+//
+// Returns The backprop for input.
+func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPool3DGrad",
+ Input: []tf.Input{
+ orig_input_shape, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// TextLineReaderV2Attr is an optional argument to TextLineReaderV2.
type TextLineReaderV2Attr func(optionalAttr)
@@ -13390,39 +13588,6 @@ func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_
return scope.AddOperation(opspec)
}
-// Shuffle dimensions of x according to a permutation.
-//
-// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
-// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
-func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Transpose",
- Input: []tf.Input{
- x, perm,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Reads and outputs the entire contents of the input filename.
-func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReadFile",
- Input: []tf.Input{
- filename,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Output a fact about factorials.
func Fact(scope *Scope) (fact tf.Output) {
if scope.Err() != nil {
@@ -14260,37 +14425,6 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe
return op.Output(0)
}
-// Adjust the contrast of one or more images.
-//
-// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
-// interpreted as `[height, width, channels]`. The other dimensions only
-// represent a collection of images, such as `[batch, height, width, channels].`
-//
-// Contrast is adjusted independently for each channel of each image.
-//
-// For each channel, the Op first computes the mean of the image pixels in the
-// channel and then adjusts each component of each pixel to
-// `(x - mean) * contrast_factor + mean`.
-//
-// Arguments:
-// images: Images to adjust. At least 3-D.
-// contrast_factor: A float multiplier for adjusting contrast.
-//
-// Returns The contrast-adjusted image or images.
-func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustContrastv2",
- Input: []tf.Input{
- images, contrast_factor,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes softsign gradients for a softsign operation.
//
// Arguments:
@@ -14386,31 +14520,6 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64
return op.Output(0)
}
-// Decode the first frame of a GIF-encoded image to a uint8 tensor.
-//
-// GIF with frame or transparency compression are not supported
-// convert animated GIF from compressed to uncompressed by:
-//
-// convert $src.gif -coalesce $dst.gif
-//
-// Arguments:
-// contents: 0-D. The GIF-encoded image.
-//
-// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order
-func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DecodeGif",
- Input: []tf.Input{
- contents,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// EncodeBase64Attr is an optional argument to EncodeBase64.
type EncodeBase64Attr func(optionalAttr)
@@ -14672,143 +14781,118 @@ func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment
return op.Output(0)
}
-// Component-wise divides a SparseTensor by a dense Tensor.
+// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad.
+type ResizeBilinearGradAttr func(optionalAttr)
+
+// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value.
//
-// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
-// the other direction.
+// value: If true, rescale grads by (orig_height - 1) / (height - 1), which
+// exactly aligns the 4 corners of grads and original_image. If false, rescale by
+// orig_height / height. Treat similarly the width dimension.
+// If not specified, defaults to false
+func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr {
+ return func(m optionalAttr) {
+ m["align_corners"] = value
+ }
+}
+
+// Computes the gradient of bilinear interpolation.
//
// Arguments:
-// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
-// SparseTensor, possibly not in canonical ordering.
-// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`.
-// sp_shape: 1-D. Shape of the input SparseTensor.
-// dense: `R`-D. The dense Tensor operand.
+// grads: 4-D with shape `[batch, height, width, channels]`.
+// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`,
+// The image tensor that was resized.
//
-// Returns 1-D. The `N` values that are operated on.
-func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) {
+// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`.
+// Gradients with respect to the input image. Input image must have been
+// float or double.
+func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseDenseCwiseDiv",
+ Type: "ResizeBilinearGrad",
Input: []tf.Input{
- sp_indices, sp_values, sp_shape, dense,
+ grads, original_image,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Reads the value of a variable.
-//
-// The tensor returned by this operation is immutable.
-//
-// The value returned by this operation is guaranteed to be influenced by all the
-// writes on which this operation depends directly or indirectly, and to not be
-// influenced by any of the writes which depend directly or indirectly on this
-// operation.
+// Computes the number of elements in the given table.
//
// Arguments:
-// resource: handle to the resource in which to store the variable.
-// dtype: the dtype of the value.
-func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) {
+// table_handle: Handle to the table.
+//
+// Returns Scalar that contains number of elements in the table.
+func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"dtype": dtype}
opspec := tf.OpSpec{
- Type: "ReadVariableOp",
+ Type: "LookupTableSizeV2",
Input: []tf.Input{
- resource,
+ table_handle,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// ProdAttr is an optional argument to Prod.
-type ProdAttr func(optionalAttr)
-
-// ProdKeepDims sets the optional keep_dims attribute to value.
-//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func ProdKeepDims(value bool) ProdAttr {
- return func(m optionalAttr) {
- m["keep_dims"] = value
- }
-}
-
-// Computes the product of elements across dimensions of a tensor.
+// Component-wise divides a SparseTensor by a dense Tensor.
//
-// Reduces `input` along the dimensions given in `reduction_indices`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
+// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
+// the other direction.
//
// Arguments:
-// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+// SparseTensor, possibly not in canonical ordering.
+// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`.
+// sp_shape: 1-D. Shape of the input SparseTensor.
+// dense: `R`-D. The dense Tensor operand.
//
-// Returns The reduced tensor.
-func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) {
+// Returns 1-D. The `N` values that are operated on.
+func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "Prod",
+ Type: "SparseDenseCwiseDiv",
Input: []tf.Input{
- input, reduction_indices,
+ sp_indices, sp_values, sp_shape, dense,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// ResizeBilinearAttr is an optional argument to ResizeBilinear.
-type ResizeBilinearAttr func(optionalAttr)
-
-// ResizeBilinearAlignCorners sets the optional align_corners attribute to value.
+// Reads the value of a variable.
//
-// value: If true, rescale input by (new_height - 1) / (height - 1), which
-// exactly aligns the 4 corners of images and resized images. If false, rescale
-// by new_height / height. Treat similarly the width dimension.
-// If not specified, defaults to false
-func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr {
- return func(m optionalAttr) {
- m["align_corners"] = value
- }
-}
-
-// Resize `images` to `size` using bilinear interpolation.
+// The tensor returned by this operation is immutable.
//
-// Input images can be of different types but output images are always float.
+// The value returned by this operation is guaranteed to be influenced by all the
+// writes on which this operation depends directly or indirectly, and to not be
+// influenced by any of the writes which depend directly or indirectly on this
+// operation.
//
// Arguments:
-// images: 4-D with shape `[batch, height, width, channels]`.
-// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
-//
-// Returns 4-D with shape
-// `[batch, new_height, new_width, channels]`.
-func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) {
+// resource: handle to the resource in which to store the variable.
+// dtype: the dtype of the value.
+func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
+ attrs := map[string]interface{}{"dtype": dtype}
opspec := tf.OpSpec{
- Type: "ResizeBilinear",
+ Type: "ReadVariableOp",
Input: []tf.Input{
- images, size,
+ resource,
},
Attrs: attrs,
}
@@ -14988,6 +15072,108 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm
return op.Output(0)
}
+// Converts one or more images from RGB to HSV.
+//
+// Outputs a tensor of the same shape as the `images` tensor, containing the HSV
+// value of the pixels. The output is only well defined if the value in `images`
+// are in `[0,1]`.
+//
+// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
+// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
+// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
+//
+// Arguments:
+// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3.
+//
+// Returns `images` converted to HSV.
+func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RGBToHSV",
+ Input: []tf.Input{
+ images,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
+type MatrixSolveLsAttr func(optionalAttr)
+
+// MatrixSolveLsFast sets the optional fast attribute to value.
+// If not specified, defaults to true
+func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
+ return func(m optionalAttr) {
+ m["fast"] = value
+ }
+}
+
+// Solves one or more linear least-squares problems.
+//
+// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`.
+// The output is a tensor shape `[..., N, K]` where each output matrix solves
+// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]
+// in the least squares sense.
+//
+// matrix and right-hand sides in the batch:
+//
+// `matrix`=\\(A \in \Re^{m \times n}\\),
+// `rhs`=\\(B \in \Re^{m \times k}\\),
+// `output`=\\(X \in \Re^{n \times k}\\),
+// `l2_regularizer`=\\(\lambda\\).
+//
+// If `fast` is `True`, then the solution is computed by solving the normal
+// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
+// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
+// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
+// minimum-norm solution to the under-determined linear system, i.e.
+// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to
+// \\(A Z = B\\). Notice that the fast path is only numerically stable when
+// \\(A\\) is numerically full rank and has a condition number
+// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
+// sufficiently large.
+//
+// If `fast` is `False` an algorithm based on the numerically robust complete
+// orthogonal decomposition is used. This computes the minimum-norm
+// least-squares solution, even when \\(A\\) is rank deficient. This path is
+// typically 6-7 times slower than the fast path. If `fast` is `False` then
+// `l2_regularizer` is ignored.
+//
+// Arguments:
+// matrix: Shape is `[..., M, N]`.
+// rhs: Shape is `[..., M, K]`.
+// l2_regularizer: Scalar tensor.
+//
+// @compatibility(numpy)
+// Equivalent to np.linalg.lstsq
+// @end_compatibility
+//
+// Returns Shape is `[..., N, K]`.
+func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixSolveLs",
+ Input: []tf.Input{
+ matrix, rhs, l2_regularizer,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// QuantizedReluXAttr is an optional argument to QuantizedReluX.
type QuantizedReluXAttr func(optionalAttr)
@@ -15770,6 +15956,30 @@ func TanhGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// Outputs all keys and values in the table.
+//
+// Arguments:
+// table_handle: Handle to the table.
+//
+//
+//
+// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`.
+func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues}
+ opspec := tf.OpSpec{
+ Type: "LookupTableExportV2",
+ Input: []tf.Input{
+ table_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap.
type AddManySparseToTensorsMapAttr func(optionalAttr)
@@ -15877,6 +16087,153 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o
return op.Output(0)
}
+// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
+type TensorArrayGatherV3Attr func(optionalAttr)
+
+// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value.
+//
+// value: The expected shape of an element, if known. Used to
+// validate the shapes of TensorArray elements. If this shape is not
+// fully specified, gathering zero-size TensorArrays is an error.
+// If not specified, defaults to <unknown_rank:true >
+func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr {
+ return func(m optionalAttr) {
+ m["element_shape"] = value
+ }
+}
+
+// Gather specific elements from the TensorArray into output `value`.
+//
+// All elements selected by `indices` must have the same shape.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+// indices: The locations in the TensorArray from which to read tensor elements.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns All of the elements in the TensorArray, concatenated along a new
+// axis (the new dimension 0).
+func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGatherV3",
+ Input: []tf.Input{
+ handle, indices, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Deprecated. Disallowed in GraphDef version >= 2.
+//
+// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead
+func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustContrast",
+ Input: []tf.Input{
+ images, contrast_factor, min_value, max_value,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
+type MaxPoolGradGradAttr func(optionalAttr)
+
+// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes second-order gradients of the maxpooling function.
+//
+// Arguments:
+// orig_input: The original input tensor.
+// orig_output: The original output tensor.
+// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns Gradients of gradients w.r.t. the input to `max_pool`.
+func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPoolGradGrad",
+ Input: []tf.Input{
+ orig_input, orig_output, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// 3D real-valued fast Fourier transform.
+//
+// Computes the 3-dimensional discrete Fourier transform of a real-valued signal
+// over the inner-most 3 dimensions of `input`.
+//
+// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
+// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
+// of `output`: the zero-frequency term, followed by the `fft_length / 2`
+// positive-frequency terms.
+//
+// Arguments:
+// input: A float32 tensor.
+// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
+//
+// Returns A complex64 tensor of the same rank as `input`. The inner-most 3
+// dimensions of `input` are replaced with the their 3D Fourier transform. The
+// inner-most dimension contains `fft_length / 2 + 1` unique frequency
+// components.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.rfftn with 3 dimensions.
+// @end_compatibility
+func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RFFT3D",
+ Input: []tf.Input{
+ input, fft_length,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// UniqueWithCountsAttr is an optional argument to UniqueWithCounts.
type UniqueWithCountsAttr func(optionalAttr)
@@ -16708,6 +17065,30 @@ func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, o
return op.Output(0), op.Output(1), op.Output(2)
}
+// Updates the table to associates keys with values.
+//
+// The tensor `keys` must be of the same type as the keys of the table.
+// The tensor `values` must be of the type of the table values.
+//
+// Arguments:
+// table_handle: Handle to the table.
+// keys: Any shape. Keys to look up.
+// values: Values to associate with keys.
+//
+// Returns the created operation.
+func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LookupTableInsertV2",
+ Input: []tf.Input{
+ table_handle, keys, values,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Produces the average pool of the input tensor for quantized types.
//
// Arguments:
@@ -16997,41 +17378,6 @@ func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Out
return op.Output(0)
}
-// Draw bounding boxes on a batch of images.
-//
-// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
-// boxes specified by the locations in `boxes`. The coordinates of the each
-// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example, if an image is 100 x 200 pixels and the bounding box is
-// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the
-// bounding box will be `(10, 40)` to `(50, 180)`.
-//
-// Parts of the bounding box may fall outside the image.
-//
-// Arguments:
-// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
-// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
-// boxes.
-//
-// Returns 4-D with the same shape as `images`. The batch of input images with
-// bounding boxes drawn on the images.
-func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DrawBoundingBoxes",
- Input: []tf.Input{
- images, boxes,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the element-wise max of two SparseTensors.
//
// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
@@ -17501,28 +17847,6 @@ func Log(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Computes rectified linear 6 gradients for a Relu6 operation.
-//
-// Arguments:
-// gradients: The backpropagated gradients to the corresponding Relu6 operation.
-// features: The features passed as input to the corresponding Relu6 operation.
-//
-// Returns The gradients:
-// `gradients * (features > 0) * (features < 6)`.
-func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Relu6Grad",
- Input: []tf.Input{
- gradients, features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResizeBicubicAttr is an optional argument to ResizeBicubic.
type ResizeBicubicAttr func(optionalAttr)
@@ -17568,6 +17892,28 @@ func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...R
return op.Output(0)
}
+// Computes rectified linear 6 gradients for a Relu6 operation.
+//
+// Arguments:
+// gradients: The backpropagated gradients to the corresponding Relu6 operation.
+// features: The features passed as input to the corresponding Relu6 operation.
+//
+// Returns The gradients:
+// `gradients * (features > 0) * (features < 6)`.
+func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Relu6Grad",
+ Input: []tf.Input{
+ gradients, features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes natural logarithm of (1 + x) element-wise.
//
// I.e., \\(y = \log_e (1 + x)\\).
@@ -17681,181 +18027,6 @@ func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Outp
return op.Output(0)
}
-// Gradients for batch normalization.
-//
-// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization()
-//
-// This op is deprecated. See `tf.nn.batch_normalization`.
-//
-// Arguments:
-// t: A 4D input Tensor.
-// m: A 1D mean Tensor with size matching the last dimension of t.
-// This is the first output from tf.nn.moments,
-// or a saved moving average thereof.
-// v: A 1D variance Tensor with size matching the last dimension of t.
-// This is the second output from tf.nn.moments,
-// or a saved moving average thereof.
-// gamma: A 1D gamma Tensor with size matching the last dimension of t.
-// If "scale_after_normalization" is true, this Tensor will be multiplied
-// with the normalized Tensor.
-// backprop: 4D backprop Tensor.
-// variance_epsilon: A small float number to avoid dividing by 0.
-// scale_after_normalization: A bool indicating whether the resulted tensor
-// needs to be multiplied with gamma.
-//
-// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma.
-func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization}
- opspec := tf.OpSpec{
- Type: "BatchNormWithGlobalNormalizationGrad",
- Input: []tf.Input{
- t, m, v, gamma, backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
-// EncodeJpegAttr is an optional argument to EncodeJpeg.
-type EncodeJpegAttr func(optionalAttr)
-
-// EncodeJpegFormat sets the optional format attribute to value.
-//
-// value: Per pixel image format.
-// If not specified, defaults to ""
-func EncodeJpegFormat(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["format"] = value
- }
-}
-
-// EncodeJpegQuality sets the optional quality attribute to value.
-//
-// value: Quality of the compression from 0 to 100 (higher is better and slower).
-// If not specified, defaults to 95
-func EncodeJpegQuality(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["quality"] = value
- }
-}
-
-// EncodeJpegProgressive sets the optional progressive attribute to value.
-//
-// value: If True, create a JPEG that loads progressively (coarse to fine).
-// If not specified, defaults to false
-func EncodeJpegProgressive(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["progressive"] = value
- }
-}
-
-// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value.
-//
-// value: If True, spend CPU/RAM to reduce size with no quality change.
-// If not specified, defaults to false
-func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["optimize_size"] = value
- }
-}
-
-// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value.
-//
-// value: See http://en.wikipedia.org/wiki/Chroma_subsampling.
-// If not specified, defaults to true
-func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["chroma_downsampling"] = value
- }
-}
-
-// EncodeJpegDensityUnit sets the optional density_unit attribute to value.
-//
-// value: Unit used to specify `x_density` and `y_density`:
-// pixels per inch (`'in'`) or centimeter (`'cm'`).
-// If not specified, defaults to "in"
-func EncodeJpegDensityUnit(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["density_unit"] = value
- }
-}
-
-// EncodeJpegXDensity sets the optional x_density attribute to value.
-//
-// value: Horizontal pixels per density unit.
-// If not specified, defaults to 300
-func EncodeJpegXDensity(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["x_density"] = value
- }
-}
-
-// EncodeJpegYDensity sets the optional y_density attribute to value.
-//
-// value: Vertical pixels per density unit.
-// If not specified, defaults to 300
-func EncodeJpegYDensity(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["y_density"] = value
- }
-}
-
-// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value.
-//
-// value: If not empty, embed this XMP metadata in the image header.
-// If not specified, defaults to ""
-func EncodeJpegXmpMetadata(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["xmp_metadata"] = value
- }
-}
-
-// JPEG-encode an image.
-//
-// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
-//
-// The attr `format` can be used to override the color format of the encoded
-// output. Values can be:
-//
-// * `''`: Use a default format based on the number of channels in the image.
-// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension
-// of `image` must be 1.
-// * `rgb`: Output an RGB JPEG image. The `channels` dimension
-// of `image` must be 3.
-//
-// If `format` is not specified or is the empty string, a default format is picked
-// in function of the number of channels in `image`:
-//
-// * 1: Output a grayscale image.
-// * 3: Output an RGB image.
-//
-// Arguments:
-// image: 3-D with shape `[height, width, channels]`.
-//
-// Returns 0-D. JPEG-encoded image.
-func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeJpeg",
- Input: []tf.Input{
- image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes sin of x element-wise.
func Sin(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -18164,6 +18335,117 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output) (output tf.Outpu
return op.Output(0)
}
+// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
+type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
+
+// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value.
+//
+// value: If True, the subtraction will be protected by a lock;
+// otherwise the behavior is undefined, but may exhibit less contention.
+// If not specified, defaults to false
+func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Sparse update '*var' as FOBOS algorithm with fixed learning rate.
+//
+// That is for rows we have grad for, we update var as follows:
+// prox_v = var - alpha * grad
+// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
+//
+// Arguments:
+// var_: Should be from a Variable().
+// alpha: Scaling factor. Must be a scalar.
+// l1: L1 regularization. Must be a scalar.
+// l2: L2 regularization. Must be a scalar.
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var and accum.
+//
+// Returns the created operation.
+func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyProximalGradientDescent",
+ Input: []tf.Input{
+ var_, alpha, l1, l2, grad, indices,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2.
+type InitializeTableFromTextFileV2Attr func(optionalAttr)
+
+// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value.
+//
+// value: Number of elements of the file, use -1 if unknown.
+// If not specified, defaults to -1
+//
+// REQUIRES: value >= -1
+func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr {
+ return func(m optionalAttr) {
+ m["vocab_size"] = value
+ }
+}
+
+// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value.
+//
+// value: Delimiter to separate fields in a line.
+// If not specified, defaults to "\t"
+func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr {
+ return func(m optionalAttr) {
+ m["delimiter"] = value
+ }
+}
+
+// Initializes a table from a text file.
+//
+// It inserts one key-value pair into the table for each line of the file.
+// The key and value is extracted from the whole line content, elements from the
+// split line based on `delimiter` or the line number (starting from zero).
+// Where to extract the key and value from a line is specified by `key_index` and
+// `value_index`.
+//
+// - A value of -1 means use the line number(starting from zero), expects `int64`.
+// - A value of -2 means use the whole line content, expects `string`.
+// - A value >= 0 means use the index (starting at zero) of the split line based
+// on `delimiter`.
+//
+// Arguments:
+// table_handle: Handle to a table which will be initialized.
+// filename: Filename of a vocabulary text file.
+// key_index: Column index in a line to get the table `key` values from.
+// value_index: Column index that represents information of a line to get the table
+// `value` values from.
+//
+// Returns the created operation.
+func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "InitializeTableFromTextFileV2",
+ Input: []tf.Input{
+ table_handle, filename,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Computes atan of x element-wise.
func Atan(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -18628,33 +18910,36 @@ func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient.
-type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr)
+// BiasAddGradAttr is an optional argument to BiasAddGrad.
+type BiasAddGradAttr func(optionalAttr)
-// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value.
+// BiasAddGradDataFormat sets the optional data_format attribute to value.
//
-// value: The bitwidth of the quantization; between 2 and 8, inclusive.
-// If not specified, defaults to 8
-func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr {
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the bias tensor will be added to the last dimension
+// of the value tensor.
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// The tensor will be added to "in_channels", the third-to-the-last
+// dimension.
+// If not specified, defaults to "NHWC"
+func BiasAddGradDataFormat(value string) BiasAddGradAttr {
return func(m optionalAttr) {
- m["num_bits"] = value
+ m["data_format"] = value
}
}
-// Compute gradients for a FakeQuantWithMinMaxVars operation.
-//
-// Arguments:
-// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
-// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
-// min, max: Quantization interval, scalar floats.
+// The backward operation for "BiasAdd" on the "bias" tensor.
//
+// It accumulates all the values from out_backprop into the feature dimension.
+// For NHWC data format, the feature dimension is the last. For NCHW data format,
+// the feature dimension is the third-to-last.
//
+// Arguments:
+// out_backprop: Any number of dimensions.
//
-// Returns Backpropagated gradients w.r.t. inputs:
-// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter:
-// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter:
-// `sum(gradients * (inputs > max))`.
-func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) {
+// Returns 1-D with size the feature dimension of `out_backprop`.
+func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -18663,31 +18948,13 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t
a(attrs)
}
opspec := tf.OpSpec{
- Type: "FakeQuantWithMinMaxVarsGradient",
+ Type: "BiasAddGrad",
Input: []tf.Input{
- gradients, inputs, min, max,
+ out_backprop,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
-//
-// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Minimum",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
return op.Output(0)
}
@@ -19996,65 +20263,6 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// DecodePngAttr is an optional argument to DecodePng.
-type DecodePngAttr func(optionalAttr)
-
-// DecodePngChannels sets the optional channels attribute to value.
-//
-// value: Number of color channels for the decoded image.
-// If not specified, defaults to 0
-func DecodePngChannels(value int64) DecodePngAttr {
- return func(m optionalAttr) {
- m["channels"] = value
- }
-}
-
-// DecodePngDtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_UINT8
-func DecodePngDtype(value tf.DataType) DecodePngAttr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Decode a PNG-encoded image to a uint8 or uint16 tensor.
-//
-// The attr `channels` indicates the desired number of color channels for the
-// decoded image.
-//
-// Accepted values are:
-//
-// * 0: Use the number of channels in the PNG-encoded image.
-// * 1: output a grayscale image.
-// * 3: output an RGB image.
-// * 4: output an RGBA image.
-//
-// If needed, the PNG-encoded image is transformed to match the requested number
-// of color channels.
-//
-// Arguments:
-// contents: 0-D. The PNG-encoded image.
-//
-// Returns 3-D with shape `[height, width, channels]`.
-func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodePng",
- Input: []tf.Input{
- contents,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// AudioSummaryV2Attr is an optional argument to AudioSummaryV2.
type AudioSummaryV2Attr func(optionalAttr)
@@ -20219,31 +20427,188 @@ func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate flo
return op.Output(0)
}
-// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor.
-type ResizeNearestNeighborAttr func(optionalAttr)
+// Replaces the contents of the table with the specified keys and values.
+//
+// The tensor `keys` must be of the same type as the keys of the table.
+// The tensor `values` must be of the type of the table values.
+//
+// Arguments:
+// table_handle: Handle to the table.
+// keys: Any shape. Keys to look up.
+// values: Values to associate with keys.
+//
+// Returns the created operation.
+func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LookupTableImportV2",
+ Input: []tf.Input{
+ table_handle, keys, values,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
-// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value.
+// HashTableV2Attr is an optional argument to HashTableV2.
+type HashTableV2Attr func(optionalAttr)
+
+// HashTableV2Container sets the optional container attribute to value.
//
-// value: If true, rescale input by (new_height - 1) / (height - 1), which
-// exactly aligns the 4 corners of images and resized images. If false, rescale
-// by new_height / height. Treat similarly the width dimension.
+// value: If non-empty, this table is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func HashTableV2Container(value string) HashTableV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// HashTableV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this table is shared under the given name across
+// multiple sessions.
+// If not specified, defaults to ""
+func HashTableV2SharedName(value string) HashTableV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
+//
+// value: If true and shared_name is empty, the table is shared
+// using the node name.
// If not specified, defaults to false
-func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr {
+func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr {
return func(m optionalAttr) {
- m["align_corners"] = value
+ m["use_node_name_sharing"] = value
}
}
-// Resize `images` to `size` using nearest neighbor interpolation.
+// Creates a non-initialized hash table.
+//
+// This op creates a hash table, specifying the type of its keys and values.
+// Before using the table you will have to initialize it. After initialization the
+// table will be immutable.
//
// Arguments:
-// images: 4-D with shape `[batch, height, width, channels]`.
-// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
+// key_dtype: Type of the table keys.
+// value_dtype: Type of the table values.
//
-// Returns 4-D with shape
-// `[batch, new_height, new_width, channels]`.
-func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) {
+// Returns Handle to a table.
+func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "HashTableV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
+type MutableHashTableV2Attr func(optionalAttr)
+
+// MutableHashTableV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this table is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func MutableHashTableV2Container(value string) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MutableHashTableV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this table is shared under the given name across
+// multiple sessions.
+// If not specified, defaults to ""
+func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
+//
+// value: If true and shared_name is empty, the table is shared
+// using the node name.
+// If not specified, defaults to false
+func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["use_node_name_sharing"] = value
+ }
+}
+
+// Creates an empty hash table.
+//
+// This op creates a mutable hash table, specifying the type of its keys and
+// values. Each value must be a scalar. Data can be inserted into the table using
+// the insert operations. It does not support the initialization operation.
+//
+// Arguments:
+// key_dtype: Type of the table keys.
+// value_dtype: Type of the table values.
+//
+// Returns Handle to a table.
+func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MutableHashTableV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad.
+type ResourceApplyProximalAdagradAttr func(optionalAttr)
+
+// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value.
+//
+// value: If True, updating of the var and accum tensors will be protected by
+// a lock; otherwise the behavior is undefined, but may exhibit less contention.
+// If not specified, defaults to false
+func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate.
+//
+// accum += grad * grad
+// prox_v = var - lr * grad * (1 / sqrt(accum))
+// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regularization. Must be a scalar.
+// l2: L2 regularization. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) {
if scope.Err() != nil {
return
}
@@ -20252,13 +20617,165 @@ func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optio
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResizeNearestNeighbor",
+ Type: "ResourceApplyProximalAdagrad",
Input: []tf.Input{
- images, size,
+ var_, accum, lr, l1, l2, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
+type MutableHashTableOfTensorsV2Attr func(optionalAttr)
+
+// MutableHashTableOfTensorsV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this table is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this table is shared under the given name across
+// multiple sessions.
+// If not specified, defaults to ""
+func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
+// If not specified, defaults to false
+func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["use_node_name_sharing"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value.
+// If not specified, defaults to <>
+func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["value_shape"] = value
+ }
+}
+
+// Creates an empty hash table.
+//
+// This op creates a mutable hash table, specifying the type of its keys and
+// values. Each value must be a vector. Data can be inserted into the table using
+// the insert operations. It does not support the initialization operation.
+//
+// Arguments:
+// key_dtype: Type of the table keys.
+// value_dtype: Type of the table values.
+//
+// Returns Handle to a table.
+func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MutableHashTableOfTensorsV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Table initializer that takes two tensors for keys and values respectively.
+//
+// Arguments:
+// table_handle: Handle to a table which will be initialized.
+// keys: Keys of type Tkey.
+// values: Values of type Tval.
+//
+// Returns the created operation.
+func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InitializeTableV2",
+ Input: []tf.Input{
+ table_handle, keys, values,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient.
+type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr)
+
+// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value.
+//
+// value: The bitwidth of the quantization; between 2 and 8, inclusive.
+// If not specified, defaults to 8
+func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr {
+ return func(m optionalAttr) {
+ m["num_bits"] = value
+ }
+}
+
+// Compute gradients for a FakeQuantWithMinMaxVars operation.
+//
+// Arguments:
+// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
+// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
+// min, max: Quantization interval, scalar floats.
+//
+//
+//
+// Returns Backpropagated gradients w.r.t. inputs:
+// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter:
+// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter:
+// `sum(gradients * (inputs > max))`.
+func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FakeQuantWithMinMaxVarsGradient",
+ Input: []tf.Input{
+ gradients, inputs, min, max,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
+//
+// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Minimum",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
return op.Output(0)
}
@@ -20385,6 +20902,84 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha
return op.Output(0)
}
+// Adjust the saturation of one or more images.
+//
+// `images` is a tensor of at least 3 dimensions. The last dimension is
+// interpretted as channels, and must be three.
+//
+// The input image is considered in the RGB colorspace. Conceptually, the RGB
+// colors are first mapped into HSV. A scale is then applied all the saturation
+// values, and then remapped back to RGB colorspace.
+//
+// Arguments:
+// images: Images to adjust. At least 3-D.
+// scale: A float scale to add to the saturation.
+//
+// Returns The hue-adjusted image or images.
+func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustSaturation",
+ Input: []tf.Input{
+ images, scale,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2.
+type SelfAdjointEigV2Attr func(optionalAttr)
+
+// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value.
+//
+// value: If `True` then eigenvectors will be computed and returned in `v`.
+// Otherwise, only the eigenvalues will be computed.
+// If not specified, defaults to true
+func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr {
+ return func(m optionalAttr) {
+ m["compute_v"] = value
+ }
+}
+
+// Computes the eigen decomposition of one or more square self-adjoint matrices.
+//
+// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
+// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.
+//
+// ```prettyprint
+// # a is a tensor.
+// # e is a tensor of eigenvalues.
+// # v is a tensor of eigenvectors.
+// e, v = self_adjoint_eig(a)
+// e = self_adjoint_eig(a, compute_v=False)
+// ```
+//
+// Arguments:
+// input: `Tensor` input of shape `[N, N]`.
+//
+// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`.
+func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SelfAdjointEigV2",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// MatrixSolveAttr is an optional argument to MatrixSolve.
type MatrixSolveAttr func(optionalAttr)
@@ -21033,371 +21628,6 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
return op.Output(0), op.Output(1)
}
-// Computes the number of elements in the given table.
-//
-// Arguments:
-// table_handle: Handle to the table.
-//
-// Returns Scalar that contains number of elements in the table.
-func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LookupTableSizeV2",
- Input: []tf.Input{
- table_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad.
-type ResizeBilinearGradAttr func(optionalAttr)
-
-// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value.
-//
-// value: If true, rescale grads by (orig_height - 1) / (height - 1), which
-// exactly aligns the 4 corners of grads and original_image. If false, rescale by
-// orig_height / height. Treat similarly the width dimension.
-// If not specified, defaults to false
-func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr {
- return func(m optionalAttr) {
- m["align_corners"] = value
- }
-}
-
-// Computes the gradient of bilinear interpolation.
-//
-// Arguments:
-// grads: 4-D with shape `[batch, height, width, channels]`.
-// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`,
-// The image tensor that was resized.
-//
-// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`.
-// Gradients with respect to the input image. Input image must have been
-// float or double.
-func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResizeBilinearGrad",
- Input: []tf.Input{
- grads, original_image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad.
-type ResizeNearestNeighborGradAttr func(optionalAttr)
-
-// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value.
-//
-// value: If true, rescale grads by (orig_height - 1) / (height - 1), which
-// exactly aligns the 4 corners of grads and original_image. If false, rescale by
-// orig_height / height. Treat similarly the width dimension.
-// If not specified, defaults to false
-func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr {
- return func(m optionalAttr) {
- m["align_corners"] = value
- }
-}
-
-// Computes the gradient of nearest neighbor interpolation.
-//
-// Arguments:
-// grads: 4-D with shape `[batch, height, width, channels]`.
-// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The
-// original input size.
-//
-// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients
-// with respect to the input image.
-func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResizeNearestNeighborGrad",
- Input: []tf.Input{
- grads, size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// DecodeJpegAttr is an optional argument to DecodeJpeg.
-type DecodeJpegAttr func(optionalAttr)
-
-// DecodeJpegChannels sets the optional channels attribute to value.
-//
-// value: Number of color channels for the decoded image.
-// If not specified, defaults to 0
-func DecodeJpegChannels(value int64) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["channels"] = value
- }
-}
-
-// DecodeJpegRatio sets the optional ratio attribute to value.
-//
-// value: Downscaling ratio.
-// If not specified, defaults to 1
-func DecodeJpegRatio(value int64) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["ratio"] = value
- }
-}
-
-// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
-//
-// value: If true use a slower but nicer upscaling of the
-// chroma planes (yuv420/422 only).
-// If not specified, defaults to true
-func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["fancy_upscaling"] = value
- }
-}
-
-// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value.
-//
-// value: If true try to recover an image from truncated input.
-// If not specified, defaults to false
-func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["try_recover_truncated"] = value
- }
-}
-
-// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
-//
-// value: The minimum required fraction of lines before a truncated
-// input is accepted.
-// If not specified, defaults to 1
-func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["acceptable_fraction"] = value
- }
-}
-
-// DecodeJpegDctMethod sets the optional dct_method attribute to value.
-//
-// value: string specifying a hint about the algorithm used for
-// decompression. Defaults to "" which maps to a system-specific
-// default. Currently valid values are ["INTEGER_FAST",
-// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
-// jpeg library changes to a version that does not have that specific
-// option.)
-// If not specified, defaults to ""
-func DecodeJpegDctMethod(value string) DecodeJpegAttr {
- return func(m optionalAttr) {
- m["dct_method"] = value
- }
-}
-
-// Decode a JPEG-encoded image to a uint8 tensor.
-//
-// The attr `channels` indicates the desired number of color channels for the
-// decoded image.
-//
-// Accepted values are:
-//
-// * 0: Use the number of channels in the JPEG-encoded image.
-// * 1: output a grayscale image.
-// * 3: output an RGB image.
-//
-// If needed, the JPEG-encoded image is transformed to match the requested number
-// of color channels.
-//
-// The attr `ratio` allows downscaling the image by an integer factor during
-// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
-// downscaling the image later.
-//
-// Arguments:
-// contents: 0-D. The JPEG-encoded image.
-//
-// Returns 3-D with shape `[height, width, channels]`..
-func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodeJpeg",
- Input: []tf.Input{
- contents,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
-type TensorArrayGatherV3Attr func(optionalAttr)
-
-// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value.
-//
-// value: The expected shape of an element, if known. Used to
-// validate the shapes of TensorArray elements. If this shape is not
-// fully specified, gathering zero-size TensorArrays is an error.
-// If not specified, defaults to <unknown_rank:true >
-func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr {
- return func(m optionalAttr) {
- m["element_shape"] = value
- }
-}
-
-// Gather specific elements from the TensorArray into output `value`.
-//
-// All elements selected by `indices` must have the same shape.
-//
-// Arguments:
-// handle: The handle to a TensorArray.
-// indices: The locations in the TensorArray from which to read tensor elements.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
-//
-// Returns All of the elements in the TensorArray, concatenated along a new
-// axis (the new dimension 0).
-func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TensorArrayGatherV3",
- Input: []tf.Input{
- handle, indices, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
-type MaxPoolGradGradAttr func(optionalAttr)
-
-// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes second-order gradients of the maxpooling function.
-//
-// Arguments:
-// orig_input: The original input tensor.
-// orig_output: The original output tensor.
-// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
-//
-// Returns Gradients of gradients w.r.t. the input to `max_pool`.
-func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MaxPoolGradGrad",
- Input: []tf.Input{
- orig_input, orig_output, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// 3D real-valued fast Fourier transform.
-//
-// Computes the 3-dimensional discrete Fourier transform of a real-valued signal
-// over the inner-most 3 dimensions of `input`.
-//
-// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
-// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
-// of `output`: the zero-frequency term, followed by the `fft_length / 2`
-// positive-frequency terms.
-//
-// Arguments:
-// input: A float32 tensor.
-// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
-//
-// Returns A complex64 tensor of the same rank as `input`. The inner-most 3
-// dimensions of `input` are replaced with the their 3D Fourier transform. The
-// inner-most dimension contains `fft_length / 2 + 1` unique frequency
-// components.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.rfftn with 3 dimensions.
-// @end_compatibility
-func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RFFT3D",
- Input: []tf.Input{
- input, fft_length,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deprecated. Disallowed in GraphDef version >= 2.
-//
-// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead
-func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustContrast",
- Input: []tf.Input{
- images, contrast_factor, min_value, max_value,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Store the input tensor in the state of the current session.
//
// Arguments:
@@ -21419,25 +21649,6 @@ func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
return op.Output(0)
}
-// Restore a Reader to its initial clean state.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-//
-// Returns the created operation.
-func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderResetV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Adjust the hue of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last dimension is
@@ -21466,232 +21677,21 @@ func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Outpu
return op.Output(0)
}
-// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2.
-type SelfAdjointEigV2Attr func(optionalAttr)
-
-// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value.
-//
-// value: If `True` then eigenvectors will be computed and returned in `v`.
-// Otherwise, only the eigenvalues will be computed.
-// If not specified, defaults to true
-func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr {
- return func(m optionalAttr) {
- m["compute_v"] = value
- }
-}
-
-// Computes the eigen decomposition of one or more square self-adjoint matrices.
-//
-// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
-// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.
-//
-// ```prettyprint
-// # a is a tensor.
-// # e is a tensor of eigenvalues.
-// # v is a tensor of eigenvectors.
-// e, v = self_adjoint_eig(a)
-// e = self_adjoint_eig(a, compute_v=False)
-// ```
-//
-// Arguments:
-// input: `Tensor` input of shape `[N, N]`.
-//
-// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`.
-func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SelfAdjointEigV2",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Adjust the saturation of one or more images.
-//
-// `images` is a tensor of at least 3 dimensions. The last dimension is
-// interpretted as channels, and must be three.
-//
-// The input image is considered in the RGB colorspace. Conceptually, the RGB
-// colors are first mapped into HSV. A scale is then applied all the saturation
-// values, and then remapped back to RGB colorspace.
-//
-// Arguments:
-// images: Images to adjust. At least 3-D.
-// scale: A float scale to add to the saturation.
-//
-// Returns The hue-adjusted image or images.
-func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustSaturation",
- Input: []tf.Input{
- images, scale,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// EncodePngAttr is an optional argument to EncodePng.
-type EncodePngAttr func(optionalAttr)
-
-// EncodePngCompression sets the optional compression attribute to value.
-//
-// value: Compression level.
-// If not specified, defaults to -1
-func EncodePngCompression(value int64) EncodePngAttr {
- return func(m optionalAttr) {
- m["compression"] = value
- }
-}
-
-// PNG-encode an image.
-//
-// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
-// where `channels` is:
-//
-// * 1: for grayscale.
-// * 2: for grayscale + alpha.
-// * 3: for RGB.
-// * 4: for RGBA.
-//
-// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
-// default or a value from 0 to 9. 9 is the highest compression level, generating
-// the smallest output, but is slower.
-//
-// Arguments:
-// image: 3-D with shape `[height, width, channels]`.
-//
-// Returns 0-D. PNG-encoded image.
-func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodePng",
- Input: []tf.Input{
- image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
-type MatrixSolveLsAttr func(optionalAttr)
-
-// MatrixSolveLsFast sets the optional fast attribute to value.
-// If not specified, defaults to true
-func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
- return func(m optionalAttr) {
- m["fast"] = value
- }
-}
-
-// Solves one or more linear least-squares problems.
-//
-// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`.
-// The output is a tensor shape `[..., N, K]` where each output matrix solves
-// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]
-// in the least squares sense.
-//
-// matrix and right-hand sides in the batch:
-//
-// `matrix`=\\(A \in \Re^{m \times n}\\),
-// `rhs`=\\(B \in \Re^{m \times k}\\),
-// `output`=\\(X \in \Re^{n \times k}\\),
-// `l2_regularizer`=\\(\lambda\\).
-//
-// If `fast` is `True`, then the solution is computed by solving the normal
-// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
-// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
-// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
-// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
-// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
-// minimum-norm solution to the under-determined linear system, i.e.
-// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to
-// \\(A Z = B\\). Notice that the fast path is only numerically stable when
-// \\(A\\) is numerically full rank and has a condition number
-// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
-// sufficiently large.
-//
-// If `fast` is `False` an algorithm based on the numerically robust complete
-// orthogonal decomposition is used. This computes the minimum-norm
-// least-squares solution, even when \\(A\\) is rank deficient. This path is
-// typically 6-7 times slower than the fast path. If `fast` is `False` then
-// `l2_regularizer` is ignored.
-//
-// Arguments:
-// matrix: Shape is `[..., M, N]`.
-// rhs: Shape is `[..., M, K]`.
-// l2_regularizer: Scalar tensor.
-//
-// @compatibility(numpy)
-// Equivalent to np.linalg.lstsq
-// @end_compatibility
-//
-// Returns Shape is `[..., N, K]`.
-func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MatrixSolveLs",
- Input: []tf.Input{
- matrix, rhs, l2_regularizer,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Converts one or more images from RGB to HSV.
-//
-// Outputs a tensor of the same shape as the `images` tensor, containing the HSV
-// value of the pixels. The output is only well defined if the value in `images`
-// are in `[0,1]`.
-//
-// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
-// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
-// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
+// Restore a Reader to its initial clean state.
//
// Arguments:
-// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3.
+// reader_handle: Handle to a Reader.
//
-// Returns `images` converted to HSV.
-func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) {
+// Returns the created operation.
+func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "RGBToHSV",
+ Type: "ReaderResetV2",
Input: []tf.Input{
- images,
+ reader_handle,
},
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
+ return scope.AddOperation(opspec)
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5e938c73f5..9fd5ada71e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -82,6 +82,7 @@ py_library(
"//third_party/py/numpy",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column",
+ "//tensorflow/python/feature_column:lookup_ops",
"//tensorflow/python/ops/losses",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/saved_model",
@@ -1021,7 +1022,6 @@ tf_gen_op_wrapper_private_py(
require_shape_functions = True,
visibility = [
"//learning/brain/python/ops:__pkg__",
- "//tensorflow/contrib/lookup:__pkg__",
"//tensorflow/python/kernel_tests:__pkg__",
],
)
@@ -1057,6 +1057,16 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "lookup_ops_gen",
+ require_shape_functions = True,
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow/python/feature_column:__pkg__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
require_shape_functions = True,
visibility = [
@@ -1474,6 +1484,20 @@ py_library(
)
py_library(
+ name = "lookup_ops",
+ srcs = ["ops/lookup_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":framework",
+ ":framework_for_generated_wrappers",
+ ":lookup_ops_gen",
+ ":math_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "math_grad",
srcs = ["ops/math_grad.py"],
srcs_version = "PY2AND3",
@@ -1861,6 +1885,7 @@ py_library(
":io_ops",
":linalg_ops",
":logging_ops",
+ ":lookup_ops",
":math_grad",
":math_ops",
":numerics",
@@ -2268,6 +2293,7 @@ py_library(
":io_ops",
":io_ops_gen",
":lib",
+ ":lookup_ops",
":math_ops",
":platform",
":protos_all_py",
@@ -2990,6 +3016,7 @@ cuda_py_tests(
":framework",
":framework_for_generated_wrappers",
":framework_test_lib",
+ ":lookup_ops",
":gradients",
":math_ops",
":nn_grad",
@@ -3020,7 +3047,7 @@ py_library(
srcs = ["training/saver_test_utils.py"],
srcs_version = "PY2AND3",
deps = [
- ":data_flow_ops_gen",
+ ":lookup_ops_gen",
":training",
],
)
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 864a96ef34..6336ca2310 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.config_pb2 import *
+from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.protobuf.rewriter_config_pb2 import *
from tensorflow.core.util.event_pb2 import *
@@ -131,6 +132,7 @@ _allowed_symbols = [
'AttrValue',
'AutoParallelOptions',
'ConfigProto',
+ 'ClusterDef',
'DeviceSpec',
'Event',
'GPUOptions',
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 9add5bd3cd..040cc33315 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -29,6 +29,7 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
@@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with CaptureStderr() as log:
sess.run(c)
# Ensure that we did log device placement.
- self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log))
+ self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log))
def testLocalMasterSessionTimeout(self):
# Test that the timeout passed in a config to the session works correctly.
@@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase):
server = server_lib.Server.create_local_server()
self.runTestBuildGraphError(session.Session(server.target))
+ def testClusterSpecPropagationSimple(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config)
+ output = sess.run(const)
+ self.assertEqual(17, output)
+
+ def testClusterSpecPropagationWorker2Placement(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'):
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config, graph=g)
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+ output = sess.run(const, options=run_options, run_metadata=run_metadata)
+ self.assertEqual(17, output)
+ self.assertEqual(1,
+ len([
+ node_stats
+ for dev_stats in run_metadata.step_stats.dev_stats
+ for node_stats in dev_stats.node_stats
+ if '/job:worker/replica:0/task:1/device:CPU:0' ==
+ dev_stats.device and 'Const' == node_stats.node_name
+ ]))
+
+ def testClusterSpecPropagationWorker1Placement(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config, graph=g)
+ output = sess.run(const)
+ self.assertEqual(17, output)
+
+ def testClusterSpecPropagationThreeServers2Graphs(self):
+ """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
+
+ We create 2 clusterspecs:
+ 1. server2 as the master, server1 as a worker
+ 2. server2 as the master, server3 as a worker
+
+ We ensure that variables on the workers are independent.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def1 = cluster_pb2.ClusterDef()
+ job1 = cluster_def1.job.add()
+ job1.name = 'worker1'
+ job1.tasks[0] = server2.target[len('grpc://'):]
+ job1.tasks[1] = server1.target[len('grpc://'):]
+
+ cluster_def2 = cluster_pb2.ClusterDef()
+ job2 = cluster_def2.job.add()
+ job2.name = 'worker2'
+ job2.tasks[0] = server2.target[len('grpc://'):]
+ job2.tasks[1] = server3.target[len('grpc://'):]
+
+ config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
+ config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
+
+ with ops.Graph().as_default() as g1:
+ with ops.device('/job:worker1/task:1'):
+ var1 = variables.Variable(array_ops.zeros([2]), name='var1')
+ update_op1 = state_ops.assign_add(
+ var1, array_ops.ones([2]), name='var1_assign_add')
+ init1 = variables.global_variables_initializer()
+
+ with ops.Graph().as_default() as g2:
+ with ops.device('/job:worker2/task:1'):
+ var2 = variables.Variable(array_ops.zeros([2]), name='var2')
+ update_op2 = state_ops.assign_add(
+ var2, array_ops.ones([2]), name='var2_assign_add')
+ init2 = variables.global_variables_initializer()
+
+ sess1 = session.Session(server2.target, graph=g1, config=config1)
+ sess2 = session.Session(server2.target, graph=g2, config=config2)
+
+ init1.run(session=sess1)
+ init2.run(session=sess2)
+
+ expected_zeros = np.zeros([2])
+ expected_ones = np.ones([2])
+
+ self.assertAllEqual(expected_zeros, sess1.run(var1))
+ self.assertAllEqual(expected_zeros, sess2.run(var2))
+
+ self.assertAllEqual(expected_ones, sess1.run(update_op1))
+ self.assertAllEqual(expected_ones, sess1.run(var1))
+ self.assertAllEqual(expected_zeros, sess2.run(var2))
+ self.assertAllEqual(expected_ones, sess2.run(update_op2))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1))
+ self.assertAllEqual(expected_ones, sess2.run(var2))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
+
+ def testClusterSpecPropagationThreeServers(self):
+ """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
+
+ We create 2 clusterspecs:
+ 1. server2 as the master, server1 as a worker
+ 2. server2 as the master, server3 as a worker
+
+ We ensure that variables on the workers are independent.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def1 = cluster_pb2.ClusterDef()
+ job1 = cluster_def1.job.add()
+ job1.name = 'worker'
+ job1.tasks[0] = server2.target[len('grpc://'):]
+ job1.tasks[1] = server1.target[len('grpc://'):]
+
+ cluster_def2 = cluster_pb2.ClusterDef()
+ job2 = cluster_def2.job.add()
+ job2.name = 'worker'
+ job2.tasks[0] = server2.target[len('grpc://'):]
+ job2.tasks[1] = server3.target[len('grpc://'):]
+
+ config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
+ config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
+
+ with ops.device('/job:worker/task:1'):
+ var = variables.Variable(array_ops.zeros([2]), name='var')
+ feed = array_ops.placeholder(dtypes.float32, shape=(2))
+ update_op = var.assign_add(feed)
+
+ sess1 = session.Session(server2.target, config=config1)
+ sess2 = session.Session(server2.target, config=config2)
+
+ variables.global_variables_initializer().run(session=sess1)
+ variables.global_variables_initializer().run(session=sess2)
+
+ expected_zeros = np.zeros([2])
+ expected_ones = np.ones([2])
+
+ self.assertAllEqual(expected_zeros, sess1.run(var))
+ self.assertAllEqual(expected_zeros, sess2.run(var))
+ self.assertAllEqual(expected_ones,
+ sess1.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones, sess1.run(var))
+ self.assertAllEqual(expected_zeros, sess2.run(var))
+ self.assertAllEqual(expected_ones,
+ sess2.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones + expected_ones,
+ sess1.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones, sess2.run(var))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(var))
+
+ def testClusterSpecPropagationThreeServersOneCluster(self):
+ """Boots 3 servers, ensures appropriate communication across workers.
+
+ Additionally, in this cluster, we ensure the master is not the 0-th worker.
+
+ Note: this test only uses one session.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server3.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ job.tasks[2] = server1.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ # Add ops to the devices in non-linear order.
+
+ with ops.device('/job:worker/task:1'):
+ feed1 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const1 = constant_op.constant(2.0)
+ mul1 = const1 * feed1
+
+ with ops.device('/job:worker/task:2'):
+ feed2 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const2 = constant_op.constant(2.0)
+ mul2 = const2 * feed2
+
+ with ops.device('/job:worker/task:0'):
+ feed0 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const0 = constant_op.constant(2.0)
+ mul0 = const0 * feed0
+
+ sum_op = mul0 + mul1 + mul2
+
+ ones = np.ones([2])
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ # Run!
+ with session.Session(server1.target, config=config) as sess:
+ output = sess.run(
+ sum_op,
+ options=run_options,
+ run_metadata=run_metadata,
+ feed_dict={feed1: ones,
+ feed2: ones,
+ feed0: ones})
+ self.assertAllEqual(6 * ones, output)
+
+ self.assertEqual(
+ 3,
+ len([
+ dev_stats.device
+ for dev_stats in run_metadata.step_stats.dev_stats
+ for node_stats in dev_stats.node_stats
+ if '/job:worker/replica:0/task:' in dev_stats.device and
+ node_stats.node_name.startswith('Const')
+ ]), run_metadata)
+
+ def testClusterSpecPropagationPartialRun(self):
+ """Test successful partial run with ClusterSpec propagation."""
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.device('/job:worker/task:0'):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ with ops.device('/job:worker/task:1'):
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ with ops.device('/job:worker/task:0'):
+ r2 = math_ops.multiply(r1, c)
+
+ with session.Session(server1.target, config=config) as sess:
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ res = sess.partial_run(h, r2, feed_dict={c: 3})
+ self.assertEqual(9, res)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 7a12ec01d0..3b7e3b1c90 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -42,8 +42,8 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
@@ -1396,9 +1396,10 @@ class EstimatorExportTest(test.TestCase):
my_int = variables.Variable(1, name='my_int',
collections=[ops.GraphKeys.LOCAL_VARIABLES])
scores = constant_op.constant([3.])
- with ops.control_dependencies(
- [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]):
+ with ops.control_dependencies([
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]):
assign_op = state_ops.assign(my_int, 12345)
# local_initSop must be an Operation, not a Tensor.
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 37a98cf481..a1ecd794df 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -23,6 +23,8 @@ import collections
import os
import time
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(features, dict):
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
for name, tensor in features.items():
- if not isinstance(name, str):
+ if not isinstance(name, six.string_types):
raise ValueError('feature keys must be strings: {}.'.format(name))
if not (isinstance(tensor, ops.Tensor)
or isinstance(tensor, sparse_tensor.SparseTensor)):
@@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(receiver_tensors, dict):
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
for name, tensor in receiver_tensors.items():
- if not isinstance(name, str):
+ if not isinstance(name, six.string_types):
raise ValueError(
'receiver_tensors keys must be strings: {}.'.format(name))
if not isinstance(tensor, ops.Tensor):
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 69be0f687c..49bcd06d50 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import abc
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
'Prediction outputs must be given as a dict of string to Tensor; '
'got {}'.format(outputs))
for key, value in outputs.items():
- if not isinstance(key, str):
+ if not isinstance(key, six.string_types):
raise ValueError(
'Prediction output key must be a string; got {}.'.format(key))
if not isinstance(value, ops.Tensor):
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index 27a088e551..035a9a143e 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator.export import export_output as export_output_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
signature_constants.CLASSIFY_METHOD_NAME)
self.assertEqual(actual_signature_def, expected_signature_def)
+ def test_predict_output_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ outputs = {
+ "output0": constant_op.constant([0]),
+ u"output1": constant_op.constant([1]),
+ }
+ export_output_lib.PredictOutput(outputs)
+
+ def test_predict_output_outputs_invalid(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Prediction outputs must be given as a dict of string to Tensor"):
+ export_output_lib.PredictOutput(constant_op.constant([0]))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Prediction output key must be a string"):
+ export_output_lib.PredictOutput({1: constant_op.constant([0])})
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Prediction output value must be a Tensor"):
+ export_output_lib.PredictOutput({
+ "prediction1": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ })
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index fdd924f2e1..7946bd88ba 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
class ExportTest(test_util.TensorFlowTestCase):
+ def test_serving_input_receiver_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0": constant_op.constant([0]),
+ u"feature1": constant_op.constant([1]),
+ "feature2": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ export.ServingInputReceiver(features, receiver_tensors)
+
+ def test_serving_input_receiver_features_invalid(self):
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+
+ with self.assertRaisesRegexp(ValueError, "features must be defined"):
+ export.ServingInputReceiver(
+ features=None,
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
+ export.ServingInputReceiver(
+ features={1: constant_op.constant([1])},
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(
+ ValueError, "feature feature1 must be a Tensor or SparseTensor"):
+ export.ServingInputReceiver(
+ features={"feature1": [1]},
+ receiver_tensors=receiver_tensors)
+
+ def test_serving_input_receiver_receiver_tensors_invalid(self):
+ features = {
+ "feature0": constant_op.constant([0]),
+ u"feature1": constant_op.constant([1]),
+ "feature2": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensors must be defined"):
+ export.ServingInputReceiver(
+ features=features,
+ receiver_tensors=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensors keys must be strings"):
+ export.ServingInputReceiver(
+ features=features,
+ receiver_tensors={
+ 1: array_ops.placeholder(dtypes.string, name="example0")})
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensor example1 must be a Tensor"):
+ export.ServingInputReceiver(
+ features=features,
+ receiver_tensors={"example1": [1]})
+
def test_single_feature_single_receiver(self):
feature = constant_op.constant(5)
receiver_tensor = array_ops.placeholder(dtypes.string)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index d5eb20e997..ac7aef96ac 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -29,6 +29,7 @@ py_library(
srcs = ["feature_column.py"],
srcs_version = "PY2AND3",
deps = [
+ ":lookup_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
"//tensorflow/python:init_ops",
@@ -44,14 +45,47 @@ py_library(
],
)
+filegroup(
+ name = "vocabulary_testdata",
+ srcs = [
+ "testdata/warriors_vocabulary.txt",
+ "testdata/wire_vocabulary.txt",
+ ],
+)
+
py_test(
name = "feature_column_test",
srcs = ["feature_column_test.py"],
+ data = [":vocabulary_testdata"],
srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
":feature_column",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+# TODO(ptucker,yleon): Move along with 3p/tf/contrib/lookup.
+# Test is still in 3p/tf/contrib/lookup.
+py_library(
+ name = "lookup_ops",
+ srcs = [
+ "lookup_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:lookup_ops_gen",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:string_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index a96052a3ae..ffdf8868e2 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -121,6 +121,9 @@ from __future__ import print_function
import abc
import collections
+import numpy as np
+
+from tensorflow.python.feature_column import lookup_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -331,7 +334,9 @@ def numeric_column(key,
```
Args:
- key: A string providing key to look up corresponding `Tensor`.
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
shape: An iterable of integers specifies the shape of the `Tensor`. An
integer can be given which means a single dimension `Tensor` with given
width. The `Tensor` representing the column will have the shape of
@@ -430,6 +435,12 @@ def bucketized_column(source_column, boundaries):
return _BucketizedColumn(source_column, tuple(boundaries))
+def _assert_string_or_int(dtype, prefix):
+ if (dtype != dtypes.string) and (not dtype.is_integer):
+ raise ValueError(
+ '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+
+
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
dtype=dtypes.string):
@@ -443,22 +454,22 @@ def categorical_column_with_hash_bucket(key,
```python
keywords = categorical_column_with_hash_bucket("keywords", 10K)
- all_feature_columns = [keywords, ...]
- linear_prediction = make_linear_model(features, all_feature_columns)
+ linear_prediction = make_linear_model(features, [keywords, ...])
# or
keywords_embedded = embedding_column(keywords, 16)
- all_feature_columns = [keywords_embedded, ...]
- dense_tensor = make_input_layer(features, all_feature_columns)
+ dense_tensor = make_input_layer(features, [keywords_embedded, ...])
```
Args:
- key: A string providing key to look up corresponding `Tensor`.
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
hash_bucket_size: An int > 1. The number of buckets.
dtype: The type of features. Only string and integer types are supported.
Returns:
- A `_CategoricalColumnHashed`.
+ A `_HashedCategoricalColumn`.
Raises:
ValueError: `hash_bucket_size` is not greater than 1.
@@ -472,11 +483,177 @@ def categorical_column_with_hash_bucket(key,
'hash_bucket_size: {}, key: {}'.format(
hash_bucket_size, key))
- if dtype != dtypes.string and not dtype.is_integer:
- raise ValueError('dtype must be string or integer. '
- 'dtype: {}, column_name: {}'.format(dtype, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+
+ return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
+
+
+def categorical_column_with_vocabulary_file(
+ key, vocabulary_file, vocabulary_size, num_oov_buckets=0,
+ default_value=None, dtype=dtypes.string):
+ """A `_CategoricalColumn` with a vocabulary file.
+
+ Use this when your inputs are in string or integer format, and you have a
+ vocabulary file that maps each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values
+ can be represented by `-1` for int and `''` for string. Note that these values
+ are independent of the `default_value` argument.
+
+ Example with `num_oov_buckets`:
+ File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
+ abbreviation. All inputs with values in that file are assigned an ID 0-49,
+ corresponding to its line number. All other values are hashed and assigned an
+ ID 50-54.
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
+ num_oov_buckets=5)
+ linear_prediction = make_linear_model(features, [states, ...])
+ ```
+
+ Example with `default_value`:
+ File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
+ other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
+ in input, and other values missing from the file, will be assigned ID 0. All
+ others are assigned the corresponding line number 1-50.
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
+ default_value=0)
+ linear_prediction, _, _ = make_linear_model(features, [states, ...])
+
+ And to make an embedding with either:
+ ```python
+ dense_tensor = make_input_layer(features, [embedding_column(states, 3),...])
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_file: The vocabulary file name.
+ vocabulary_size: Number of the elements in the vocabulary. This must be no
+ greater than length of `vocabulary_file`, if less than length, later
+ values are ignored.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
+ the input value. A positive `num_oov_buckets` can not be specified with
+ `default_value`.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to -1. This can not be specified with a positive
+ `num_oov_buckets`.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `_CategoricalColumn` with a vocabulary file.
+
+ Raises:
+ ValueError: `vocabulary_file` is missing.
+ ValueError: `vocabulary_size` is missing or < 1.
+ ValueError: `num_oov_buckets` is not a non-negative integer.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if not vocabulary_file:
+ raise ValueError('Missing vocabulary_file in {}.'.format(key))
+ # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
+ # TODO(ptucker): Should we fail for vocabulary_size==1?
+ if (vocabulary_size is None) or (vocabulary_size < 1):
+ raise ValueError('Invalid vocabulary_size in {}.'.format(key))
+ if num_oov_buckets:
+ if default_value is not None:
+ raise ValueError(
+ 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
+ key))
+ if num_oov_buckets < 0:
+ raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
+ num_oov_buckets, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ return _VocabularyFileCategoricalColumn(
+ key=key,
+ vocabulary_file=vocabulary_file,
+ vocabulary_size=vocabulary_size,
+ num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
+ default_value=-1 if default_value is None else default_value,
+ dtype=dtype)
+
+
+def categorical_column_with_vocabulary_list(
+ key, vocabulary_list, dtype=None, default_value=-1):
+ """A `_CategoricalColumn` with in-memory vocabulary.
+
+ Logic for feature f is:
+ id = f in vocabulary_list ? vocabulary_list.index(f) : default_value
+
+ Use this when your inputs are in string or integer format, and you have an
+ in-memory vocabulary mapping each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use `default_value` to specify how to
+ include out-of-vocabulary values.
+
+ Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values
+ can be represented by `-1` for int and `''` for string. Note that these values
+ are independent of the `default_value` argument.
+
+ In the following examples, each input in `vocabulary_list` is assigned an ID
+ 0-4 corresponding to its index (e.g., input 'B' produces output 2). All other
+ inputs are assigned `default_value` 0.
+
+ Linear model:
+ ```python
+ colors = categorical_column_with_vocabulary_list(
+ key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
+ linear_prediction, _, _ = make_linear_model(features, [colors, ...])
+ ```
+
+ Embedding for a DNN model:
+ ```python
+ dense_tensor = make_input_layer(features, [embedding_column(colors, 3),...])
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_list: An ordered iterable defining the vocabulary. Each feature
+ is mapped to the index of its value (if present) in `vocabulary_list`.
+ Must be castable to `dtype`.
+ dtype: The type of features. Only string and integer types are supported.
+ If `None`, it will be inferred from `vocabulary_list`.
+ default_value: The value to use for values not in `vocabulary_list`.
- return _CategoricalColumnHashed(key, hash_bucket_size, dtype)
+ Returns:
+ A `_CategoricalColumn` with in-memory vocabulary.
+
+ Raises:
+ ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
+ ValueError: if `dtype` is not integer or string.
+ """
+ if (vocabulary_list is None) or (len(vocabulary_list) < 1):
+ raise ValueError(
+ 'vocabulary_list {} must be non-empty, column_name: {}'.format(
+ vocabulary_list, key))
+ if len(set(vocabulary_list)) != len(vocabulary_list):
+ raise ValueError(
+ 'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
+ vocabulary_list, key))
+ vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
+ _assert_string_or_int(
+ vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
+ if dtype is None:
+ dtype = vocabulary_dtype
+ elif dtype.is_integer != vocabulary_dtype.is_integer:
+ raise ValueError(
+ 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
+ dtype, vocabulary_dtype, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+
+ return _VocabularyListCategoricalColumn(
+ key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
+ default_value=default_value)
class _FeatureColumn(object):
@@ -764,6 +941,67 @@ class _LazyBuilder(object):
return transformed
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _shape_offsets(shape):
+ """Returns moving offset for each dimension given shape."""
+ offsets = []
+ for dim in reversed(shape):
+ if offsets:
+ offsets.append(dim * offsets[-1])
+ else:
+ offsets.append(dim)
+ offsets.reverse()
+ return offsets
+
+
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _to_sparse_input(input_tensor, ignore_value=None):
+ """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
+
+ If `input_tensor` is already a `SparseTensor`, just return it.
+
+ Args:
+ input_tensor: A string or integer `Tensor`.
+ ignore_value: Entries in `dense_tensor` equal to this value will be
+ absent from the resulting `SparseTensor`. If `None`, default value of
+ `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
+
+ Returns:
+ A `SparseTensor` with the same shape as `input_tensor`.
+
+ Raises:
+ ValueError: when `input_tensor`'s rank is `None`.
+ """
+ input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ input_tensor)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ return input_tensor
+ with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
+ input_rank = input_tensor.get_shape().ndims
+ if input_rank is None:
+ # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank.
+ raise ValueError('Undefined input_tensor shape.')
+ if ignore_value is None:
+ ignore_value = '' if input_tensor.dtype == dtypes.string else -1
+ dense_shape = math_ops.cast(array_ops.shape(input_tensor), dtypes.int64)
+ indices = array_ops.where(math_ops.not_equal(
+ input_tensor, math_ops.cast(ignore_value, input_tensor.dtype)))
+ # Flattens the tensor and indices for use with gather.
+ flat_tensor = array_ops.reshape(input_tensor, [-1])
+ flat_indices = indices[:, input_rank - 1]
+ # Computes the correct flattened indices for 2d (or higher) tensors.
+ if input_rank > 1:
+ higher_dims = indices[:, :input_rank - 1]
+ shape_offsets = array_ops.stack(
+ _shape_offsets(array_ops.unstack(dense_shape)[1:]))
+ offsets = math_ops.reduce_sum(
+ math_ops.multiply(higher_dims, shape_offsets),
+ reduction_indices=[1])
+ flat_indices = math_ops.add(flat_indices, offsets)
+ values = array_ops.gather(flat_tensor, flat_indices)
+ return sparse_tensor_lib.SparseTensor(indices, values, dense_shape)
+
+
def _check_feature_columns(feature_columns):
if isinstance(feature_columns, dict):
raise ValueError('Expected feature_columns to be iterable, found dict.')
@@ -951,7 +1189,7 @@ def _check_default_value(shape, default_value, dtype, key):
`shape`.
dtype: defines the type of values. Default value is `tf.float32`. Must be a
non-quantized, real integer or floating point type.
- key: A string providing key to look up corresponding `Tensor`.
+ key: Column name, used only for error messages.
Returns:
A tuple which will be used as default value.
@@ -994,9 +1232,9 @@ def _check_default_value(shape, default_value, dtype, key):
default_value, dtype, key))
-class _CategoricalColumnHashed(
+class _HashedCategoricalColumn(
_CategoricalColumn,
- collections.namedtuple('_CategoricalColumnHashed',
+ collections.namedtuple('_HashedCategoricalColumn',
['key', 'hash_bucket_size', 'dtype'])):
"""see `categorical_column_with_hash_bucket`."""
@@ -1009,15 +1247,13 @@ class _CategoricalColumnHashed(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = inputs.get(self.key)
+ input_tensor = _to_sparse_input(inputs.get(self.key))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
- if (input_tensor.dtype != dtypes.string and
- not input_tensor.dtype.is_integer):
- raise ValueError('input tensors dtype must be string or integer. '
- 'dtype: {}, column_name: {}'.format(
- input_tensor.dtype, self.key))
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -1045,6 +1281,109 @@ class _CategoricalColumnHashed(
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+class _VocabularyFileCategoricalColumn(
+ _CategoricalColumn,
+ collections.namedtuple('_VocabularyFileCategoricalColumn', (
+ 'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype',
+ 'default_value'
+ ))):
+ """See `categorical_column_with_vocabulary_file`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input(inputs.get(self.key))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_file` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ return lookup_ops.index_table_from_file(
+ vocabulary_file=self.vocabulary_file,
+ num_oov_buckets=self.num_oov_buckets,
+ vocab_size=self.vocabulary_size,
+ default_value=self.default_value,
+ key_dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.vocabulary_size + self.num_oov_buckets
+
+ def _get_sparse_tensors(
+ self, inputs, weight_collections=None, trainable=None):
+ return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
+
+class _VocabularyListCategoricalColumn(
+ _CategoricalColumn,
+ collections.namedtuple('_VocabularyListCategoricalColumn', (
+ 'key', 'vocabulary_list', 'dtype', 'default_value'
+ ))):
+ """See `categorical_column_with_vocabulary_list`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input(inputs.get(self.key))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_tensor` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ return lookup_ops.index_table_from_tensor(
+ mapping=tuple(self.vocabulary_list),
+ default_value=self.default_value,
+ dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return len(self.vocabulary_list)
+
+ def _get_sparse_tensors(
+ self, inputs, weight_collections=None, trainable=None):
+ return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
+
# TODO(zakaria): Move this to embedding_ops and make it public.
def _safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index bc62653310..59aa39411f 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -28,9 +28,10 @@ from tensorflow.python.client import session
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -40,7 +41,7 @@ from tensorflow.python.platform import test
def _initialized_session():
sess = session.Session()
sess.run(variables_lib.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
return sess
@@ -552,7 +553,7 @@ class BucketizedColumnTest(test.TestCase):
self.assertAllClose([[81.], [141.]], predictions.eval())
-class SparseColumnHashedTest(test.TestCase):
+class HashedCategoricalColumnTest(test.TestCase):
def test_defaults(self):
a = fc.categorical_column_with_hash_bucket('aaa', 10)
@@ -578,11 +579,14 @@ class SparseColumnHashedTest(test.TestCase):
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
- column = fc.categorical_column_with_hash_bucket('aaa', 10)
- column_copy = copy.deepcopy(column)
- self.assertEqual('aaa', column_copy.name)
- self.assertEqual(10, column_copy.hash_bucket_size)
- self.assertEqual(dtypes.string, column_copy.dtype)
+ original = fc.categorical_column_with_hash_bucket('aaa', 10)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(10, column.hash_bucket_size)
+ # pylint: disable=protected-access
+ self.assertEqual(10, column._num_buckets)
+ # pylint: enable=protected-access
+ self.assertEqual(dtypes.string, column.dtype)
def test_parse_config(self):
a = fc.categorical_column_with_hash_bucket('aaa', 10)
@@ -681,14 +685,45 @@ class SparseColumnHashedTest(test.TestCase):
def test_get_sparse_tensors(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- builder = fc._LazyBuilder({'wire': wire_tensor})
- self.assertEqual(
- builder.get(hashed_sparse),
- hashed_sparse._get_sparse_tensors(builder).id_tensor)
+ builder = fc._LazyBuilder({
+ 'wire': sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ })
+ id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
+
+ def test_get_sparse_tensors_dense_input(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ builder = fc._LazyBuilder({
+ 'wire': (('omar', ''), ('stringer', 'marlo'))
+ })
+ id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
+
+ def test_make_linear_model(self):
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.make_linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
def get_linear_model_bias():
@@ -1158,5 +1193,640 @@ class MakeInputLayerTest(test.TestCase):
self.assertAllClose([[1., 3.]], net2.eval())
+def _assert_sparse_tensor_value(test_case, expected, actual):
+ test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+ test_case.assertAllEqual(expected.indices, actual.indices)
+
+ test_case.assertEqual(
+ np.array(expected.values).dtype, np.array(actual.values).dtype)
+ test_case.assertAllEqual(expected.values, actual.values)
+
+ test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
+ test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
+
+
+class VocabularyFileCategoricalColumnTest(test.TestCase):
+
+ def setUp(self):
+ super(VocabularyFileCategoricalColumnTest, self).setUp()
+
+ # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
+ self._warriors_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/warriors_vocabulary.txt')
+ self._warriors_vocabulary_size = 5
+
+ # Contains strings, character names from 'The Wire': omar, stringer, marlo
+ self._wire_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/wire_vocabulary.txt')
+ self._wire_vocabulary_size = 3
+
+ def test_defaults(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ # pylint: disable=protected-access
+ self.assertEqual(7, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(7, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_vocabulary_file_none(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=None, vocabulary_size=3)
+
+ def test_vocabulary_file_empty_string(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='', vocabulary_size=3)
+
+ def test_invalid_vocabulary_file(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_vocabulary_size(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=None)
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=-1)
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=0)
+
+ def test_too_large_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size + 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_num_oov_buckets(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ num_oov_buckets=-1)
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ dtype=dtypes.float64)
+
+ def test_invalid_buckets_and_default_value(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'both num_oov_buckets and default_value'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100,
+ default_value=2)
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ dtype=dtypes.string)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (1, 2)),
+ values=('marlo', 'skywalker', 'omar', 'heisenberg'),
+ dense_shape=(2, 3))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 33, 0, 62), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_small_vocabulary_size(self):
+ # 'marlo' is the last entry in our vocabulary file, so be setting
+ # `vocabulary_size` to 1 less than number of entries in file, we take
+ # 'marlo' out of the vocabulary.
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size - 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((-1, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ default_value=default_value)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 60, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_make_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.make_linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+
+class VocabularyListCategoricalColumnTest(test.TestCase):
+
+ def test_defaults_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_defaults_int(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36))
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
+ default_value=-99)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.float32)
+
+ def test_invalid_mapping_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12., 24., 36.))
+
+ def test_mismatched_int_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.int32)
+
+ def test_mismatched_string_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string)
+
+ def test_none_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=None)
+
+ def test_empty_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=tuple([]))
+
+ def test_duplicate_mapping(self):
+ with self.assertRaisesRegexp(ValueError, 'Duplicate keys'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 12))
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=(12, 24, 36))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((11, 100, 30, 22), dtype=np.int32),
+ dense_shape=(3, 3))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32,
+ default_value=default_value)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': np.array(
+ ((11, -1, -1), (100, 30, -1), (-1, -1, 22)),
+ dtype=np.int32)
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_make_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ self.assertEqual(3, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.make_linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> None, 'omar' -> 0: wire_var[0] = 1
+ self.assertAllClose(((3.,), (1.,)), predictions.eval())
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py
index 9dc7414cd0..8225b47b20 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/python/feature_column/lookup_ops.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Lookup table Operations."""
-# pylint: disable=g-bad-name
+"""Lookup table operations."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,7 +27,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saver import BaseSaverBuilder
@@ -151,7 +151,7 @@ class InitializableLookupTableBase(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as scope:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope)
# pylint: enable=protected-access
def lookup(self, keys, name=None):
@@ -182,7 +182,7 @@ class InitializableLookupTableBase(LookupInterface):
name, "%s_Lookup" % self._name,
(self._table_ref, key_tensor, self._default_value)) as scope:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(
+ values = gen_lookup_ops._lookup_table_find(
self._table_ref, key_tensor, self._default_value, name=scope)
# pylint: enable=protected-access
@@ -229,7 +229,7 @@ class HashTable(InitializableLookupTableBase):
with ops.name_scope(
name, "hash_table", (initializer, default_value)) as scope:
# pylint: disable=protected-access
- table_ref = gen_data_flow_ops._hash_table(
+ table_ref = gen_lookup_ops._hash_table(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
@@ -308,10 +308,8 @@ class KeyValueTensorInitializer(TableInitializerBase):
self._name,
values=(table.table_ref, self._keys, self._values)) as scope:
# pylint: disable=protected-access
- init_op = gen_data_flow_ops._initialize_table(table.table_ref,
- self._keys,
- self._values,
- name=scope)
+ init_op = gen_lookup_ops._initialize_table(
+ table.table_ref, self._keys, self._values, name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
@@ -477,7 +475,7 @@ class TextFileInitializer(TableInitializerBase):
dtypes.string,
name="asset_filepath")
# pylint: disable=protected-access
- init_op = gen_data_flow_ops._initialize_table_from_text_file(
+ init_op = gen_lookup_ops._initialize_table_from_text_file(
table.table_ref,
filename,
self._key_index,
@@ -608,7 +606,7 @@ class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
__slots__ = ()
-FastHashSpec = HasherSpec("fasthash", None)
+FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name
class StrongHashSpec(HasherSpec):
@@ -1333,14 +1331,14 @@ class MutableHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
# pylint: disable=protected-access
if self._default_value.get_shape().ndims == 0:
- self._table_ref = gen_data_flow_ops._mutable_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_hash_table(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
value_dtype=value_dtype,
name=name)
else:
- self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
+ self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
@@ -1368,7 +1366,7 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1394,10 +1392,8 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(self._table_ref,
- keys,
- self._default_value,
- name=name)
+ values = gen_lookup_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
@@ -1423,7 +1419,7 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
- op = gen_data_flow_ops._lookup_table_insert(
+ op = gen_lookup_ops._lookup_table_insert(
self._table_ref, keys, values, name=name)
return op
@@ -1440,11 +1436,8 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- exported_keys, exported_values = gen_data_flow_ops._lookup_table_export(
- self._table_ref,
- self._key_dtype,
- self._value_dtype,
- name=name)
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
self._value_shape))
@@ -1464,7 +1457,7 @@ class MutableHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_import(
+ return gen_lookup_ops._lookup_table_import(
self.op._table_ref, restored_tensors[0], restored_tensors[1])
@@ -1539,7 +1532,7 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
# pylint: disable=protected-access
- self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_dense_hash_table(
empty_key=empty_key,
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
@@ -1567,7 +1560,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1593,7 +1586,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(
+ values = gen_lookup_ops._lookup_table_find(
self._table_ref, keys, self._default_value, name=name)
if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
@@ -1623,7 +1616,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
- op = gen_data_flow_ops._lookup_table_insert(
+ op = gen_lookup_ops._lookup_table_insert(
self._table_ref, keys, values, name=name)
return op
@@ -1640,7 +1633,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- exported_keys, exported_values = gen_data_flow_ops._lookup_table_export(
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
@@ -1661,6 +1654,5 @@ class MutableDenseHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_import(self.op._table_ref,
- restored_tensors[0],
- restored_tensors[1])
+ return gen_lookup_ops._lookup_table_import(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])
diff --git a/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt
new file mode 100644
index 0000000000..6c917fa699
--- /dev/null
+++ b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt
@@ -0,0 +1,5 @@
+30
+35
+11
+23
+22
diff --git a/tensorflow/python/feature_column/testdata/wire_vocabulary.txt b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt
new file mode 100644
index 0000000000..32c6b5692a
--- /dev/null
+++ b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt
@@ -0,0 +1,3 @@
+omar
+stringer
+marlo
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 452cf3be70..0b04904ec2 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -70,25 +70,33 @@ def _override_helper(clazz_object, operator, func):
setattr(clazz_object, operator, func)
-def _convert_stack(stack):
+def _convert_stack(stack, include_func_start_lineno=False):
"""Converts a stack extracted using _extract_stack() to a traceback stack.
Args:
- stack: A list of n 4-tuples, (filename, lineno, name, frame_globals).
+ stack: A list of n 5-tuples,
+ (filename, lineno, name, frame_globals, func_start_lineno).
+ include_func_start_lineno: True if function start line number should be
+ included as the 5th entry in return tuples.
Returns:
- A list of n 4-tuples (filename, lineno, name, code), where the code tuple
- element is calculated from the corresponding elements of the input tuple.
+ A list of n 4-tuples or 5-tuples
+ (filename, lineno, name, code, [optional: func_start_lineno]), where the
+ code tuple element is calculated from the corresponding elements of the
+ input tuple.
"""
ret = []
- for filename, lineno, name, frame_globals in stack:
+ for filename, lineno, name, frame_globals, func_start_lineno in stack:
linecache.checkcache(filename)
line = linecache.getline(filename, lineno, frame_globals)
if line:
line = line.strip()
else:
line = None
- ret.append((filename, lineno, name, line))
+ if include_func_start_lineno:
+ ret.append((filename, lineno, name, line, func_start_lineno))
+ else:
+ ret.append((filename, lineno, name, line))
return ret
@@ -103,7 +111,8 @@ def _extract_stack():
be formatted etc. using traceback methods.
Returns:
- A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to
+ A list of 5-tuples
+ (filename, lineno, name, frame_globals, func_start_lineno) corresponding to
the call stack of the current thread.
"""
# pylint: enable=line-too-long
@@ -118,7 +127,8 @@ def _extract_stack():
filename = co.co_filename
name = co.co_name
frame_globals = f.f_globals
- ret.append((filename, lineno, name, frame_globals))
+ func_start_lineno = co.co_firstlineno
+ ret.append((filename, lineno, name, frame_globals, func_start_lineno))
f = f.f_back
ret.reverse()
return ret
@@ -1505,6 +1515,15 @@ class Operation(object):
"""Returns the call stack from when this operation was constructed."""
return _convert_stack(self._traceback)
+ @property
+ def traceback_with_start_lines(self):
+ """Same as traceback but includes start line of function definition.
+
+ Returns:
+ A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
+ """
+ return _convert_stack(self._traceback, include_func_start_lineno=True)
+
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 06d03121a0..3e9f047a7d 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -22,6 +22,7 @@ import gc
import weakref
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
@@ -1703,5 +1704,26 @@ class NameScopeTest(test_util.TensorFlowTestCase):
self.assertEqual("", g.get_name_scope())
+class TracebackTest(test_util.TensorFlowTestCase):
+
+ def testTracebackWithStartLines(self):
+ with self.test_session() as sess:
+ a = constant_op.constant(2.0)
+ sess.run(
+ a,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertTrue(sess.graph.get_operations())
+
+ # Tests that traceback_with_start_lines is the same as traceback
+ # but includes one more element at the end.
+ for op in sess.graph.get_operations():
+ self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines))
+ for frame, frame_with_start_line in zip(
+ op.traceback, op.traceback_with_start_lines):
+ self.assertEquals(5, len(frame_with_start_line))
+ self.assertEquals(frame, frame_with_start_line[:-1])
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index 3630adc954..50a0795200 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -249,6 +249,23 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "identity_bijector_test",
+ size = "small",
+ srcs = ["identity_bijector_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/distributions",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
index 0969c293d4..e8f9d0b728 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py
+++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
-from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity
+from tensorflow.python.ops.distributions import bijector_test_util
+from tensorflow.python.ops.distributions import identity_bijector
from tensorflow.python.platform import test
@@ -28,7 +28,7 @@ class IdentityBijectorTest(test.TestCase):
def testBijector(self):
with self.test_session():
- bijector = Identity()
+ bijector = identity_bijector.Identity()
self.assertEqual("identity", bijector.name)
x = [[[0.], [1.]]]
self.assertAllEqual(x, bijector.forward(x).eval())
@@ -38,8 +38,8 @@ class IdentityBijectorTest(test.TestCase):
def testScalarCongruency(self):
with self.test_session():
- bijector = Identity()
- assert_scalar_congruency(
+ bijector = identity_bijector.Identity()
+ bijector_test_util.assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index 8099175186..a0bd178e24 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase):
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True).eval()
+ def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
+ # Note: use_gpu=False because nice errors are only returned from CPU kerne
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True):
+ indices = np.array([[1, 10]]).astype(np.int64)
+ values = np.array([10]).astype(np.float32)
+ shape = [3, 2]
+ sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
+
+ # Test multiplying by both a small and large dense matrix, to hit
+ # both cases in the kernel.
+ dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+ dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
+ expected_t = np.array(
+ [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+
+ # Repeat with adjoint_a, now the error is that the sparse index
+ # is OOO w.r.t. the output. The GPU kernel can't do much here,
+ # so it just doesn't accumulate.
+
+ dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
+ dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
+ expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
# Tests setting one dimension to be a high value.
def _testLarge(self, np_dtype):
r1 = np.random.randint(6000, 20000)
@@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase):
y = _maybe_complex(np.random.randn(k, n).astype(np_dtype))
- self._testMatmul(x, y)
+ self._testMatmul(x, y, adjoint_a=False, adjoint_b=False)
+ self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False)
+ self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True)
+ self._testMatmul(
+ x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True)
- def testLarge(self):
np.random.seed(127) # Repeatable results
self._testLarge(np.float32)
self._testLarge(np.float64)
@@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
if skip_dense:
delta_dense = float("nan")
else:
- with session.Session("", config=config, graph=ops.Graph()) as sess:
+ with session.Session(config=config, graph=ops.Graph()) as sess:
if not use_gpu:
with ops.device("/cpu:0"):
x_t = constant_op.constant(x)
@@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
x_t, y_t, adjoint_a, adjoint_b)
else:
- x_t = constant_op.constant(x)
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t,
- adjoint_a,
- adjoint_b)
- delta_dense = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_t = constant_op.constant(x)
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
+ x_t, y_t, adjoint_a, adjoint_b)
+ delta_dense = _timer(sess, ops_fn, 200)
# Using sparse_tensor_dense_matmul.
with session.Session("", config=config, graph=ops.Graph()) as sess:
@@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
else:
- x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
- x_val = constant_op.constant(x[np.where(x)])
- x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
- x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
- delta_sparse = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
+ x_val = constant_op.constant(x[np.where(x)])
+ x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
+ x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
+ delta_sparse = _timer(sess, ops_fn, 200)
print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" %
(1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse,
@@ -340,7 +389,7 @@ def main(_):
"\t dt(sparse)/dt(dense)")
for thresh in (0.99, 0.8, 0.5, 0.2):
- for n in (1, 10, 25):
+ for n in (50, 100):
for use_gpu in (True, False):
for m in (100, 1000):
for k in (100, 1000):
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 95e803e2aa..9a208613ad 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -38,7 +38,6 @@ from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.deprecation import deprecated
def _as_type_list(dtypes):
@@ -1037,47 +1036,6 @@ class Barrier(object):
self._barrier_ref, name=name)
-@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.")
-def initialize_all_tables(name="init_all_tables"):
- """Returns an Op that initializes all tables of the default graph.
-
- Args:
- name: Optional name for the initialization op.
-
- Returns:
- An Op that initializes all tables. Note that if there are
- not tables the returned Op is a NoOp.
- """
- return tables_initializer(name)
-
-
-def tables_initializer(name="init_all_tables"):
- """Returns an Op that initializes all tables of the default graph.
-
- Args:
- name: Optional name for the initialization op.
-
- Returns:
- An Op that initializes all tables. Note that if there are
- not tables the returned Op is a NoOp.
- """
- initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
- if initializers:
- return control_flow_ops.group(*initializers, name=name)
- return control_flow_ops.no_op(name=name)
-
-
-ops.NotDifferentiable("LookupTableFind")
-ops.NotDifferentiable("LookupTableInsert")
-ops.NotDifferentiable("LookupTableSize")
-ops.NotDifferentiable("HashTable")
-ops.NotDifferentiable("InitializeTable")
-ops.NotDifferentiable("InitializeTableFromTextFile")
-ops.NotDifferentiable("MutableDenseHashTable")
-ops.NotDifferentiable("MutableHashTable")
-ops.NotDifferentiable("MutableHashTableOfTensors")
-
-
class ConditionalAccumulatorBase(object):
"""A conditional accumulator for aggregating gradients.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py b/tensorflow/python/ops/distributions/bijector_test_util.py
index ff3535c626..ff3535c626 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py
+++ b/tensorflow/python/ops/distributions/bijector_test_util.py
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py b/tensorflow/python/ops/distributions/identity_bijector.py
index f277eda8bb..f277eda8bb 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py
+++ b/tensorflow/python/ops/distributions/identity_bijector.py
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index e146e20d3a..09b26a9fb7 100644
--- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -21,7 +21,6 @@ import numpy as np
# Bijectors must be directly imported because `remove_undocumented` prevents
# individual file imports.
-from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -32,6 +31,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import identity_bijector
from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -265,7 +265,7 @@ class TransformedDistribution(distribution_lib.Distribution):
self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
if bijector is None:
- bijector = Identity(validate_args=validate_args)
+ bijector = identity_bijector.Identity(validate_args=validate_args)
# We will keep track of a static and dynamic version of
# self._is_{batch,event}_override. This way we can do more prior to graph
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
new file mode 100644
index 0000000000..54dba9e38e
--- /dev/null
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#==============================================================================
+"""Data Flow Operations."""
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_lookup_ops import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.deprecation import deprecated
+
+
+@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.")
+def initialize_all_tables(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Args:
+ name: Optional name for the initialization op.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ return tables_initializer(name)
+
+
+def tables_initializer(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Args:
+ name: Optional name for the initialization op.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
+ if initializers:
+ return control_flow_ops.group(*initializers, name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+ops.NotDifferentiable("LookupTableFind")
+ops.NotDifferentiable("LookupTableFindV2")
+ops.NotDifferentiable("LookupTableInsert")
+ops.NotDifferentiable("LookupTableInsertV2")
+ops.NotDifferentiable("LookupTableSize")
+ops.NotDifferentiable("LookupTableSizeV2")
+ops.NotDifferentiable("HashTable")
+ops.NotDifferentiable("HashTableV2")
+ops.NotDifferentiable("InitializeTable")
+ops.NotDifferentiable("InitializeTableV2")
+ops.NotDifferentiable("InitializeTableFromTextFile")
+ops.NotDifferentiable("InitializeTableFromTextFileV2")
+ops.NotDifferentiable("MutableDenseHashTable")
+ops.NotDifferentiable("MutableDenseHashTableV2")
+ops.NotDifferentiable("MutableHashTable")
+ops.NotDifferentiable("MutableHashTableV2")
+ops.NotDifferentiable("MutableHashTableOfTensors")
+ops.NotDifferentiable("MutableHashTableOfTensorsV2")
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 4dc8e702ca..28ed3af9d7 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -1924,7 +1924,74 @@ def recall_at_k(labels,
labels = _maybe_expand_labels(labels, predictions)
_, top_k_idx = nn.top_k(predictions, k)
- top_k_idx = math_ops.to_int64(top_k_idx)
+ return _sparse_recall_at_top_k(
+ labels=labels,
+ predictions_idx=top_k_idx,
+ k=k,
+ class_id=class_id,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=scope)
+
+
+def _sparse_recall_at_top_k(labels,
+ predictions_idx,
+ k=None,
+ class_id=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes recall@k of top-k predictions with respect to sparse labels.
+
+ Differs from `recall_at_k` in that predictions must be in the form of top `k`
+ class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
+ for more details.
+
+ Args:
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
+ the associated prediction. Commonly, N=1 and `labels` has shape
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
+ should be in range [0, num_classes), where num_classes is the last
+ dimension of `predictions`. Values outside this range always count
+ towards `false_negative_at_<k>`.
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
+ Commonly, N=1 and predictions has shape [batch size, k]. The final
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
+ match `labels`.
+ k: Integer, k for @k metric.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If class_id is outside this range, the method returns NAN.
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
+ dimensions must be either `1`, or the same as the corresponding `labels`
+ dimension).
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
+ by the sum of `true_positives` and `false_negatives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_negatives` variables appropriately, and whose value matches
+ `recall`.
+
+ Raises:
+ ValueError: If `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ with ops.name_scope(name,
+ _at_k_name('recall', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
+ top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 4810e97b36..c7ac742b5d 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -28,6 +28,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.util import nest
@@ -75,11 +77,13 @@ def _zero_state_tensors(state_size, batch_size, dtype):
return zeros
-class _RNNCell(base_layer.Layer): # pylint: disable=protected-access
+class _RNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
- Every `RNNCell` must have the properties below and implement `__call__` with
- the following signature.
+ Every `RNNCell` must have the properties below and implement `call` with
+ the signature `(output, next_state) = call(input, state)`. The optional
+ third input argument, `scope`, is allowed for backwards compatibility
+ purposes; but should be left off for new subclasses.
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
@@ -90,8 +94,9 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
- tuple of integers, then it results in a tuple of `len(state_size)` state
- matrices, each with a column size corresponding to values in `state_size`.
+ (possibly nested tuple of) TensorShape object(s), then it should return a
+ matching structure of Tensors having shape `[batch_size].concatenate(s)`
+ for each `s` in `self.batch_size`.
"""
def __call__(self, inputs, state, scope=None):
@@ -112,7 +117,25 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
the arity and shapes of `state`.
"""
- return super(_RNNCell, self).__call__(inputs, state, scope=scope)
+ if scope is not None:
+ with vs.variable_scope(scope,
+ custom_getter=self._rnn_get_variable) as scope:
+ return super(_RNNCell, self).__call__(inputs, state, scope=scope)
+ else:
+ with vs.variable_scope(vs.get_variable_scope(),
+ custom_getter=self._rnn_get_variable):
+ return super(_RNNCell, self).__call__(inputs, state)
+
+ def _rnn_get_variable(self, getter, *args, **kwargs):
+ variable = getter(*args, **kwargs)
+ trainable = (variable in tf_variables.trainable_variables() or
+ (isinstance(variable, tf_variables.PartitionedVariable) and
+ list(variable)[0] in tf_variables.trainable_variables()))
+ if trainable and variable not in self._trainable_weights:
+ self._trainable_weights.append(variable)
+ elif not trainable and variable not in self._non_trainable_weights:
+ self._non_trainable_weights.append(variable)
+ return variable
@property
def state_size(self):
@@ -128,6 +151,11 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")
+ def build(self, _):
+ # This tells the parent Layer object that it's OK to call
+ # self.add_variable() inside the call() method.
+ pass
+
def zero_state(self, batch_size, dtype):
"""Return zero-filled state tensor(s).
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 0140a27aaa..d6cb7c5be4 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -241,6 +241,8 @@ def sparse_add(a, b, thresh=0):
of arguments does not matter. Use vanilla `tf.add()` for adding two dense
`Tensor`s.
+ The shapes of the two operands must match: broadcasting is not supported.
+
The indices of any input `SparseTensor` are assumed ordered in standard
lexicographic order. If this is not the case, before this step run
`SparseReorder` to restore index ordering.
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 09e04d4247..a39d28490c 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -57,6 +57,7 @@ from tensorflow.python.ops.io_ops import *
from tensorflow.python.ops.linalg_ops import *
from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
+from tensorflow.python.ops.lookup_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py
index 66cf9d4d8a..355fd57bf1 100644
--- a/tensorflow/python/saved_model/main_op_impl.py
+++ b/tensorflow/python/saved_model/main_op_impl.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
@@ -35,7 +35,7 @@ def main_op():
"""
init = variables.global_variables_initializer()
init_local = variables.local_variables_initializer()
- init_tables = tf_data_flow_ops.tables_initializer()
+ init_tables = lookup_ops.tables_initializer()
return control_flow_ops.group(init, init_local, init_tables)
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 2fea29d961..c9c56a5014 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -14,68 +14,8 @@
# ==============================================================================
"""Command-line interface to inspect and execute a graph in a SavedModel.
-If TensorFlow is installed on your system through pip, the 'saved_model_cli'
-binary can be invoked directly from command line.
-
-At a high level, SavedModel CLI allows users to both inspect and execute
-computations on a MetaGraphDef in a SavedModel. These are done through `show`
-and `run` commands. Following is the usage of the two commands. SavedModel
-CLI will also display these information with -h option.
-
-'show' command usage: saved_model_cli show [-h] --dir DIR [--tag_set TAG_SET]
- [--signature_def SIGNATURE_DEF_KEY]
-Examples:
-To show all available tag-sets in the SavedModel:
- $saved_model_cli show --dir /tmp/saved_model
-
-To show all available SignatureDef keys in a MetaGraphDef specified by its
-tag-set:
- $saved_model_cli show --dir /tmp/saved_model --tag_set serve
-For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed
-in, separated by ',':
- $saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu
-
-To show all inputs and outputs TensorInfo for a specific SignatureDef specified
-by the SignatureDef key in a MetaGraphDef:
- $saved_model_cli show --dir /tmp/saved_model --tag_set serve
- --signature_def serving_default
-Example output:
- The given SavedModel SignatureDef contains the following input(s):
- inputs['input0'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- inputs['input1'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- The given SavedModel SignatureDef contains the following output(s):
- outputs['output'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- Method name is: tensorflow/serving/regress
-
-To show all available information in the SavedModel:
- $saved_model_cli show --dir /tmp/saved_model --all
-
-usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
- SIGNATURE_DEF_KEY [--inputs INPUTS]
- [--input_exprs INPUT_EXPRS] [--outdir OUTDIR]
- [--overwrite] [--tf_debug]
-
-Examples:
-To run input tensors from files through a MetaGraphDef and save the output
-tensors to files:
- $saved_model_cli run --dir /tmp/saved_model --tag_set serve
- --signature_def serving_default --inputs x=/tmp/124.npz
- --input_exprs 'x2=np.ones((6,2))' --outdir /tmp/out
-
-To observe the intermediate Tensor values in the runtime graph, use the
---tf_debug flag, e.g.:
- $saved_model_cli run --dir /tmp/saved_model --tag_set serve
- --signature_def serving_default --inputs 'x=/tmp/124.npz;x2=/tmp/123.npy'
- --outdir /tmp/out --tf_debug
-
-To build this tool from source, run:
- $bazel build tensorflow/python/tools:saved_model_cli
+For detailed usages and examples, please refer to:
+https://www.tensorflow.org/programmers_guide/saved_model_cli
"""
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index c905672313..a891bae5f2 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -26,7 +26,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@@ -238,7 +238,7 @@ class Scaffold(object):
@staticmethod
def _default_local_init_op():
return control_flow_ops.group(variables.local_variables_initializer(),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py
index 5f31e2aa53..6a73565f82 100644
--- a/tensorflow/python/training/saver_test_utils.py
+++ b/tensorflow/python/training/saver_test_utils.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.training import saver as saver_module
@@ -34,7 +34,7 @@ class CheckpointedOp(object):
# pylint: disable=protected-access
def __init__(self, name, table_ref=None):
if table_ref is None:
- self.table_ref = gen_data_flow_ops._mutable_hash_table(
+ self.table_ref = gen_lookup_ops._mutable_hash_table(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else:
self.table_ref = table_ref
@@ -52,10 +52,10 @@ class CheckpointedOp(object):
return self._saveable
def insert(self, keys, values):
- return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values)
+ return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values)
def lookup(self, keys, default):
- return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default)
+ return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default)
def keys(self):
return self._export()[0]
@@ -64,8 +64,8 @@ class CheckpointedOp(object):
return self._export()[1]
def _export(self):
- return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string,
- dtypes.float32)
+ return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string,
+ dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp."""
@@ -81,6 +81,6 @@ class CheckpointedOp(object):
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
- return gen_data_flow_ops._lookup_table_import(
+ return gen_lookup_ops._lookup_table_import(
self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access
diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py
index d2ccf37d88..2091eca0b9 100644
--- a/tensorflow/python/training/server_lib.py
+++ b/tensorflow/python/training/server_lib.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
@@ -276,14 +277,14 @@ class ClusterSpec(object):
"from integers to strings." % job_name)
self._cluster_spec[job_name] = job_tasks
self._make_cluster_def()
- elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
+ elif isinstance(cluster, cluster_pb2.ClusterDef):
self._cluster_def = cluster
self._cluster_spec = {}
for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()}
elif isinstance(cluster, ClusterSpec):
- self._cluster_def = tensorflow_server_pb2.ClusterDef()
+ self._cluster_def = cluster_pb2.ClusterDef()
self._cluster_def.MergeFrom(cluster.as_cluster_def())
self._cluster_spec = {}
for job_def in self._cluster_def.job:
@@ -440,7 +441,7 @@ class ClusterSpec(object):
TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
of strings.
"""
- self._cluster_def = tensorflow_server_pb2.ClusterDef()
+ self._cluster_def = cluster_pb2.ClusterDef()
# NOTE(mrry): Sort by job_name to produce deterministic protobufs.
for job_name, tasks in sorted(self._cluster_spec.items()):
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 277c11386d..230ed1db68 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -27,7 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as _summary
@@ -426,8 +426,10 @@ class Supervisor(object):
local_init_op = self._get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
- op_list = [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]
+ op_list = [
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index bdf3d9c017..f4ac3c9758 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import *
# pylint: enable=wildcard-import
# Distributed computing support.
-from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef
-from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.cluster_pb2 import JobDef
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.training.server_lib import Server
@@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server
_allowed_symbols = [
# TODO(cwhipkey): review these and move to contrib or expose through
# documentation.
- "generate_checkpoint_state_proto", # Used internally by saver.
+ "generate_checkpoint_state_proto", # Used internally by saver.
"checkpoint_exists", # Only used in test?
"get_checkpoint_mtimes", # Only used in test?
# Legacy: remove.
"do_quantize_training_on_graphdef", # At least use grah_def, not graphdef.
- # No uses within tensorflow.
+ # No uses within tensorflow.
"queue_runner", # Use tf.train.start_queue_runner etc directly.
- # This is also imported internally.
+ # This is also imported internally.
# TODO(drpng): document these. The reference in howtos/distributed does
# not link.
"SyncReplicasOptimizer",
# Protobufs:
- "BytesList", # from example_pb2.
+ "BytesList", # from example_pb2.
"ClusterDef",
- "Example", # from example_pb2
- "Feature", # from example_pb2
- "Features", # from example_pb2
- "FeatureList", # from example_pb2
- "FeatureLists", # from example_pb2
- "FloatList", # from example_pb2.
- "Int64List", # from example_pb2.
+ "Example", # from example_pb2
+ "Feature", # from example_pb2
+ "Features", # from example_pb2
+ "FeatureList", # from example_pb2
+ "FeatureLists", # from example_pb2
+ "FloatList", # from example_pb2.
+ "Int64List", # from example_pb2.
"JobDef",
- "SaverDef", # From saver_pb2.
- "SequenceExample", # from example_pb2.
+ "SaverDef", # From saver_pb2.
+ "SequenceExample", # from example_pb2.
"ServerDef",
]
# Include extra modules for docstrings because:
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index 69f08495a3..d424f103dd 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -30,7 +30,7 @@
"merge2": "~0.3.6",
"minimist": "~1.2.0",
"tsify": "^0.14.8",
- "typescript": "2.2.2",
+ "typescript": "2.3.1",
"typings": "1.4.0",
"vinyl-source-stream": "^1.1.0",
"vulcanize": "^1.14.0",
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
index 805a9bdd4f..da6af3919e 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
@@ -7,6 +7,10 @@ tf_class {
mtype: "<type \'int\'>"
}
member {
+ name: "CLUSTER_DEF_FIELD_NUMBER"
+ mtype: "<type \'int\'>"
+ }
+ member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt
index 0f43a49ee9..64240f7069 100644
--- a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt
@@ -39,6 +39,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "traceback_with_start_lines"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "type"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
index feb73bd7d4..93ff856b09 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.ClusterDef"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.ClusterDef\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.ClusterDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
index 2d7fcbe545..ac6d81541a 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef.TasksEntry"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.TasksEntry\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.TasksEntry\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
index fc5b76341d..ce34537fa1 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.JobDef\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.JobDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 61c3fe5540..459d6ee328 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -28,11 +28,13 @@ import subprocess
PIP_PACKAGE_QUERY = """bazel query \
'deps(//tensorflow/tools/pip_package:build_pip_package)'"""
-PY_TEST_QUERY = """bazel query 'filter("^((?!(benchmark|manual|no_pip)).)*$", \
- deps(kind(py_test,\
- //tensorflow/python/... + \
- //tensorflow/tensorboard/... + \
- //tensorflow/contrib/...), 1))'"""
+PY_TEST_QUERY = """bazel query 'deps(\
+ filter("^((?!benchmark).)*$",\
+ kind(py_test,\
+ //tensorflow/python/... \
+ + //tensorflow/tensorboard/... \
+ + //tensorflow/contrib/... \
+ - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'"""
# Hard-coded blacklist of files if not included in pip package
# TODO(amitpatankar): Clean up blacklist.
@@ -45,6 +47,7 @@ BLACKLIST = [
"//tensorflow/python:compare_test_proto_py",
"//tensorflow/core:image_testdata",
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
+ "//tensorflow/python/feature_column:vocabulary_testdata",
"//tensorflow/python:framework/test_file_system.so",
# contrib
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
@@ -54,7 +57,7 @@ BLACKLIST = [
"//tensorflow/contrib/factorization/examples:mnist.py",
"//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
- "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long
+ "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
]
@@ -121,7 +124,10 @@ def main():
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))
- raise RuntimeError("One or more dependencies are not in the pip package.")
+ raise RuntimeError("""One or more dependencies are not in the pip package.
+Please either blacklist the dependencies in
+tensorflow/tensorflow/tensorflow/tools/pip_package/pip_smoke_test.py
+or add them to tensorflow/tensorflow/tensorflow/tools/pip_package/BUILD.""")
else:
print("TEST PASSED")
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6270b95b6b..3831a481ba 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -687,13 +687,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
name = "com_microsoft_typescript",
licenses = ["notice"], # Apache 2.0
sha256_urls = {
- "43a7c763fe024d5add8d5365e5a7981f4a359ba5bf86481f545a0db8f60d48cc": [
- "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js",
- "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js",
+ "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js",
+ "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js",
],
- "aecec1e47a3b3d872e214cb9adb82b30d6bd0471ea0aad7311ad81428566627c": [
- "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts",
- "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts",
+ "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts",
+ "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts",
],
},
extra_build_file_content = "\n".join([