aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-01 17:12:55 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-03-01 17:12:55 -0800
commitb8ff42a4b3c0bbc4eadaa5ee5cd1b10adb58de7e (patch)
tree3fd6fb903c758ea93ac1fad9a130cf6589c011a2
parentca7cb4e5ef7f6af36ea9e51561587e5ded42bbef (diff)
parent8fa1b728b6649231cfcd9dee7138579812341b32 (diff)
Merge commit for internal changes
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py46
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py3
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/client/tensor_c_api.cc14
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc3
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc8
-rw-r--r--tensorflow/core/common_runtime/executor.cc16
-rw-r--r--tensorflow/core/common_runtime/executor.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc79
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h8
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc5
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc10
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc24
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.h4
-rw-r--r--tensorflow/core/distributed_runtime/README.md6
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc123
-rw-r--r--tensorflow/core/framework/allocation_description.proto3
-rw-r--r--tensorflow/core/framework/allocator.cc14
-rw-r--r--tensorflow/core/framework/allocator.h6
-rw-r--r--tensorflow/core/framework/device_base.h4
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h3
-rw-r--r--tensorflow/core/framework/log_memory.cc83
-rw-r--r--tensorflow/core/framework/log_memory.h115
-rw-r--r--tensorflow/core/framework/log_memory.proto93
-rw-r--r--tensorflow/core/framework/node_def_util.cc5
-rw-r--r--tensorflow/core/framework/node_def_util.h8
-rw-r--r--tensorflow/core/framework/op_kernel.cc198
-rw-r--r--tensorflow/core/framework/op_kernel.h229
-rw-r--r--tensorflow/core/framework/tensor.cc20
-rw-r--r--tensorflow/core/framework/tensor.h101
-rw-r--r--tensorflow/core/framework/tensor_shape.cc5
-rw-r--r--tensorflow/core/framework/tensor_shape.h3
-rw-r--r--tensorflow/core/graph/node_builder.cc6
-rw-r--r--tensorflow/core/graph/node_builder.h15
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc16
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h46
-rw-r--r--tensorflow/core/kernels/ops_util.cc9
-rw-r--r--tensorflow/core/kernels/ops_util.h6
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.cc15
-rw-r--r--tensorflow/core/kernels/queue_base.cc2
-rw-r--r--tensorflow/core/kernels/queue_base.h2
-rw-r--r--tensorflow/core/kernels/relu_op.cc71
-rw-r--r--tensorflow/core/kernels/scatter_op.cc112
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc40
-rw-r--r--tensorflow/core/kernels/softplus_op.cc20
-rw-r--r--tensorflow/core/kernels/softsign_op.cc21
-rw-r--r--tensorflow/core/lib/core/errors.h69
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md2
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassSession.md37
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassTensor.md43
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassTensorShape.md6
-rw-r--r--tensorflow/g3doc/api_docs/cc/index.md34
-rw-r--r--tensorflow/g3doc/api_docs/python/constant_op.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.layers.md274
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.util.md23
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md14
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md16
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md240
-rw-r--r--tensorflow/g3doc/api_docs/python/sparse_ops.md89
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md26
-rw-r--r--tensorflow/g3doc/api_docs/python/test.md12
-rw-r--r--tensorflow/python/client/tf_session_helper.cc26
-rw-r--r--tensorflow/python/framework/function.py89
-rw-r--r--tensorflow/python/framework/function_test.py43
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py18
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py4
-rw-r--r--tensorflow/python/ops/constant_op.py5
-rw-r--r--tensorflow/python/ops/script_ops.py6
-rw-r--r--tensorflow/python/training/saver.py5
-rw-r--r--tensorflow/tensorflow.bzl1
-rw-r--r--tensorflow/tools/docs/gen_cc_md.py20
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
83 files changed, 2038 insertions, 739 deletions
diff --git a/eigen.BUILD b/eigen.BUILD
index 6f74c966c3..0a7b4bbc06 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-73a4995594c6"
+archive_dir = "eigen-eigen-017cff30cf74"
cc_library(
name = "eigen",
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 930df7bcac..ee5adf99e2 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -104,7 +104,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
'gender': [1]}, 1),
]
example_weights = [1.0, 1.0]
- with self.test_session(use_gpu=False):
+ config = tf.ConfigProto(inter_op_parallelism_threads=1,
+ intra_op_parallelism_threads=1)
+ with self.test_session(use_gpu=False, config=config):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=0.5,
@@ -119,8 +121,10 @@ class SdcaOptimizerTest(TensorFlowTestCase):
self.assertAllClose(0.693147, unregularized_loss.eval())
self.assertAllClose(0.693147, loss.eval())
lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval())
- self.assertAllClose(0.657446, loss.eval())
+ self.assertAllClose(0.395226, unregularized_loss.eval(),
+ rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.657446, loss.eval(),
+ rtol=3e-2, atol=3e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
@@ -148,11 +152,13 @@ class SdcaOptimizerTest(TensorFlowTestCase):
'gender': [0]}, 1),
]
example_weights = [1.0, 0.0, 1.0, 0.0]
- with self.test_session(use_gpu=False):
+ config = tf.ConfigProto(inter_op_parallelism_threads=1,
+ intra_op_parallelism_threads=1)
+ with self.test_session(use_gpu=False, config=config):
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
- options = dict(symmetric_l2_regularization=0.25,
+ options = dict(symmetric_l2_regularization=0.5,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
tf.initialize_all_variables().run()
@@ -161,8 +167,10 @@ class SdcaOptimizerTest(TensorFlowTestCase):
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval())
- self.assertAllClose(0.526336, loss.eval())
+ self.assertAllClose(0.395226, unregularized_loss.eval(),
+ rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.657446, loss.eval(),
+ rtol=3e-2, atol=3e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
@@ -180,7 +188,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
]
# Zeroed out example weights.
example_weights = [0.0, 0.0]
- with self.test_session(use_gpu=False):
+ config = tf.ConfigProto(inter_op_parallelism_threads=1,
+ intra_op_parallelism_threads=1)
+ with self.test_session(use_gpu=False, config=config):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=0.5,
@@ -211,7 +221,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
'gender': [1]}, 1),
]
example_weights = [1.0, 1.0, 1.0, 1.0]
- with self.test_session(use_gpu=False):
+ config = tf.ConfigProto(inter_op_parallelism_threads=1,
+ intra_op_parallelism_threads=1)
+ with self.test_session(use_gpu=False, config=config):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(3, 1)
options = dict(symmetric_l2_regularization=0.25,
@@ -224,10 +236,8 @@ class SdcaOptimizerTest(TensorFlowTestCase):
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
- self.assertAllClose(0.331710,
- unregularized_loss.eval(),
- rtol=3e-2,
- atol=3e-2)
+ self.assertAllClose(0.331710, unregularized_loss.eval(),
+ rtol=3e-2, atol=3e-2)
self.assertAllClose(0.591295, loss.eval(), rtol=3e-2, atol=3e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
@@ -245,7 +255,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
'gender': [1]}, 1),
]
example_weights = [3.0, 1.0]
- with self.test_session(use_gpu=False):
+ config = tf.ConfigProto(inter_op_parallelism_threads=1,
+ intra_op_parallelism_threads=1)
+ with self.test_session(use_gpu=False, config=config):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=0.25,
@@ -257,10 +269,8 @@ class SdcaOptimizerTest(TensorFlowTestCase):
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
- self.assertAllClose(0.266189,
- unregularized_loss.eval(),
- rtol=3e-2,
- atol=3e-2)
+ self.assertAllClose(0.266189, unregularized_loss.eval(),
+ rtol=3e-2, atol=3e-2)
self.assertAllClose(0.571912, loss.eval(), rtol=3e-2, atol=3e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 099c5ee468..7c4cdb6f4a 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -112,7 +112,7 @@ class SdcaModel(object):
"""Create a new sdca optimizer."""
_maybe_load_sdca_ops()
-
+
if not container or not examples or not variables or not options:
raise ValueError('All arguments must be specified.')
@@ -303,7 +303,6 @@ class SdcaModel(object):
'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
with name_scope('sdca/regularized_loss'):
- logits = self._logits(examples)
# TODO(rohananil): Change loss when supporting linear regression.
return self._l1_loss() + self._l2_loss() + self.unregularized_loss(
examples)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 2f099cecb3..d29d1aed0d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -229,6 +229,7 @@ tf_cuda_library(
"framework/function.h",
"framework/graph_def_util.h",
"framework/kernel_def_builder.h",
+ "framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
"framework/node_def_builder.h",
@@ -251,6 +252,7 @@ tf_cuda_library(
"framework/tensor_slice.h",
"framework/tensor_types.h",
"framework/tensor_util.h",
+ "framework/tracking_allocator.h",
"framework/type_index.h",
"framework/type_traits.h",
"framework/types.h",
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index 853c309091..21c5cd58d6 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/coding.h"
@@ -99,6 +101,12 @@ class TF_ManagedBuffer : public TensorBuffer {
};
void deallocate_realigned_buffer(void* data, size_t len, void* arg) {
+ if (tensorflow::LogMemory::IsEnabled()) {
+ tensorflow::LogMemory::RecordRawDeallocation(
+ "TensorFlow C Api",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
+ tensorflow::cpu_allocator(), false);
+ }
tensorflow::cpu_allocator()->DeallocateRaw(data);
}
} // namespace
@@ -125,6 +133,12 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims,
// requirements.
buf->data_ =
tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
+ if (tensorflow::LogMemory::IsEnabled()) {
+ tensorflow::LogMemory::RecordRawAllocation(
+ "TF_NewTensor",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len,
+ buf->data_, tensorflow::cpu_allocator());
+ }
std::memcpy(buf->data_, data, len);
buf->deallocator_ = deallocate_realigned_buffer;
buf->deallocator_arg_ = nullptr;
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 910f67bf81..ce394713fe 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
@@ -321,7 +322,7 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
core::ScopedUnref rendez_unref(rendez);
Executor::Args args;
- args.step_id = Executor::Args::CONSTANT_FOLDING_STEP_ID;
+ args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
args.runner = runner;
args.rendezvous = rendez;
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 635412b55b..3dd713d99a 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
@@ -91,7 +92,7 @@ string GetRendezvousKey(const string& tensor_name,
} // namespace
-std::atomic_int_fast64_t DirectSession::step_id_counter_(0);
+std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
// NOTE: On Android with a single device, there is never
// a risk of an OpKernel blocking indefinitely:
@@ -305,8 +306,9 @@ Status DirectSession::RunWithOpts(const RunOptions& run_options,
args.rendezvous = run_state.rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
- VLOG(1) << "Step " << args.step_id << " is for handle "
- << run_state_args.handle;
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordStep(args.step_id, run_state_args.handle);
+ }
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
args.stats_collector =
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 41619bc3d2..7072149d65 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_segment.h"
@@ -1246,6 +1247,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
Entry* out = &((*outputs)[i]);
out->has_value = true;
+ // This value is filled in below if LogMemory::IsEnabled.
+ Tensor value_to_log;
+
// Set the device context of the output entry.
out->device_context = device_context;
@@ -1259,8 +1263,16 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
if (val.is_ref()) {
out->ref = val.tensor;
out->ref_mu = val.mutex_if_ref;
+ if (LogMemory::IsEnabled()) {
+ // Dereference the tensor under the lock.
+ mutex_lock l(*out->ref_mu);
+ value_to_log = *out->ref;
+ }
} else {
out->val = *val.tensor;
+ if (LogMemory::IsEnabled()) {
+ value_to_log = out->val;
+ }
}
if (stats_collector_ && val.tensor->IsInitialized()) {
nodestats::SetOutput(stats, i, val.tensor);
@@ -1272,6 +1284,10 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
DataTypeString(item.output_type(i)),
" for node ", SummarizeNodeDef(node->def())));
}
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(),
+ i, value_to_log);
+ }
}
if (!val.is_ref()) {
// If OpKernelContext returns outputs via pass-by-value, we
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 846a3e1775..ab106c85be 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -80,12 +80,6 @@ class Executor {
// RunAsync() dispatches closures to "runner". Typically, "runner"
// is backed up by a bounded threadpool.
struct Args {
- // Executors are sometimes instantiated for initialization work
- // like constant folding that is logically outside any computation
- // step, and SpecialStepIds lists the ids used for those steps.
- enum SpecialStepIds {
- CONSTANT_FOLDING_STEP_ID = -1,
- };
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
StepStatsCollector* stats_collector = nullptr;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 1b740e33d9..c46c785744 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -75,8 +75,12 @@ class EigenAllocator : public ::Eigen::Allocator {
public:
EigenAllocator() {}
- void Reinitialize(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
- EventMgr* em) {
+ void Reinitialize(OpKernelContext* context, gpu::Stream* stream,
+ ::tensorflow::Allocator* alloc, EventMgr* em) {
+ if (LogMemory::IsEnabled()) {
+ operation_ = context->op_kernel().name() + "/EigenAllocator";
+ step_id_ = context->step_id();
+ }
stream_ = stream;
allocator_ = alloc;
em_ = em;
@@ -90,14 +94,24 @@ class EigenAllocator : public ::Eigen::Allocator {
LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
<< num_bytes << ". See error logs for more detailed info.";
}
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret,
+ allocator_);
+ }
return ret;
}
void deallocate(void* buffer) const override {
- em_->ThenDeleteBuffer(stream_, {allocator_, buffer});
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_,
+ true);
+ }
+ em_->ThenDeleteBuffer(stream_, {allocator_, buffer, operation_, step_id_});
}
private:
+ string operation_;
+ int64 step_id_;
gpu::Stream* stream_; // Not owned.
::tensorflow::Allocator* allocator_; // Not owned.
::tensorflow::EventMgr* em_; // Not owned.
@@ -110,8 +124,12 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
public:
EigenCudaStreamDevice() { Eigen::initializeDeviceProp(); }
- void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id,
- ::tensorflow::Allocator* alloc) {
+ void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
+ int gpu_id, ::tensorflow::Allocator* alloc) {
+ if (LogMemory::IsEnabled()) {
+ operation_ = context->op_kernel().name() + "/EigenAllocator";
+ step_id_ = context->step_id();
+ }
stream_ = cuda_stream;
allocator_ = alloc;
device_prop_ = &Eigen::m_deviceProperties[gpu_id];
@@ -128,30 +146,47 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
<< num_bytes << ". See error logs for more detailed info.";
}
-
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret,
+ allocator_);
+ }
return ret;
}
void deallocate(void* buffer) const override {
- AsyncFreeData* afData = new AsyncFreeData(allocator_, buffer);
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_,
+ true);
+ }
+ AsyncFreeData* afData =
+ new AsyncFreeData(allocator_, buffer, operation_, step_id_);
cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
CHECK_EQ(err, cudaSuccess);
}
private:
struct AsyncFreeData {
- AsyncFreeData(::tensorflow::Allocator* a, void* p)
- : allocator_(a), address_(p) {}
+ AsyncFreeData(::tensorflow::Allocator* a, void* p, const string& o,
+ const int64 s)
+ : allocator_(a), address_(p), operation_(o), step_id_(s) {}
::tensorflow::Allocator* allocator_;
void* address_;
+ const string operation_;
+ const int64 step_id_;
};
static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
void* userData) {
AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(data->operation_, data->step_id_,
+ data->address_, data->allocator_, false);
+ }
data->allocator_->DeallocateRaw(data->address_);
delete data;
}
+ string operation_;
+ int64 step_id_;
const cudaStream_t* stream_; // Not owned.
const cudaDeviceProp* device_prop_; // Not owned.
::tensorflow::Allocator* allocator_; // Not owned.
@@ -435,9 +470,9 @@ namespace {
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
ConcretePerOpGpuDevice() : device_(nullptr) {}
- void Reinitialize(gpu::Stream* stream, Allocator* base_allocator,
- ::tensorflow::EventMgr* em) {
- allocator_.Reinitialize(stream, base_allocator, em);
+ void Reinitialize(OpKernelContext* context, gpu::Stream* stream,
+ Allocator* base_allocator, ::tensorflow::EventMgr* em) {
+ allocator_.Reinitialize(context, stream, base_allocator, em);
device_.Reinitialize(stream, &allocator_);
}
@@ -452,9 +487,9 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
ConcretePerOpGpuDevice() : device_(&stream_device_) {}
- void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id,
- Allocator* base_allocator) {
- stream_device_.Reinitialize(cuda_stream, gpu_id, base_allocator);
+ void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
+ int gpu_id, Allocator* base_allocator) {
+ stream_device_.Reinitialize(context, cuda_stream, gpu_id, base_allocator);
}
const Eigen::GpuDevice& device() const override { return device_; }
@@ -466,18 +501,19 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
#endif
} // namespace
-void BaseGPUDevice::ReinitializeDevice(PerOpGpuDevice* device, int stream_id,
+void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
+ PerOpGpuDevice* device, int stream_id,
Allocator* allocator) {
ConcretePerOpGpuDevice* concrete_device =
dynamic_cast<ConcretePerOpGpuDevice*>(device);
DCHECK(concrete_device);
#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
- concrete_device->Reinitialize(streams_[stream_id].compute, allocator,
+ concrete_device->Reinitialize(context, streams_[stream_id].compute, allocator,
em_.get());
#else
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
streams_[stream_id].compute->implementation()->CudaStreamMemberHack());
- concrete_device->Reinitialize(cuda_stream, gpu_id_, allocator);
+ concrete_device->Reinitialize(context, cuda_stream, gpu_id_, allocator);
#endif
}
@@ -485,7 +521,8 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
return new ConcretePerOpGpuDevice();
}
-void BaseGPUDevice::ReinitializeGpuDevice(PerOpGpuDevice* device,
+void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+ PerOpGpuDevice* device,
DeviceContext* dc,
Allocator* allocator) {
if (dc) {
@@ -494,9 +531,9 @@ void BaseGPUDevice::ReinitializeGpuDevice(PerOpGpuDevice* device,
VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id
<< "]";
CHECK_LT(stream_id, streams_.size());
- ReinitializeDevice(device, stream_id, allocator);
+ ReinitializeDevice(context, device, stream_id, allocator);
} else {
- ReinitializeDevice(device, 0, allocator);
+ ReinitializeDevice(context, device, 0, allocator);
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index c940004584..36af8c7b8b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -74,8 +74,8 @@ class BaseGPUDevice : public LocalDevice {
// The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override;
- void ReinitializeGpuDevice(PerOpGpuDevice* device, DeviceContext* dc,
- Allocator* allocator) override;
+ void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc, Allocator* allocator) override;
protected:
Allocator* gpu_allocator_; // not owned
@@ -96,8 +96,8 @@ class BaseGPUDevice : public LocalDevice {
const bool sync_every_op_ = false;
std::unique_ptr<EventMgr> em_;
- void ReinitializeDevice(PerOpGpuDevice* device, int stream_id,
- Allocator* allocator);
+ void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ int stream_id, Allocator* allocator);
};
class BaseGPUDeviceFactory : public DeviceFactory {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
index 58820d1c2d..d37a55784d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -22,8 +22,6 @@ limitations under the License.
namespace tensorflow {
-void RequireGPUDevice() {}
-
class GPUDevice : public BaseGPUDevice {
public:
GPUDevice(const SessionOptions& options, const string& name,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 45dfc3e129..cb034cf2d8 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -57,6 +57,11 @@ EventMgr::~EventMgr() {
delete ue->mem;
}
if (ue->bufrec.buf) {
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(ue->bufrec.operation,
+ ue->bufrec.step_id, ue->bufrec.buf,
+ ue->bufrec.alloc, false);
+ }
ue->bufrec.alloc->DeallocateRaw(ue->bufrec.buf);
}
if (ue->func != nullptr) threadpool_.Schedule(ue->func);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index 1e11e67ced..6cc6595767 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/core/notification.h"
@@ -58,6 +59,10 @@ class EventMgr {
struct BufRec {
Allocator* alloc;
void* buf;
+ // operation and step_id are only populated when
+ // LogMemory::IsEnabled() is true.
+ string operation;
+ int64 step_id;
};
// Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw()
@@ -110,7 +115,14 @@ class EventMgr {
}
delete iu.mem;
}
- if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
+ if (iu.bufrec.buf) {
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(iu.bufrec.operation,
+ iu.bufrec.step_id, iu.bufrec.buf,
+ iu.bufrec.alloc, false);
+ }
+ iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
+ }
// The function must be called in another thread.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
index 25bccbfd2c..a25d072eea 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -154,6 +154,11 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU");
alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
buf = alloc->Allocate<char>(total_bytes);
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawAllocation("SetProtoFromGPU",
+ LogMemory::PROTO_BUFFER_STEP_ID,
+ total_bytes, buf, alloc);
+ }
void* src_ptr = GetBase(&tensor);
DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes);
send_device_to_host_stream->ThenMemcpy(buf, gpu_src_ptr, total_bytes);
@@ -170,6 +175,11 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
if (total_bytes > 0) {
port::CopyFromArray(proto->mutable_tensor_content(), buf,
total_bytes);
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation("SetProtoFromGPU",
+ LogMemory::PROTO_BUFFER_STEP_ID,
+ buf, alloc, false);
+ }
alloc->Deallocate<char>(buf, total_bytes);
}
done(Status::OK());
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 596a654920..8c438f588b 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -155,9 +157,15 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
numa_node = 0;
mutex_lock lock(mu_);
while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
- cpu_allocators_.push_back(new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/, new BasicCPUAllocator(),
- new NoopRounder, "cpu_pool"));
+ Allocator* allocator =
+ new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+ new BasicCPUAllocator(), new NoopRounder, "cpu_pool");
+ if (LogMemory::IsEnabled()) {
+ // Wrap the allocator to track allocation ids for better logging
+ // at the cost of performance.
+ allocator = new TrackingAllocator(allocator, true);
+ }
+ cpu_allocators_.push_back(allocator);
}
return cpu_allocators_[0];
}
@@ -178,9 +186,15 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
gpu::Platform* gpu_platform = GPUMachineManager();
gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
CHECK(se);
- cuda_host_allocators_.push_back(new PoolAllocator(
+ Allocator* allocator = new PoolAllocator(
100 /*pool_size_limit*/, true /*auto_resize*/,
- new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host"));
+ new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host");
+ if (LogMemory::IsEnabled()) {
+ // Wrap the allocator to track allocation ids for better logging
+ // at the cost of performance.
+ allocator = new TrackingAllocator(allocator, true);
+ }
+ cuda_host_allocators_.push_back(allocator);
if (FLAGS_brain_gpu_record_mem_types) {
MemDesc md;
md.loc = MemDesc::CPU;
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h
index eb289887b9..144be80f58 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.h
+++ b/tensorflow/core/common_runtime/gpu/process_state.h
@@ -116,10 +116,10 @@ class ProcessState {
mutex mu_;
- std::vector<PoolAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
- std::vector<PoolAllocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md
index 243fa47205..918af2d2ba 100644
--- a/tensorflow/core/distributed_runtime/README.md
+++ b/tensorflow/core/distributed_runtime/README.md
@@ -31,7 +31,7 @@ test your installation by starting a server as follows:
```shell
# Start a TensorFlow server as a single-process "cluster".
$ bazel-bin/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server \
- --cluster_spec='local|localhost:2222' --job_name=local --task_index=0 &
+ --cluster_spec='local|localhost:2222' --job_name=local --task_id=0 &
```
...then start a Python interpreter and create a remote session:
@@ -62,9 +62,9 @@ The command-line arguments to `grpc_tensorflow_server` define the membership of
</tr>
</table>
-The `--job_name` and `--task_index` flags indicate which task will run in this
+The `--job_name` and `--task_id` flags indicate which task will run in this
process, out of the jobs and tasks defined in `--cluster_spec`. For example,
-`--job_name=local --task_index=0` means that the process will be task
+`--job_name=local --task_id=0` means that the process will be task
`/job:local/task:0`, and TensorFlow devices in the process will have names
starting with that prefix.
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index f1bcbf3956..261091f778 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/config.pb.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -335,8 +336,9 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
args.rendezvous = rendezvous;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
- VLOG(1) << "Step " << args.step_id << " is for handle " << handle
- << ", graph-local step " << step_id;
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordStep(args.step_id, handle);
+ }
thread::ThreadPool* pool = worker_env_->compute_pool;
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
for (const auto& unit : item->units) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc
deleted file mode 100644
index c502cbdacf..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc
+++ /dev/null
@@ -1,123 +0,0 @@
-/* Copyright 2016 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
-
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/distributed_runtime/graph_mgr.h"
-#include "tensorflow/core/distributed_runtime/master_env.h"
-#include "tensorflow/core/distributed_runtime/master_session.h"
-#include "tensorflow/core/distributed_runtime/process_util.h"
-#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
-#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/worker_env.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/init_main.h"
-#include "tensorflow/core/public/session_options.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-// This binary starts a TensorFlow server (master and worker) for test purposes.
-namespace tensorflow {
-
-struct GrpcTaskOptions {
- // This process belongs to the "job_name".
- string job_name;
-
- // This process is the task-th task within the replica. 0th, 1st,
- // 2nd, etc.
- int32 task = 0;
-
- // Specification of peers.
- GrpcChannelSpec channel_spec;
-
- SessionOptions default_session_options;
-};
-
-Status StartTensorFlowServer(const TaskOptions& task_options) {
- thread::ThreadPool* thread_pool =
- new thread::ThreadPool(Env::Default(), "server", 1);
- thread_pool->Schedule([argc, argv, task_options]() {
- // This process provides both the worker service and the master
- // service. We let these two services share the same channel cache
- // (rpc connections) and cpu devices (used by the master as the
- // client device). These client devices require a worker service
- // so that remote devices can copy the feeds from the client
- // device in the master.
- tensorflow::MasterEnv master_env;
- string name_prefix =
- strings::StrCat("/job:", task_optionss.job_name, "/replica:0", "/task:",
- task_options.task);
- DeviceFactory::AddDevices(task_options.default_session_options, name_prefix,
- &master_env.local_devices);
-
- // Create the DeviceMgr before initializing the RPC layer, because that
- // needs to know how many devices of each kind exist.
- WorkerEnv worker_env;
- worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
-
- // Finish setting up Env for Worker service.
- string donotcare;
- CHECK(DeviceNameUtils::SplitDeviceName(master_env.local_devices[0]->name(),
- &worker_env.worker_name,
- &donotcare));
- worker_env.env = Env::Default();
-
- GrpcChannelCache* channel_cache =
- NewGrpcChannelCache(task_options.channel_spec);
- string server_address = channel_cache->TranslateTask(name_prefix);
- worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
- worker_env.graph_mgr = new GraphMgr(&worker_env);
- worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
- worker_env.compute_pool = ComputePool(task_options.default_session_options);
-
- // Finish setting up Env for Master service.
- master_env.env = Env::Default();
- master_env.ops = OpRegistry::Global();
- master_env.worker_cache = worker_env.worker_cache;
- master_env.master_session_factory = internal::NewMasterSession;
-
- ::grpc::ServerBuilder builder;
- builder.AddListeningPort(server_address,
- ::grpc::InsecureServerCredentials());
- auto master_service = NewGrpcMasterService(&master_env, &builder);
- auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
- // Finally assemble the server.
- auto server_ = builder.BuildAndStart();
-
- std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
- ThreadOptions(), "master_service_thread",
- [master_service]() { master_service->HandleRPCsLoop(); }));
-
- std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
- ThreadOptions(), "worker_service_thread",
- [worker_service]() { worker_service->HandleRPCsLoop(); }));
- });
-
- // The ThreadPool destructor waits until all work is done (i.e. forever).
- delete thread_pool;
- return Status::OK();
-}
-
-} // end namespace tensorflow
diff --git a/tensorflow/core/framework/allocation_description.proto b/tensorflow/core/framework/allocation_description.proto
index 992f867b75..8c74ee16c8 100644
--- a/tensorflow/core/framework/allocation_description.proto
+++ b/tensorflow/core/framework/allocation_description.proto
@@ -21,4 +21,7 @@ message AllocationDescription {
// Set if this tensor only has one remaining reference
bool has_single_reference = 5;
+
+ // Address of the allocation.
+ uint64 ptr = 6;
};
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index df15bfdc69..72b1865042 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -99,8 +101,18 @@ class CPUAllocator : public Allocator {
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};
+namespace {
+Allocator* MakeCpuAllocator() {
+ Allocator* allocator = new CPUAllocator;
+ if (LogMemory::IsEnabled()) {
+ allocator = new TrackingAllocator(allocator, true);
+ }
+ return allocator;
+}
+} // namespace
+
Allocator* cpu_allocator() {
- static CPUAllocator* cpu_alloc = new CPUAllocator;
+ static Allocator* cpu_alloc = MakeCpuAllocator();
return cpu_alloc;
}
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 5a55557f7e..97a3f61693 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -36,6 +36,12 @@ struct AllocationAttributes {
// An example use case is optional scratch spaces where a failure
// has only performance impact.
bool no_retry_on_failure = false;
+ // If a Tensor is allocated without the following set to true, then
+ // it is logged as an unknown allocation. During execution Tensors
+ // should be allocated through the OpKernelContext which records
+ // which Op is performing the allocation, and sets this flag to
+ // true.
+ bool allocation_will_be_logged = false;
};
// Runtime statistics collected by an allocator.
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index a3fb3c2b67..075f74a0b6 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -42,6 +42,7 @@ namespace tensorflow {
class Device;
class Env;
class EventMgr;
+class OpKernelContext;
class ResourceMgr;
namespace thread {
@@ -170,7 +171,8 @@ class DeviceBase {
// This is overridden by GPU devices to reinitialize the derived
// type returned by MakeGpuDevice.
- virtual void ReinitializeGpuDevice(PerOpGpuDevice* /*device*/,
+ virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
+ PerOpGpuDevice* /*device*/,
DeviceContext* /*dc*/,
Allocator* /*allocator*/) {}
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
index 272ce4b497..824692ab2b 100644
--- a/tensorflow/core/framework/kernel_def_builder.h
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -83,8 +83,7 @@ class KernelDefBuilder {
// IMPLEMENTATION
template <class T>
-inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
- const char* attr_name) {
+KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
}
diff --git a/tensorflow/core/framework/log_memory.cc b/tensorflow/core/framework/log_memory.cc
new file mode 100644
index 0000000000..6411c7638d
--- /dev/null
+++ b/tensorflow/core/framework/log_memory.cc
@@ -0,0 +1,83 @@
+#include "tensorflow/core/framework/log_memory.h"
+
+#include "tensorflow/core/framework/log_memory.pb.h"
+
+namespace tensorflow {
+
+const string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__";
+
+bool LogMemory::IsEnabled() { return VLOG_IS_ON(1); }
+
+void LogMemory::OutputToLog(const protobuf::Message& proto) {
+ protobuf::TextFormat::Printer printer;
+ printer.SetExpandAny(true);
+ printer.SetSingleLineMode(true);
+ string contents_string;
+ printer.PrintToString(proto, &contents_string);
+
+ LOG(INFO) << kLogMemoryLabel << " " << proto.GetDescriptor()->name() << " { "
+ << contents_string << " }";
+}
+
+void LogMemory::RecordStep(const int64 step_id, const string& handle) {
+ MemoryLogStep step;
+ step.set_step_id(step_id);
+ step.set_handle(handle);
+ OutputToLog(step);
+}
+
+void LogMemory::RecordTensorAllocation(const string& kernel_name,
+ const int64 step_id,
+ const Tensor& tensor) {
+ MemoryLogTensorAllocation allocation;
+ allocation.set_step_id(step_id);
+ allocation.set_kernel_name(kernel_name);
+ tensor.FillDescription(allocation.mutable_tensor());
+ OutputToLog(allocation);
+}
+
+void LogMemory::RecordTensorDeallocation(const int64 allocation_id,
+ const string& allocator_name) {
+ MemoryLogTensorDeallocation deallocation;
+ deallocation.set_allocation_id(allocation_id);
+ deallocation.set_allocator_name(allocator_name);
+ OutputToLog(deallocation);
+}
+
+void LogMemory::RecordTensorOutput(const string& kernel_name,
+ const int64 step_id, const int index,
+ const Tensor& tensor) {
+ MemoryLogTensorOutput output;
+ output.set_step_id(step_id);
+ output.set_kernel_name(kernel_name);
+ output.set_index(index);
+ tensor.FillDescription(output.mutable_tensor());
+ OutputToLog(output);
+}
+
+void LogMemory::RecordRawAllocation(const string& operation,
+ const int64 step_id, size_t num_bytes,
+ void* ptr, Allocator* allocator) {
+ MemoryLogRawAllocation allocation;
+ allocation.set_step_id(step_id);
+ allocation.set_operation(operation);
+ allocation.set_num_bytes(static_cast<int64>(num_bytes));
+ allocation.set_ptr(reinterpret_cast<uintptr_t>(ptr));
+ allocation.set_allocation_id(allocator->AllocationId(ptr));
+ allocation.set_allocator_name(allocator->Name());
+ OutputToLog(allocation);
+}
+
+void LogMemory::RecordRawDeallocation(const string& operation,
+ const int64 step_id, void* ptr,
+ Allocator* allocator, bool deferred) {
+ MemoryLogRawDeallocation deallocation;
+ deallocation.set_step_id(step_id);
+ deallocation.set_operation(operation);
+ deallocation.set_allocation_id(allocator->AllocationId(ptr));
+ deallocation.set_allocator_name(allocator->Name());
+ deallocation.set_deferred(deferred);
+ OutputToLog(deallocation);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/log_memory.h b/tensorflow/core/framework/log_memory.h
new file mode 100644
index 0000000000..4ff862ef0e
--- /dev/null
+++ b/tensorflow/core/framework/log_memory.h
@@ -0,0 +1,115 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+#define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// LogMemory contains methods for recording memory allocations and
+// frees, associating each allocation with a step identified by a
+// process-wide id. For now, logging is enabled whenever VLOG_IS_ON(1)
+// for the log_memory module.
+//
+// Limitations: We don't log memory allocations by Eigen on the CPU
+// since that would require major changes to plumb through to the
+// Eigen::{DefaultDevice,ThreadPoolDevice} allocate and deallocate
+// methods. We do log Eigen allocations on GPU since the plumbing was
+// already in place.
+class LogMemory {
+ public:
+ // Allocations sometimes happen outside any computation step, and
+ // SpecialStepIds lists the ids used for those steps.
+ enum SpecialStepIds {
+ // Used when performing a just-in-time constant folding optimization.
+ CONSTANT_FOLDING_STEP_ID = -1,
+ // Used when constructing an Op kernel before executing a step.
+ OP_KERNEL_CONSTRUCTION_STEP_ID = -2,
+ // Used when allocating a tensor buffer from external code, e.g.,
+ // the C API.
+ EXTERNAL_TENSOR_ALLOCATION_STEP_ID = -3,
+ // Used when allocating a buffer for network transfer.
+ NETWORK_BUFFER_STEP_ID = -4,
+ // Used when allocating a buffer to fill a Proto from the GPU.
+ PROTO_BUFFER_STEP_ID = -5,
+ // Used when allocating a Tensor where the caller has not indicated
+ // the step.
+ UNKNOWN_STEP_ID = -6,
+ };
+
+ static const string kLogMemoryLabel;
+
+ // Test to see if memory logging is enabled. For now, logging is
+ // enabled whenever VLOG_IS_ON(1) for the log_memory module.
+ static bool IsEnabled();
+
+ // Log the beginning of a step.
+ static void RecordStep(int64 step_id, const string& handle);
+
+ // Log a tensor buffer allocation. The name indicates which kernel
+ // made the allocation. If the allocation is made through an
+ // OpKernelContext the step_id indicates which step is executing,
+ // otherwise step_id is one of the SpecialStepIds defined in
+ // op_kernel.h, e.g. Op Kernel construction or an optimization pass
+ // such as constant folding.
+ static void RecordTensorAllocation(const string& kernel_name, int64 step_id,
+ const Tensor& tensor);
+
+ // Log a tensor buffer deallocation. The deallocation is triggered
+ // when the buffer's refcount falls to zero, and the tracking
+ // mechanism does not associate it with a particular step or
+ // kernel. The allocation_id/allocator_name should match a
+ // corresponding tensor previously passed in to
+ // RecordTensorAllocation.
+ static void RecordTensorDeallocation(int64 allocation_id,
+ const string& allocator_name);
+
+ // Log the use of a tensor as an output from a kernel.
+ static void RecordTensorOutput(const string& kernel_name, int64 step_id,
+ int index, const Tensor& tensor);
+
+ // Log a "raw" allocation, which is just a buffer sized in
+ // bytes. The Eigen allocator, and memory copies, record their
+ // allocations this way, since they do not allocate TensorFlow
+ // tensors. The operation is set to the OpKernel name if this is
+ // called from within an Op execution, otherwise it indicates an
+ // operation such as memcpy. The step_id if >=0 indicates which step
+ // is executing, otherwise step_id is one of the SpecialStepIds
+ // defined in op_kernel.h, e.g. Op Kernel construction or an
+ // optimization pass such as constant folding.
+ static void RecordRawAllocation(const string& operation, int64 step_id,
+ size_t num_bytes, void* ptr,
+ Allocator* allocator);
+
+ // Log a "raw" deallocation of a buffer. When deferred is true, the
+ // buffer won't be used again, but a GPU kernel may still be
+ // enqueued using the buffer. A deferred deallocation should always
+ // be followed by a matching non-deferred deallocation when the
+ // buffer is actually returned and can be reused.
+ static void RecordRawDeallocation(const string& operation, int64 step_id,
+ void* ptr, Allocator* allocator,
+ bool deferred);
+
+ private:
+ // Write the message as a log entry
+ static void OutputToLog(const protobuf::Message& proto);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
diff --git a/tensorflow/core/framework/log_memory.proto b/tensorflow/core/framework/log_memory.proto
new file mode 100644
index 0000000000..83971bd358
--- /dev/null
+++ b/tensorflow/core/framework/log_memory.proto
@@ -0,0 +1,93 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+option java_outer_classname = "LogMemoryProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "tensorflow/core/framework/tensor_description.proto";
+
+message MemoryLogStep {
+ // Process-unique step id.
+ int64 step_id = 1;
+
+ // Handle describing the feeds and fetches of the step.
+ string handle = 2;
+};
+
+message MemoryLogTensorAllocation {
+ // Process-unique step id.
+ int64 step_id = 1;
+
+ // Name of the kernel making the allocation as set in GraphDef,
+ // e.g., "affine2/weights/Assign".
+ string kernel_name = 2;
+
+ // Allocated tensor details.
+ TensorDescription tensor = 3;
+};
+
+message MemoryLogTensorDeallocation {
+ // Id of the tensor buffer being deallocated, used to match to a
+ // corresponding allocation.
+ int64 allocation_id = 1;
+
+ // Name of the allocator used.
+ string allocator_name = 2;
+};
+
+message MemoryLogTensorOutput {
+ // Process-unique step id.
+ int64 step_id = 1;
+
+ // Name of the kernel producing an output as set in GraphDef, e.g.,
+ // "affine2/weights/Assign".
+ string kernel_name = 2;
+
+ // Index of the output being set.
+ int32 index = 3;
+
+ // Output tensor details.
+ TensorDescription tensor = 4;
+}
+
+message MemoryLogRawAllocation {
+ // Process-unique step id.
+ int64 step_id = 1;
+
+ // Name of the operation making the allocation.
+ string operation = 2;
+
+ // Number of bytes in the allocation.
+ int64 num_bytes = 3;
+
+ // Address of the allocation.
+ uint64 ptr = 4;
+
+ // Id of the tensor buffer being allocated, used to match to a
+ // corresponding deallocation.
+ int64 allocation_id = 5;
+
+ // Name of the allocator used.
+ string allocator_name = 6;
+};
+
+message MemoryLogRawDeallocation {
+ // Process-unique step id.
+ int64 step_id = 1;
+
+ // Name of the operation making the deallocation.
+ string operation = 2;
+
+ // Id of the tensor buffer being deallocated, used to match to a
+ // corresponding allocation.
+ int64 allocation_id = 3;
+
+ // Name of the allocator used.
+ string allocator_name = 4;
+
+ // True if the deallocation is queued and will be performed later,
+ // e.g. for GPU lazy freeing of buffers.
+ bool deferred = 5;
+};
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 8bc417a26e..0f08d391ac 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -30,6 +30,11 @@ limitations under the License.
namespace tensorflow {
+AttrSlice::AttrSlice(const NodeDef& node_def)
+ : ndef_(&node_def), attrs_(&ndef_->attr()) {}
+
+AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
+
string SummarizeNodeDef(const NodeDef& node_def) {
string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 4ebfd9a16c..d70bb6dd37 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -56,11 +56,9 @@ void AddNodeAttr(const string& name, std::initializer_list<T> value,
class AttrSlice {
public:
- AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit)
- : ndef_(&node_def),
- attrs_(&ndef_->attr()) {}
+ AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
- explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}
+ explicit AttrSlice(const AttrValueMap* a);
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
@@ -71,7 +69,7 @@ class AttrSlice {
Status Find(const string& attr_name, const AttrValue** attr_value) const;
private:
- const NodeDef* ndef_ = nullptr;
+ const NodeDef* ndef_;
const AttrValueMap* attrs_;
};
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 4392dbe5ff..e78c3c9877 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"
@@ -137,6 +138,10 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
// OpKernelConstruction ------------------------------------------------------
+void OpKernelConstruction::SetStatus(const Status& status) {
+ status_->Update(status);
+}
+
Status OpKernelConstruction::MatchSignature(
const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
@@ -146,12 +151,18 @@ Status OpKernelConstruction::MatchSignature(
Status OpKernelConstruction::allocate_temp(DataType type,
const TensorShape& shape,
Tensor* out_temp) {
- Tensor new_temp(allocator_, type, shape);
+ AllocationAttributes attr;
+ attr.allocation_will_be_logged = true;
+ Tensor new_temp(allocator_, type, shape, attr);
if (!new_temp.IsInitialized() && shape.num_elements() > 0) {
return errors::ResourceExhausted(
"OOM when allocating temporary tensor with shape", shape.DebugString());
}
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorAllocation(
+ def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
+ }
*out_temp = new_temp;
return Status::OK();
}
@@ -174,10 +185,6 @@ Status OpKernelConstruction::allocate_persistent(
return s;
}
-void OpKernelConstruction::SetStatus(const Status& status) {
- status_->Update(status);
-}
-
// OpKernelContext -----------------------------------------------------------
OpKernelContext::OpKernelContext(Params* params)
@@ -186,7 +193,7 @@ OpKernelContext::OpKernelContext(Params* params, int noutputs)
: params_(params), outputs_(noutputs) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
params_->ensure_eigen_gpu_device();
- params_->device->ReinitializeGpuDevice(params_->eigen_gpu_device,
+ params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
params_->op_device_context,
eigen_gpu_allocator);
record_tensor_accesses_ = params_->device->RequiresRecordingAccessedTensors();
@@ -223,6 +230,12 @@ void OpKernelContext::SetStatus(const Status& status) {
status_.Update(status);
}
+void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
+ mutex_lock l(mu_);
+ // Keep a reference to the underlying memory around.
+ referenced_tensors_.Add(tensor);
+}
+
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@@ -253,6 +266,69 @@ Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
return Status::OK();
}
+const Tensor& OpKernelContext::input(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK(!(*params_->inputs)[index].is_ref());
+ const Tensor& tensor = *((*params_->inputs)[index].tensor);
+ record_tensor_reference(tensor);
+ return tensor;
+}
+
+Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
+ // return a copy of the Ref acquired while holding the mutex
+ if (lock_held) {
+ Tensor& tensor = *((*params_->inputs)[index].tensor);
+ record_tensor_reference(tensor);
+ return tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ Tensor& tensor = *((*params_->inputs)[index].tensor);
+ record_tensor_reference(tensor);
+ return tensor;
+ }
+}
+
+void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
+ bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ *(*params_->inputs)[index].tensor = tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ *(*params_->inputs)[index].tensor = tensor;
+ }
+ record_tensor_reference(tensor);
+}
+
+void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
+ int output_index) {
+ DCHECK_GE(input_index, 0);
+ DCHECK_LT(input_index, params_->inputs->size());
+ DCHECK((*params_->inputs)[input_index].is_ref());
+ set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
+ (*params_->inputs)[input_index].tensor);
+}
+
+void OpKernelContext::delete_ref_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ delete (*params_->inputs)[index].tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ delete (*params_->inputs)[index].tensor;
+ }
+}
+
Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
bool lock_held) {
int start, stop;
@@ -317,6 +393,14 @@ Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
return Status::OK();
}
+Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
+ Tensor** output) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_outputs());
+ AllocatorAttributes attr = output_alloc_attr(index);
+ return allocate_output(index, shape, output, attr);
+}
+
Status OpKernelContext::allocate_output(const string& name,
const TensorShape& shape,
Tensor** tensor) {
@@ -346,6 +430,70 @@ Status OpKernelContext::allocate_output(const string& name,
return allocate_output(start, shape, tensor, attr);
}
+Status OpKernelContext::allocate_tensor(
+ DataType type, const TensorShape& shape, Tensor* out_tensor,
+ AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
+ Allocator* a = get_allocator(attr);
+ AllocationAttributes logged_attr(allocation_attr);
+ logged_attr.allocation_will_be_logged = true;
+ Tensor new_tensor(a, type, shape, logged_attr);
+
+ if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
+ return errors::ResourceExhausted("OOM when allocating tensor with shape",
+ shape.DebugString());
+ }
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
+ params_->step_id, new_tensor);
+ }
+ *out_tensor = new_tensor;
+ record_tensor_reference(new_tensor);
+ return Status::OK();
+}
+
+Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
+ Tensor** output,
+ AllocatorAttributes attr) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ const DataType type = params_->op_kernel->output_type(index);
+ DCHECK(!IsRefType(type));
+ DCHECK(mutable_output(index) == nullptr);
+ Tensor* output_tensor = new Tensor();
+ Status s = allocate_tensor(type, shape, output_tensor, attr);
+ if (s.ok()) {
+ outputs_[index] = TensorValue(output_tensor);
+ *output = outputs_[index].tensor;
+ }
+ return s;
+}
+
+Status OpKernelContext::allocate_temp(
+ DataType type, const TensorShape& shape, Tensor* out_temp,
+ AllocatorAttributes allocator_attr,
+ const AllocationAttributes& allocation_attr) {
+ Status s =
+ allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
+ return s;
+}
+
+Status OpKernelContext::allocate_persistent(DataType type,
+ const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor,
+ AllocatorAttributes attr) {
+ // TODO(misard) add specific memory tracking for persistent tensors
+ Tensor persistent;
+ Status s = allocate_tensor(type, shape, &persistent, attr);
+ if (s.ok()) {
+ *out_persistent = PersistentTensor(persistent);
+ if (out_tensor) {
+ *out_tensor = out_persistent->AccessTensor(this);
+ }
+ }
+ return s;
+}
+
Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
@@ -359,6 +507,24 @@ Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
return Status::OK();
}
+void OpKernelContext::set_output(int index, const Tensor& tensor) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
+ DCHECK_EQ(mutable_output(index), nullptr);
+ record_tensor_reference(tensor);
+ outputs_[index] = TensorValue(new Tensor(tensor));
+}
+
+void OpKernelContext::set_output_ref(int index, mutex* mu,
+ Tensor* tensor_for_ref) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ DCHECK(IsRefType(params_->op_kernel->output_type(index)));
+ record_tensor_reference(*tensor_for_ref);
+ outputs_[index] = TensorValue(mu, tensor_for_ref);
+}
+
Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
@@ -717,4 +883,24 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
return eigen_gpu_device();
}
+void OpKernelConstruction::CtxFailure(Status s) {
+ VLOG(1) << s;
+ SetStatus(s);
+}
+
+void OpKernelConstruction::CtxFailureWithWarning(Status s) {
+ LOG(WARNING) << s;
+ SetStatus(s);
+}
+
+void OpKernelContext::CtxFailure(Status s) {
+ VLOG(1) << s;
+ SetStatus(s);
+}
+
+void OpKernelContext::CtxFailureWithWarning(Status s) {
+ LOG(WARNING) << s;
+ SetStatus(s);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index eb33c5ea4e..f3aecf0b96 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -294,9 +294,7 @@ class OpKernelConstruction {
// attr with attr_name is found in def(), or the attr does not have
// a matching type, a non-ok status will be returned.
template <class T>
- Status GetAttr(const string& attr_name, T* value) const {
- return GetNodeAttr(def(), attr_name, value);
- }
+ Status GetAttr(const string& attr_name, T* value) const;
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
@@ -313,6 +311,10 @@ class OpKernelConstruction {
// The GraphDef version whose behavior we should follow.
const int graph_def_version() const { return graph_def_version_; }
+ // Helper routines for the OP_REQUIRES macros
+ void CtxFailure(Status s);
+ void CtxFailureWithWarning(Status s);
+
private:
const DeviceType device_type_;
DeviceBase* const device_;
@@ -911,6 +913,10 @@ class OpKernelContext {
return params_->step_resource_manager;
}
+ // Helper routines for the OP_REQUIRES macros
+ void CtxFailure(Status s);
+ void CtxFailureWithWarning(Status s);
+
private:
Allocator* get_allocator(AllocatorAttributes attr);
@@ -919,6 +925,7 @@ class OpKernelContext {
// accurately track the memory that may not be reused until the Op
// execution completes.
void record_tensor_reference(const Tensor& tensor);
+ void really_record_tensor_reference(const Tensor& tensor);
// Internal common method used when allocating tensor memory
Status allocate_tensor(DataType type, const TensorShape& shape,
@@ -1053,6 +1060,11 @@ class OpKernelRegistrar {
// -----------------------------------------------------------------------------
// Template and inline method implementations, please ignore
+template <class T>
+Status OpKernelConstruction::GetAttr(const string& attr_name, T* value) const {
+ return GetNodeAttr(def(), attr_name, value);
+}
+
inline DataType OpKernelContext::input_dtype(int index) const {
DCHECK_GE(index, 0);
DCHECK_LT(index, params_->inputs->size());
@@ -1074,9 +1086,7 @@ inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
DCHECK(params_->device->RequiresRecordingAccessedTensors() ==
record_tensor_accesses_);
if (record_tensor_accesses_) {
- mutex_lock l(mu_);
- // Keep a reference to the underlying memory around.
- referenced_tensors_.Add(tensor);
+ really_record_tensor_reference(tensor);
}
}
@@ -1088,69 +1098,6 @@ inline void OpKernelContext::retrieve_accessed_tensors(
}
}
-inline const Tensor& OpKernelContext::input(int index) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK(!(*params_->inputs)[index].is_ref());
- const Tensor& tensor = *((*params_->inputs)[index].tensor);
- record_tensor_reference(tensor);
- return tensor;
-}
-
-inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
- // return a copy of the Ref acquired while holding the mutex
- if (lock_held) {
- Tensor& tensor = *((*params_->inputs)[index].tensor);
- record_tensor_reference(tensor);
- return tensor;
- } else {
- mutex_lock l(*input_ref_mutex(index));
- Tensor& tensor = *((*params_->inputs)[index].tensor);
- record_tensor_reference(tensor);
- return tensor;
- }
-}
-
-inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
- bool lock_held) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
- // should only modify the tensor while holding the mutex
- if (lock_held) {
- *(*params_->inputs)[index].tensor = tensor;
- } else {
- mutex_lock l(*input_ref_mutex(index));
- *(*params_->inputs)[index].tensor = tensor;
- }
- record_tensor_reference(tensor);
-}
-
-inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
- int output_index) {
- DCHECK_GE(input_index, 0);
- DCHECK_LT(input_index, params_->inputs->size());
- DCHECK((*params_->inputs)[input_index].is_ref());
- set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
- (*params_->inputs)[input_index].tensor);
-}
-
-inline void OpKernelContext::delete_ref_input(int index, bool lock_held) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
- // should only modify the tensor while holding the mutex
- if (lock_held) {
- delete (*params_->inputs)[index].tensor;
- } else {
- mutex_lock l(*input_ref_mutex(index));
- delete (*params_->inputs)[index].tensor;
- }
-}
-
// no input if tensor == nullptr.
inline bool OpKernelContext::has_input(int index) const {
DCHECK_GE(index, 0);
@@ -1165,96 +1112,12 @@ inline mutex* OpKernelContext::input_ref_mutex(int index) {
return (*params_->inputs)[index].mutex_if_ref;
}
-inline Status OpKernelContext::allocate_output(int index,
- const TensorShape& shape,
- Tensor** output) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, num_outputs());
- AllocatorAttributes attr = output_alloc_attr(index);
- return allocate_output(index, shape, output, attr);
-}
-
-inline Status OpKernelContext::allocate_tensor(
- DataType type, const TensorShape& shape, Tensor* out_tensor,
- AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
- Allocator* a = get_allocator(attr);
- Tensor new_tensor(a, type, shape, allocation_attr);
-
- if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
- return errors::ResourceExhausted("OOM when allocating tensor with shape",
- shape.DebugString());
- }
- *out_tensor = new_tensor;
- record_tensor_reference(new_tensor);
- return Status::OK();
-}
-
-inline Status OpKernelContext::allocate_output(int index,
- const TensorShape& shape,
- Tensor** output,
- AllocatorAttributes attr) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
- const DataType type = params_->op_kernel->output_type(index);
- DCHECK(!IsRefType(type));
- DCHECK(mutable_output(index) == nullptr);
- Tensor* output_tensor = new Tensor();
- Status s = allocate_tensor(type, shape, output_tensor, attr);
- if (s.ok()) {
- outputs_[index] = TensorValue(output_tensor);
- *output = outputs_[index].tensor;
- }
- return s;
-}
-
-inline Status OpKernelContext::allocate_temp(
- DataType type, const TensorShape& shape, Tensor* out_temp,
- AllocatorAttributes allocator_attr,
- const AllocationAttributes& allocation_attr) {
- Status s =
- allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
- return s;
-}
-
-inline Status OpKernelContext::allocate_persistent(
- DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
- Tensor** out_tensor, AllocatorAttributes attr) {
- // TODO(misard) add specific memory tracking for persistent tensors
- Tensor persistent;
- Status s = allocate_tensor(type, shape, &persistent, attr);
- if (s.ok()) {
- *out_persistent = PersistentTensor(persistent);
- if (out_tensor) {
- *out_tensor = out_persistent->AccessTensor(this);
- }
- }
- return s;
-}
-
inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
if (t.IsInitialized()) {
record_tensor_reference(t);
}
}
-inline void OpKernelContext::set_output(int index, const Tensor& tensor) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
- DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
- DCHECK_EQ(mutable_output(index), nullptr);
- record_tensor_reference(tensor);
- outputs_[index] = TensorValue(new Tensor(tensor));
-}
-
-inline void OpKernelContext::set_output_ref(int index, mutex* mu,
- Tensor* tensor_for_ref) {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
- DCHECK(IsRefType(params_->op_kernel->output_type(index)));
- record_tensor_reference(*tensor_for_ref);
- outputs_[index] = TensorValue(mu, tensor_for_ref);
-}
-
inline Tensor* OpKernelContext::mutable_output(int index) {
DCHECK_GE(index, 0);
DCHECK_LT(index, outputs_.size());
@@ -1342,6 +1205,66 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
ctx_->set_output_ref(i, mu, tensor_for_ref);
}
+// Convenience macros for asserting and handling exceptional conditions.
+// Analogous to the CHECK* macros provided by logging.h.
+//
+// Example use:
+// void Compute(OperationContext* context) {
+// OP_REQUIRES(context, context->num_inputs() == 2,
+// errors::InvalidArgument("FooOp requires 2 arguments"));
+// ...
+// Status status = SomeUncertainMethod();
+// OP_REQUIRES_OK(context, status);
+// ...
+// }
+
+// Declares an op deprecated, and illegal starting at GraphDef version VERSION
+#define OP_DEPRECATED(CTX, VERSION, NOTE) \
+ if ((CTX)->graph_def_version() >= (VERSION)) { \
+ ::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
+ "Op ", (CTX)->op_def().name(), \
+ " is not available in GraphDef version ", (CTX)->graph_def_version(), \
+ ". It has been removed in version ", (VERSION), ". ", (NOTE), ".")); \
+ (CTX)->CtxFailure(_s); \
+ return; \
+ } else { \
+ LOG(WARNING) << "Op is deprecated." \
+ << " It will cease to work in GraphDef version " << (VERSION) \
+ << ". " << (NOTE) << "."; \
+ }
+
+#define OP_REQUIRES(CTX, EXP, STATUS) \
+ if (!(EXP)) { \
+ (CTX)->CtxFailure((STATUS)); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK(CTX, STATUS) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ (CTX)->CtxFailureWithWarning(_s); \
+ return; \
+ } \
+ } while (0)
+
+#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
+ if (!(EXP)) { \
+ (CTX)->CtxFailure((STATUS)); \
+ (CALLBACK)(); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ (CTX)->CtxFailureWithWarning(_s); \
+ (CALLBACK)(); \
+ return; \
+ } \
+ } while (0)
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 91b95a8a86..0592b9b286 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
@@ -60,6 +61,7 @@ class Buffer : public TensorBuffer {
int64 rb = size();
proto->set_requested_bytes(rb);
proto->set_allocator_name(alloc_->Name());
+ proto->set_ptr(reinterpret_cast<uintptr_t>(data_));
if (alloc_->TracksAllocationSizes()) {
int64 ab = alloc_->AllocatedSize(data_);
proto->set_allocated_bytes(ab);
@@ -257,6 +259,10 @@ Buffer<T>::Buffer(Allocator* a, int64 n,
template <typename T>
Buffer<T>::~Buffer() {
+ if (LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data_),
+ alloc_->Name());
+ }
alloc_->Deallocate<T>(data_, elem_);
}
@@ -402,6 +408,10 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
}
+ if (IsInitialized() && LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
+ *this);
+ }
}
Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
@@ -412,6 +422,11 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
}
+ if (!allocation_attr.allocation_will_be_logged && IsInitialized() &&
+ LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorAllocation("Unknown (with attributes)",
+ LogMemory::UNKNOWN_STEP_ID, *this);
+ }
}
Tensor::Tensor(DataType type, const TensorShape& shape)
@@ -501,6 +516,11 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
set_dtype(proto.dtype());
UnrefIfNonNull(buf_);
buf_ = p;
+ // TODO(misard) add tracking of which kernels and steps are calling FromProto.
+ if (IsInitialized() && LogMemory::IsEnabled()) {
+ LogMemory::RecordTensorAllocation("Unknown (from Proto)",
+ LogMemory::UNKNOWN_STEP_ID, *this);
+ }
return true;
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 748375fb3e..a6143e95c0 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -43,20 +43,36 @@ class Tensor {
/// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
Tensor();
- /// \brief Creates a Tensor of the given `type` and `shape`.
+ /// \brief Creates a Tensor of the given `type` and `shape`. If
+ /// LogMemory::IsEnabled() the allocation is logged as coming from
+ /// an unknown kernel and step. Calling the Tensor constructor
+ /// directly from within an Op is deprecated: use the
+ /// OpKernelConstruction/OpKernelContext allocate_* methods to
+ /// allocate a new tensor, which record the kernel and step.
///
/// The underlying buffer is allocated using a `CPUAllocator`.
Tensor(DataType type, const TensorShape& shape);
- /// \brief Creates a tensor with the input `type` and `shape`, using the
- /// allocator `a` to allocate the underlying buffer.
+ /// \brief Creates a tensor with the input `type` and `shape`, using
+ /// the allocator `a` to allocate the underlying buffer. If
+ /// LogMemory::IsEnabled() the allocation is logged as coming from
+ /// an unknown kernel and step. Calling the Tensor constructor
+ /// directly from within an Op is deprecated: use the
+ /// OpKernelConstruction/OpKernelContext allocate_* methods to
+ /// allocate a new tensor, which record the kernel and step.
///
/// `a` must outlive the lifetime of this Tensor.
Tensor(Allocator* a, DataType type, const TensorShape& shape);
- /// \brief Creates a tensor with the input `type` and `shape`, using the
- /// allocator `a` and the specified "allocation_attr" to allocate the
- /// underlying buffer.
+ /// \brief Creates a tensor with the input `type` and `shape`, using
+ /// the allocator `a` and the specified "allocation_attr" to
+ /// allocate the underlying buffer. If the kernel and step are known
+ /// allocation_attr.allocation_will_be_logged should be set to true
+ /// and LogMemory::RecordTensorAllocation should be called after the
+ /// tensor is constructed. Calling the Tensor constructor directly
+ /// from within an Op is deprecated: use the
+ /// OpKernelConstruction/OpKernelContext allocate_* methods to
+ /// allocate a new tensor, which record the kernel and step.
///
/// `a` must outlive the lifetime of this Tensor.
Tensor(Allocator* a, DataType type, const TensorShape& shape,
@@ -227,9 +243,7 @@ class Tensor {
///
/// ```
template <typename T>
- typename TTypes<T>::Flat flat() {
- return shaped<T, 1>({NumElements()});
- }
+ typename TTypes<T>::Flat flat();
template <typename T>
typename TTypes<T>::UnalignedFlat unaligned_flat() {
@@ -294,9 +308,7 @@ class Tensor {
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
template <typename T>
- typename TTypes<T>::ConstFlat flat() const {
- return shaped<T, 1>({NumElements()});
- }
+ typename TTypes<T>::ConstFlat flat() const;
template <typename T>
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
@@ -316,16 +328,7 @@ class Tensor {
}
template <typename T>
- typename TTypes<T>::ConstMatrix flat_outer_dims() const {
- int64 first_size = dims() > 0 ? dim_size(0) : 1;
- if (first_size == 0) {
- DCHECK_EQ(NumElements(), 0);
- // Return something empty, avoiding divide by 0
- return shaped<T, 2>({0, 0});
- } else {
- return shaped<T, 2>({first_size, NumElements() / first_size});
- }
- }
+ typename TTypes<T>::ConstMatrix flat_outer_dims() const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor shaped(
@@ -398,6 +401,11 @@ class Tensor {
template <typename T>
T* base() const;
+
+ template <size_t NDIMS>
+ void FillDimsAndValidateCompatibleShape(
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims,
+ gtl::ArraySlice<int64> new_sizes) const;
};
// Implementation details
@@ -477,19 +485,26 @@ typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
}
-template <typename T, size_t NDIMS>
-typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
+template <size_t NDIMS>
+void Tensor::FillDimsAndValidateCompatibleShape(
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims,
gtl::ArraySlice<int64> new_sizes) const {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
CHECK_EQ(NDIMS, new_sizes.size());
int64 new_num_elements = 1;
- Eigen::array<Eigen::DenseIndex, NDIMS> dims;
for (size_t d = 0; d < NDIMS; d++) {
new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
+ (*dims)[d] = new_sizes[d];
}
CHECK_EQ(new_num_elements, NumElements());
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
+ gtl::ArraySlice<int64> new_sizes) const {
+ CHECK(IsAligned());
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
}
@@ -497,14 +512,8 @@ template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const {
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
- CHECK_EQ(NDIMS, new_sizes.size());
- int64 new_num_elements = 1;
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
- for (size_t d = 0; d < NDIMS; d++) {
- new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
- }
- CHECK_EQ(new_num_elements, NumElements());
+ FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
}
@@ -522,6 +531,28 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const {
return typename TTypes<T>::ConstScalar(base<T>());
}
+template <typename T>
+typename TTypes<T>::Flat Tensor::flat() {
+ return shaped<T, 1>({NumElements()});
+}
+
+template <typename T>
+typename TTypes<T>::ConstFlat Tensor::flat() const {
+ return shaped<T, 1>({NumElements()});
+}
+
+template <typename T>
+typename TTypes<T>::ConstMatrix Tensor::flat_outer_dims() const {
+ int64 first_size = dims() > 0 ? dim_size(0) : 1;
+ if (first_size == 0) {
+ DCHECK_EQ(NumElements(), 0);
+ // Return something empty, avoiding divide by 0
+ return shaped<T, 2>({0, 0});
+ } else {
+ return shaped<T, 2>({first_size, NumElements() / first_size});
+ }
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index f608cc815d..9842b8ed3c 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -85,6 +85,11 @@ TensorShape::TensorShape() {
num_elements_ = 1;
}
+void TensorShape::DestructorOutOfLine() {
+ DCHECK(tag() == REP_OUT_OF_LINE);
+ delete as64()->dims_;
+}
+
void TensorShape::SlowCopyFrom(const TensorShape& b) {
if (b.tag() != REP_OUT_OF_LINE) {
if (tag() == REP_OUT_OF_LINE) {
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index ab4979c902..09639faafd 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -137,6 +137,7 @@ class TensorShape {
void DumpRep() const; // XXX
private:
+ void DestructorOutOfLine();
void ClearAllButDataType();
void SlowCopyFrom(const TensorShape& b);
@@ -323,7 +324,7 @@ inline TensorShape::TensorShape(const TensorShape& b) {
inline TensorShape::~TensorShape() {
if (tag() == REP_OUT_OF_LINE) {
- delete as64()->dims_;
+ DestructorOutOfLine();
}
}
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index e0a71a0856..8de02cc2ca 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -28,6 +28,12 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
index(i),
dt(SafeGetOutput(node, i, &error)) {}
+NodeBuilder::NodeOut::NodeOut(const string& name, int i, DataType t)
+ : node(nullptr), error(false), name(name), index(i), dt(t) {}
+
+NodeBuilder::NodeOut::NodeOut()
+ : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
+
NodeBuilder::NodeBuilder(const string& name, const string& op_name,
const OpRegistryInterface* op_registry)
: def_builder_(name, op_name, op_registry) {}
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 51f00f6449..147311f320 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -54,21 +54,20 @@ class NodeBuilder {
// useful when preparing a graph for ExtendSession or creating a
// back edge to a node that hasn't been added to the graph yet,
// but will be.
- NodeOut(const string& name, int i, DataType t)
- : node(nullptr), error(false), name(name), index(i), dt(t) {}
+ NodeOut(const string& name, int i, DataType t);
// Default constructor for std::vector<NodeOut>.
- NodeOut() {}
+ NodeOut();
- Node* node = nullptr;
+ Node* node;
// error is set to true if:
// * the NodeOut was default constructed and never overwritten,
// * a nullptr Node* was passed to the NodeOut constructor, or
// * an out-of-range index was passed to the NodeOut constructor.
- bool error = true;
+ bool error;
string name;
- int index = 0;
- DataType dt = DT_FLOAT;
+ int index;
+ DataType dt;
};
// Specify the name and the Op (either via an OpDef or the name of
@@ -139,7 +138,7 @@ class NodeBuilder {
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
-inline NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
+NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
def_builder_.Attr(attr_name, std::forward<T>(value));
return *this;
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0e63c5a5ea..77fb9af058 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1044,6 +1044,7 @@ filegroup(
srcs = [
"argmax_op.h",
"avgpooling_op.h",
+ "batch_norm_op.h",
"control_flow_ops.h",
"conv_2d.h",
"maxpooling_op.h",
@@ -1067,6 +1068,7 @@ filegroup(
srcs = [
"argmax_op.cc",
"avgpooling_op.cc",
+ "batch_norm_op.cc",
"bcast_ops.cc",
"control_flow_ops.cc",
"conv_2d.h",
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index e6b44d42bc..8aa933330b 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -44,16 +44,22 @@ static TensorShape ToShape(const BCast::Vec& vec) {
}
BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
- : bcast(FromShape(ctx->input(0).shape()),
- FromShape(ctx->input(1).shape())) {
+ : in0(ctx->input(0)),
+ in1(ctx->input(1)),
+ bcast(FromShape(in0.shape()), FromShape(in1.shape())) {
if (!bcast.IsValid()) {
- ctx->SetStatus(errors::InvalidArgument(
- "Incompatible shapes: ", ctx->input(0).shape().DebugString(), " vs. ",
- ctx->input(1).shape().DebugString()));
+ ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
+ in0.shape().DebugString(), " vs. ",
+ in1.shape().DebugString()));
return;
}
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, ToShape(bcast.output_shape()), &out));
+ out_num_elements = out->NumElements();
+ in0_num_elements = in0.NumElements();
+ in1_num_elements = in1.NumElements();
+
+ ndims = bcast.x_reshape().size();
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 531fea9501..3771d58e13 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -45,8 +45,17 @@ class BinaryOpShared : public OpKernel {
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
BinaryOpState(OpKernelContext* ctx);
+ const Tensor& in0;
+ const Tensor& in1;
+
BCast bcast;
Tensor* out = nullptr;
+ int64 out_num_elements;
+
+ int64 in0_num_elements;
+ int64 in1_num_elements;
+
+ int ndims;
};
template <int NDIMS>
@@ -74,42 +83,41 @@ class BinaryOp : public BinaryOpShared {
DataTypeToEnum<Tin>::v()) {}
void Compute(OpKernelContext* ctx) override {
- const Tensor& in0 = ctx->input(0);
- const Tensor& in1 = ctx->input(1);
// 'state': Shared helper not dependent on T to reduce code size
BinaryOpState state(ctx);
if (!ctx->status().ok()) return;
Tensor* out = state.out;
BCast* bcast = &state.bcast;
- if (out->NumElements() == 0) {
+ auto& in0 = state.in0;
+ auto& in1 = state.in1;
+ if (state.out_num_elements == 0) {
return;
}
- const int ndims = bcast->x_reshape().size();
+ const int ndims = state.ndims;
+ const Device& eigen_device = ctx->eigen_device<Device>();
if (ndims <= 1) {
- if (in1.NumElements() == 1) {
+ auto out_flat = out->flat<Tout>();
+ if (state.in1_num_elements == 1) {
// tensor op scalar
functor::BinaryFunctor<Device, Functor, 1>().Right(
- ctx->eigen_device<Device>(), out->flat<Tout>(), in0.flat<Tin>(),
- in1.scalar<Tin>());
+ eigen_device, out_flat, in0.flat<Tin>(), in1.scalar<Tin>());
return;
}
- if (in0.NumElements() == 1) {
+ auto in1_flat = in1.flat<Tin>();
+ if (state.in0_num_elements == 1) {
// scalar op tensor
functor::BinaryFunctor<Device, Functor, 1>().Left(
- ctx->eigen_device<Device>(), out->flat<Tout>(), in0.scalar<Tin>(),
- in1.flat<Tin>());
+ eigen_device, out_flat, in0.scalar<Tin>(), in1_flat);
return;
}
- functor::BinaryFunctor<Device, Functor, 1>()(
- ctx->eigen_device<Device>(), out->flat<Tout>(), in0.flat<Tin>(),
- in1.flat<Tin>());
+ functor::BinaryFunctor<Device, Functor, 1>()(eigen_device, out_flat,
+ in0.flat<Tin>(), in1_flat);
return;
}
if (ndims == 2) {
functor::BinaryFunctor<Device, Functor, 2>().BCast(
- ctx->eigen_device<Device>(),
- out->shaped<Tout, 2>(bcast->result_shape()),
+ eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
in0.shaped<Tin, 2>(bcast->x_reshape()),
ToIndexArray<2>(bcast->x_bcast()),
in1.shaped<Tin, 2>(bcast->y_reshape()),
@@ -119,8 +127,7 @@ class BinaryOp : public BinaryOpShared {
if (ndims == 3) {
functor::BinaryFunctor<Device, Functor, 3>().BCast(
- ctx->eigen_device<Device>(),
- out->shaped<Tout, 3>(bcast->result_shape()),
+ eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
in0.shaped<Tin, 3>(bcast->x_reshape()),
ToIndexArray<3>(bcast->x_bcast()),
in1.shaped<Tin, 3>(bcast->y_reshape()),
@@ -163,6 +170,11 @@ namespace functor {
// For CPUDevice, we do operations inline if the resulting tensor is
// modestly sized.
+//
+// NOTE(jeff): Changing DoInline to 'return false' gives significant code
+// size benefits, but hurts CPU performance considerably (performance
+// on ptb_word_lm drops from 3568 wps to 1922 wps, but code size for
+// tensorflow code in the binary drops by 3.3%).
static bool DoInline(size_t size) { return size <= 32768; }
template <typename D, typename OUT, typename RHS>
diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc
index 0ad9fff077..64955ab0b7 100644
--- a/tensorflow/core/kernels/ops_util.cc
+++ b/tensorflow/core/kernels/ops_util.cc
@@ -22,15 +22,6 @@ limitations under the License.
namespace tensorflow {
-void RequireDefaultOps() {
-// TODO(opensource): Use a more generic sounding preprocessor name than
-// GOOGLE_CUDA (maybe SUPPORT_CUDA?)
-#if GOOGLE_CUDA
- void RequireGPUDevice();
- RequireGPUDevice();
-#endif
-}
-
Status Get2dOutputSize(const int in_height, const int in_width,
int filter_height, int filter_width, int row_stride,
int col_stride, Padding padding, int* new_height,
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
index 4f57cf760e..f27a5bc423 100644
--- a/tensorflow/core/kernels/ops_util.h
+++ b/tensorflow/core/kernels/ops_util.h
@@ -25,12 +25,6 @@ limitations under the License.
namespace tensorflow {
-// Call this function from a test if op kernels are not being
-// registered. This can happen if the test is linked in a shared
-// mode and has no direct references to any code from this directory.
-// TODO(josh11b): Delete this, should no longer be needed.
-void RequireDefaultOps();
-
// Get2dOutputSize(): Given an input tensor, kernel, stride and padding
// type, the function computes the output and padding dimensions.
//
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc
index b8acfb3b20..f660ede290 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.cc
+++ b/tensorflow/core/kernels/padding_fifo_queue.cc
@@ -266,9 +266,8 @@ Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
-template <typename T, int NDIMS>
-Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
- int index) {
+static Status ValidateElementToLargerSlice(const Tensor& element,
+ Tensor* parent) {
DCHECK_NE(parent->dim_size(0), 0);
if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
TensorShape chip_shape = parent->shape();
@@ -279,6 +278,16 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
"Shapes are: [element]: ", element.shape().DebugString(),
", [parent slice]: ", chip_shape.DebugString());
}
+ return Status::OK();
+}
+
+template <typename T, int NDIMS>
+Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index) {
+ Status s = ValidateElementToLargerSlice(element, parent);
+ if (!s.ok()) {
+ return s;
+ }
auto element_t = element.tensor<T, NDIMS>();
auto parent_t = parent->tensor<T, NDIMS + 1>();
Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 24e18ae4da..cda2eaf26a 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -72,6 +72,8 @@ QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
name_(name),
closed_(false) {}
+QueueBase::~QueueBase() {}
+
Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
if (tuple.size() != static_cast<size_t>(num_components())) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 8986b15f72..6297d03c88 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -120,7 +120,7 @@ class QueueBase : public QueueInterface {
// of the *_attempts_ queues.
void FlushUnlocked();
- ~QueueBase() override {}
+ ~QueueBase() override;
// Helpers for implementing MatchesNodeDef().
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 8821e6d868..b70c9657b2 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -42,11 +42,28 @@ class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
}
};
+// Out of line check to save code space (we have this code once, rather
+// than once for every NDIMS * NumTypes * Num_different_relu_variants
+// functions.
+static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+}
+static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ ValidateSameSizeHelper(context, g, a);
+ return context->status().ok();
+}
+
template <typename Device, typename T>
class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): either the inputs that were passed to ReluOp(), or its
@@ -56,15 +73,21 @@ class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
- functor::ReluGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
+ OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
+void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ValidateSameSize(context, g, a)) return;
+ functor::ReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
+template <typename Device, typename T>
class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
public:
using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
@@ -81,6 +104,9 @@ class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to Relu6Op()
@@ -89,15 +115,21 @@ class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
- functor::Relu6Grad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
+ OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
+void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ValidateSameSize(context, g, a)) return;
+ functor::Relu6Grad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
@@ -114,6 +146,9 @@ class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
// INPUTS:
// g (gradients): backpropagated gradients
// a (outputs): outputs of the EluOp()
@@ -122,14 +157,20 @@ class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
- functor::EluGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
+ OperateNoTemplate(context, g, a, output);
}
};
+template <typename Device, typename T>
+void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ValidateSameSize(context, g, a)) return;
+ functor::EluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
#define REGISTER_RELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 518053cc0b..8bf83fe44f 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -58,6 +58,39 @@ struct Assign<scatter_op::UpdateOp::SUB> {
} // namespace
+// Check whether updates.shape = indices.shape + params.shape[1:]
+static bool ValidShapes(const Tensor& params, const Tensor& updates,
+ const Tensor& indices) {
+ if (updates.dims() != indices.dims() + params.dims() - 1) return false;
+ for (int d = 0; d < indices.dims(); d++) {
+ if (updates.dim_size(d) != indices.dim_size(d)) {
+ return false;
+ }
+ }
+ for (int d = 1; d < params.dims(); d++) {
+ if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
+ const Tensor& indices, const Tensor& updates) {
+ OP_REQUIRES(c, params.IsInitialized(),
+ errors::FailedPrecondition("Null ref for params"));
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
+ errors::InvalidArgument("params must be at least 1-D, got shape ",
+ params.shape().DebugString()));
+ OP_REQUIRES(
+ c, ValidShapes(params, updates, indices),
+ errors::InvalidArgument(
+ "Must have updates.shape = indices.shape + params.shape[1:], got ",
+ "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
+ indices.shape().DebugString(), ", params.shape ",
+ params.shape().DebugString()));
+}
+
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
class ScatterUpdateOp : public OpKernel {
public:
@@ -83,40 +116,12 @@ class ScatterUpdateOp : public OpKernel {
private:
bool use_exclusive_lock_;
- // Check whether updates.shape = indices.shape + params.shape[1:]
- static bool ValidShapes(const Tensor& params, const Tensor& updates,
- const Tensor& indices) {
- if (updates.dims() != indices.dims() + params.dims() - 1) return false;
- for (int d = 0; d < indices.dims(); d++) {
- if (updates.dim_size(d) != indices.dim_size(d)) {
- return false;
- }
- }
- for (int d = 1; d < params.dims(); d++) {
- if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
- return false;
- }
- }
- return true;
- }
-
void DoCompute(OpKernelContext* c) {
Tensor params = c->mutable_input(0, use_exclusive_lock_);
- OP_REQUIRES(c, params.IsInitialized(),
- errors::FailedPrecondition("Null ref for params"));
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
- OP_REQUIRES(
- c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
- errors::InvalidArgument("params must be at least 1-D, got shape ",
- params.shape().DebugString()));
- OP_REQUIRES(
- c, ValidShapes(params, updates, indices),
- errors::InvalidArgument(
- "Must have updates.shape = indices.shape + params.shape[1:], got ",
- "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
- indices.shape().DebugString(), ", params.shape ",
- params.shape().DebugString()));
+ DoValidationChecking(c, params, indices, updates);
+ if (!c->status().ok()) return;
// Check that we have enough index space
const int64 N_big = indices.NumElements();
@@ -178,45 +183,38 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
};
} // namespace functor
-#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
- REGISTER_KERNEL_BUILDER( \
- Name(name) \
- .Device(DEVICE_##dev) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- ScatterUpdateOp<dev##Device, type, index_type, op>)
+#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
+ REGISTER_KERNEL_BUILDER(Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterUpdateOp<dev##Device, type, index_type, op>)
-#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
- REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
+#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
+ REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ADD_SUB(type, dev) \
- REGISTER_SCATTER_KERNEL( \
- type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
- REGISTER_SCATTER_KERNEL( \
- type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
+#define REGISTER_SCATTER_ADD_SUB(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
-#define REGISTER_SCATTER_UPDATE(type, dev) \
- REGISTER_SCATTER_KERNEL( \
- type, dev, "ScatterUpdate", scatter_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_UPDATE(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
+ scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ADD_SUB_CPU(type) \
- REGISTER_SCATTER_ADD_SUB(type, CPU);
+#define REGISTER_SCATTER_ADD_SUB_CPU(type) REGISTER_SCATTER_ADD_SUB(type, CPU);
-#define REGISTER_SCATTER_UPDATE_CPU(type) \
- REGISTER_SCATTER_UPDATE(type, CPU);
+#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ADD_SUB_GPU(type) \
- REGISTER_SCATTER_ADD_SUB(type, GPU);
+#define REGISTER_SCATTER_ADD_SUB_GPU(type) REGISTER_SCATTER_ADD_SUB(type, GPU);
-#define REGISTER_SCATTER_UPDATE_GPU(type) \
- REGISTER_SCATTER_UPDATE(type, GPU);
+#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
@@ -251,8 +249,8 @@ namespace functor {
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
-#define DECLARE_GPU_SPECS(T) \
- DECLARE_GPU_SPECS_INDEX(T, int32); \
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 1a828f6994..8b672960d3 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -34,6 +34,26 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+// Static routines not in the templated class to reduce code size
+static void SegmentReductionValidationHelper(OpKernelContext* context,
+ const Tensor& input,
+ const Tensor& segment_ids) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
+ errors::InvalidArgument("segment_ids should be a vector."));
+ const int64 num_indices = segment_ids.NumElements();
+ OP_REQUIRES(context, num_indices == input.dim_size(0),
+ errors::InvalidArgument(
+ "segment_ids should be the same size as dimension 0 of"
+ " input."));
+}
+
+static bool SegmentReductionDoValidation(OpKernelContext* c,
+ const Tensor& input,
+ const Tensor& segment_ids) {
+ SegmentReductionValidationHelper(c, input, segment_ids);
+ return c->status().ok();
+}
+
// This operator handles reducing segments along the first dimension.
// See core/ops/math_ops.cc for more details.
template <typename Device, class T, class Index, typename Reducer>
@@ -46,14 +66,11 @@ class SegmentReductionOp : public OpKernel {
const Tensor& input = context->input(0);
const Tensor& segment_ids = context->input(1);
- OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
- errors::InvalidArgument("segment_ids should be a vector."));
- const int64 num_indices = segment_ids.NumElements();
- OP_REQUIRES(context, num_indices == input.dim_size(0),
- errors::InvalidArgument(
- "segment_ids should be the same size as dimension 0 of"
- " input."));
+ if (!SegmentReductionDoValidation(context, input, segment_ids)) {
+ return;
+ }
+ const int64 num_indices = segment_ids.NumElements();
auto input_flat = input.flat_outer_dims<T>();
const int64 num_col = input_flat.dimension(1);
@@ -99,7 +116,8 @@ class SegmentReductionOp : public OpKernel {
// Process segment [start, end)
const T* in_slice_ptr = &input_flat(start, 0);
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
- Eigen::Unaligned> OutT;
+ Eigen::Unaligned>
+ OutT;
T* out_slice_ptr = &output_flat(segment_vec(start), 0);
OutT out_slice(out_slice_ptr, out_slice_shape);
// We don't use out_slice.device(context->eigen_device<Device>)
@@ -108,14 +126,16 @@ class SegmentReductionOp : public OpKernel {
// using another thread to do this work.
if (start == end - 1) {
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
- Eigen::Unaligned> InT;
+ Eigen::Unaligned>
+ InT;
InT in_slice(in_slice_ptr, out_slice_shape);
out_slice = in_slice;
} else {
Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
num_col);
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
- Eigen::Unaligned> InT;
+ Eigen::Unaligned>
+ InT;
InT in_slice(in_slice_ptr, in_slice_shape);
out_slice = in_slice.reduce(dims_to_reduce, Reducer());
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
index 142f5a819c..d840f15aca 100644
--- a/tensorflow/core/kernels/softplus_op.cc
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -48,6 +48,9 @@ class SoftplusGradOp
public:
using BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>::BinaryElementWiseOp;
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to SoftplusOp()
@@ -56,13 +59,20 @@ class SoftplusGradOp
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
- functor::SoftplusGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
+ OperateNoTemplate(context, g, a, output);
}
};
+template <typename Device, typename T>
+void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::SoftplusGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
index 3a07983ede..60090d9a60 100644
--- a/tensorflow/core/kernels/softsign_op.cc
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -48,6 +48,9 @@ class SoftsignGradOp
public:
using BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>::BinaryElementWiseOp;
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to SoftsignOp()
@@ -56,14 +59,22 @@ class SoftsignGradOp
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
- functor::SoftsignGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
+ OperateNoTemplate(context, g, a, output);
}
};
+template <typename Device, typename T>
+void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::SoftsignGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Softsign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index d0f0a38d8c..fe6bd1c029 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -60,7 +60,7 @@ void AppendToMessage(::tensorflow::Status* status, Args... args) {
#define DECLARE_ERROR(FUNC, CONST) \
template <typename... Args> \
- inline ::tensorflow::Status FUNC(Args... args) { \
+ ::tensorflow::Status FUNC(Args... args) { \
return ::tensorflow::Status(::tensorflow::error::CONST, \
::tensorflow::strings::StrCat(args...)); \
} \
@@ -90,73 +90,6 @@ DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED)
// The CanonicalCode() for non-errors.
using ::tensorflow::error::OK;
-// Convenience macros for asserting and handling exceptional conditions.
-// Analogous to the CHECK* macros provided by logging.h.
-//
-// Example use:
-// void Compute(OperationContext* context) {
-// OP_REQUIRES(context, context->num_inputs() == 2,
-// errors::InvalidArgument("FooOp requires 2 arguments"));
-// ...
-// Status status = SomeUncertainMethod();
-// OP_REQUIRES_OK(context, status);
-// ...
-// }
-
-// Declares an op deprecated, and illegal starting at GraphDef version VERSION
-#define OP_DEPRECATED(CTX, VERSION, NOTE) \
- if ((CTX)->graph_def_version() >= (VERSION)) { \
- ::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
- "Op ", (CTX)->op_def().name(), \
- " is not available in GraphDef version ", (CTX)->graph_def_version(), \
- ". It has been removed in version ", (VERSION), ". ", (NOTE), ".")); \
- VLOG(1) << _s; \
- (CTX)->SetStatus(_s); \
- return; \
- } else { \
- LOG(WARNING) << "Op is deprecated." \
- << " It will cease to work in GraphDef version " << (VERSION) \
- << ". " << (NOTE) << "."; \
- }
-
-#define OP_REQUIRES(CTX, EXP, STATUS) \
- if (!(EXP)) { \
- ::tensorflow::Status _s(STATUS); \
- VLOG(1) << _s; \
- (CTX)->SetStatus(_s); \
- return; \
- }
-
-#define OP_REQUIRES_OK(CTX, STATUS) \
- do { \
- ::tensorflow::Status _s(STATUS); \
- if (!_s.ok()) { \
- LOG(WARNING) << _s; \
- (CTX)->SetStatus(_s); \
- return; \
- } \
- } while (0)
-
-#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
- if (!(EXP)) { \
- ::tensorflow::Status _s(STATUS); \
- VLOG(1) << _s; \
- (CTX)->SetStatus(_s); \
- (CALLBACK)(); \
- return; \
- }
-
-#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
- do { \
- ::tensorflow::Status _s(STATUS); \
- if (!_s.ok()) { \
- LOG(WARNING) << _s; \
- (CTX)->SetStatus(_s); \
- (CALLBACK)(); \
- return; \
- } \
- } while (0)
-
} // namespace errors
} // namespace tensorflow
diff --git a/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md b/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md
index 1ff484c083..1a1526f66d 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md
@@ -18,7 +18,7 @@ A file abstraction for randomly reading the contents of a file.
-#### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` {#virtual_Status_tensorflow_RandomAccessFile_Read}
+#### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` {#virtual_Status_tensorflow_RandomAccessFile_Read}
Reads up to `n` bytes from the file starting at `offset`.
diff --git a/tensorflow/g3doc/api_docs/cc/ClassSession.md b/tensorflow/g3doc/api_docs/cc/ClassSession.md
index c4a7e9bb02..d405727f6c 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassSession.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassSession.md
@@ -6,7 +6,36 @@ When a Session is created with a given target, a new Session object is bound to
Example:
-{c++} tensorflow::GraphDef graph; // ... Create or load graph into "graph". // This example uses the default options which connects // to a local runtime. tensorflow::SessionOptions options; std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession(options)); // Create the session with this graph. tensorflow::Status s = session->Create(graph); if (!s.ok()) { ... } // Run the graph and fetch the first output of the "output" // operation, and also run to but do not return anything // for the "update_state" operation. std::vector<tensorflow::Tensor> outputs; s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); if (!s.ok()) { ... } // Map the output as a flattened float tensor, and do something // with it. auto output_tensor = outputs[0].flat<float>(); if (output_tensor(0) > 0.5) { ... } // Close the session to release the resources associated with // this session. session->Close();
+```c++ tensorflow::GraphDef graph;
+// ... Create or load graph into "graph".
+
+// This example uses the default options which connects
+// to a local runtime.
+tensorflow::SessionOptions options;
+std::unique_ptr<tensorflow::Session>
+session(tensorflow::NewSession(options));
+
+// Create the session with this graph.
+tensorflow::Status s = session->Create(graph);
+if (!s.ok()) { ... }
+
+// Run the graph and fetch the first output of the "output"
+// operation, and also run to but do not return anything
+// for the "update_state" operation.
+std::vector<tensorflow::Tensor> outputs;
+s = session->Run({}, {"output:0"}, {"update_state"}, &outputs);
+if (!s.ok()) { ... }
+
+// Map the output as a flattened float tensor, and do something
+// with it.
+auto output_tensor = outputs[0].flat<float>();
+if (output_tensor(0) > 0.5) { ... }
+
+// Close the session to release the resources associated with
+// this session.
+session->Close();
+
+```
A Session allows concurrent calls to Run() , though a Session must be created / extended by a single thread.
@@ -38,6 +67,12 @@ REQUIRES: The name of each Tensor of the input or output must match a "Tensor en
REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty.
+#### `virtual Status tensorflow::Session::RunWithOpts(const RunOptions &run_options, const std::vector< std::pair< string, Tensor > > &inputs, const std::vector< string > &output_tensor_names, const std::vector< string > &target_node_names, std::vector< Tensor > *outputs, RunOutputs *run_outputs)` {#virtual_Status_tensorflow_Session_RunWithOpts}
+
+Like `Run`, but allows users to pass in a `RunOptions` proto and to retrieve non-Tensor metadata output via a `RunOutputs` proto for this step. NOTE: This API is still experimental and may change.
+
+
+
#### `virtual Status tensorflow::Session::PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle)` {#virtual_Status_tensorflow_Session_PRunSetup}
Sets up a graph for partial execution. All future feeds and fetches are specified by &apos;input_names&apos; and &apos;output_names&apos;. Returns &apos;handle&apos; that can be used to perform a sequence of partial feeds and fetches. NOTE: This API is still experimental and may change.
diff --git a/tensorflow/g3doc/api_docs/cc/ClassTensor.md b/tensorflow/g3doc/api_docs/cc/ClassTensor.md
index 812d2354c6..2708244c61 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassTensor.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassTensor.md
@@ -108,6 +108,12 @@ Returns the estimated memory usage of this tensor.
+#### `bool tensorflow::Tensor::IsAligned() const` {#bool_tensorflow_Tensor_IsAligned}
+
+Returns true iff this tensor is aligned.
+
+
+
#### `Tensor& tensorflow::Tensor::operator=(const Tensor &other)` {#Tensor_tensorflow_Tensor_operator_}
Assign operator. This tensor shares other&apos;s underlying storage.
@@ -162,7 +168,15 @@ Use these methods when you know the data type and the number of dimensions of th
Example:
-{c++} typedef float T; Tensor my_mat(...built with Shape{rows: 3, cols: 5}...); auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5. auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5. auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D. auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D. auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
+```c++ typedef float T;
+Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
+auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
+auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
+auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
+auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
+auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
+
+```
#### `TTypes<T>::Matrix tensorflow::Tensor::matrix()` {#TTypes_T_Matrix_tensorflow_Tensor_matrix}
@@ -184,7 +198,22 @@ These methods allow you to access the data with the dimensions and sizes of your
Example:
-{c++} typedef float T; Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...); // 1D Eigen::Tensor, size 60: auto flat = my_ten.flat<T>(); // 2D Eigen::Tensor 12 x 5: auto inner = my_ten.flat_inner_dims<T>(); // 2D Eigen::Tensor 4 x 15: auto outer = my_ten.shaped<T, 2>({4, 15}); // CHECK fails, bad num elements: auto outer = my_ten.shaped<T, 2>({4, 8}); // 3D Eigen::Tensor 6 x 5 x 2: auto weird = my_ten.shaped<T, 3>({6, 5, 2}); // CHECK fails, type mismatch: auto bad = my_ten.flat<int32>();
+```c++ typedef float T;
+Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
+// 1D Eigen::Tensor, size 60:
+auto flat = my_ten.flat<T>();
+// 2D Eigen::Tensor 12 x 5:
+auto inner = my_ten.flat_inner_dims<T>();
+// 2D Eigen::Tensor 4 x 15:
+auto outer = my_ten.shaped<T, 2>({4, 15});
+// CHECK fails, bad num elements:
+auto outer = my_ten.shaped<T, 2>({4, 8});
+// 3D Eigen::Tensor 6 x 5 x 2:
+auto weird = my_ten.shaped<T, 3>({6, 5, 2});
+// CHECK fails, type mismatch:
+auto bad = my_ten.flat<int32>();
+
+```
#### `TTypes<T>::UnalignedFlat tensorflow::Tensor::unaligned_flat()` {#TTypes_T_UnalignedFlat_tensorflow_Tensor_unaligned_flat}
@@ -308,4 +337,12 @@ The returned ` StringPiece ` may point to memory location on devices that the CP
NOTE: The underlying tensor buffer is refcounted, so the lifetime of the contents mapped by the ` StringPiece ` matches the lifetime of the buffer; callers should arrange to make sure the buffer does not get destroyed while the ` StringPiece ` is still used.
-REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
+REQUIRES: `DataTypeCanUseMemcpy( dtype() )`.
+
+#### `void tensorflow::Tensor::UnsafeCopyFromInternal(const Tensor &, const TensorShape &)` {#void_tensorflow_Tensor_UnsafeCopyFromInternal}
+
+
+
+Copy the other tensor into this tensor and reshape it and reinterpret the buffer&apos;s datatype.
+
+This tensor shares other&apos;s underlying storage.
diff --git a/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md b/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
index def4232b9f..19d0ec14d7 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
@@ -132,6 +132,12 @@ Returns true if `*this` and `b` have the same sizes. Ignores dimension names.
+#### `bool tensorflow::TensorShape::operator!=(const TensorShape &b) const` {#bool_tensorflow_TensorShape_operator_}
+
+
+
+
+
#### `void tensorflow::TensorShape::AsProto(TensorShapeProto *proto) const` {#void_tensorflow_TensorShape_AsProto}
Fill `*proto` from `*this`.
diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md
index c84025ce59..6feb73d986 100644
--- a/tensorflow/g3doc/api_docs/cc/index.md
+++ b/tensorflow/g3doc/api_docs/cc/index.md
@@ -25,33 +25,33 @@ write the graph to a file.
## Env
-* [tensorflow::Env](classEnv.md)
-* [tensorflow::RandomAccessFile](classRandomAccessFile.md)
-* [tensorflow::WritableFile](classWritableFile.md)
-* [tensorflow::EnvWrapper](classEnvWrapper.md)
+* [tensorflow::Env](ClassEnv.md)
+* [tensorflow::RandomAccessFile](ClassRandomAccessFile.md)
+* [tensorflow::WritableFile](ClassWritableFile.md)
+* [tensorflow::EnvWrapper](ClassEnvWrapper.md)
## Session
-* [tensorflow::Session](classSession.md)
-* [tensorflow::SessionOptions](structSessionOptions.md)
+* [tensorflow::Session](ClassSession.md)
+* [tensorflow::SessionOptions](StructSessionOptions.md)
## Status
-* [tensorflow::Status](classStatus.md)
-* [tensorflow::Status::State](structState.md)
+* [tensorflow::Status](ClassStatus.md)
+* [tensorflow::Status::State](StructState.md)
## Tensor
-* [tensorflow::Tensor](classTensor.md)
-* [tensorflow::TensorShape](classTensorShape.md)
-* [tensorflow::TensorShapeDim](structTensorShapeDim.md)
-* [tensorflow::TensorShapeUtils](classTensorShapeUtils.md)
-* [tensorflow::PartialTensorShape](classPartialTensorShape.md)
-* [tensorflow::PartialTensorShapeUtils](classPartialTensorShapeUtils.md)
-* [TF_Buffer](structTF_Buffer.md)
+* [tensorflow::Tensor](ClassTensor.md)
+* [tensorflow::TensorShape](ClassTensorShape.md)
+* [tensorflow::TensorShapeDim](StructTensorShapeDim.md)
+* [tensorflow::TensorShapeUtils](ClassTensorShapeUtils.md)
+* [tensorflow::PartialTensorShape](ClassPartialTensorShape.md)
+* [tensorflow::PartialTensorShapeUtils](ClassPartialTensorShapeUtils.md)
+* [TF_Buffer](StructTF_Buffer.md)
## Thread
-* [tensorflow::Thread](classThread.md)
-* [tensorflow::ThreadOptions](structThreadOptions.md)
+* [tensorflow::Thread](ClassThread.md)
+* [tensorflow::ThreadOptions](StructThreadOptions.md)
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index 4cffb39fd8..19ba42d96f 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -176,9 +176,8 @@ Creates a constant tensor.
elements specified by `shape`, the last element in the list will be used
to fill the remaining entries.
- The argument `shape` is optional. If present, it specifies the dimensions
- of the resulting tensor. If not present, then the tensor is a scalar (0-D)
- if `value` is a scalar, or 1-D otherwise.
+ The argument `shape` is optional. If present, it specifies the dimensions of
+ the resulting tensor. If not present, the shape of `value` is used.
If the argument `dtype` is not specified, then the type is inferred from
the type of `value`.
diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md
index 78bc28d8d5..842a48d539 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.layers.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md
@@ -350,6 +350,36 @@ Summarize activations, using `summarize_activation` to summarize.
## Other Functions and Classes
- - -
+### `tf.contrib.layers.absolute_loss(predicted, target, name=None)` {#absolute_loss}
+
+Computes and returns the per-example absolute loss.
+
+Computes the per-example absolute value of the difference between
+the target and predicted tensors. The tensors must have the same
+shape.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
+- - -
+
### `tf.contrib.layers.assert_same_float_dtype(tensors=None, dtype=None)` {#assert_same_float_dtype}
Validate and return float type based on `tensors` and `dtype`.
@@ -406,3 +436,247 @@ Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
+- - -
+
+### `tf.contrib.layers.mean_absolute_loss(predicted, target, name=None)` {#mean_absolute_loss}
+
+Calculates the mean absolute loss across batches.
+
+Computes the absolute difference between the target and predicted
+tensors, averaged across all dimensions except dimension 0:
+
+ losses = reduce_batch_mean(absolute_loss(predicted, target))
+
+where `losses` is a tensor with dimensions [batch_size].
+
+The tensors must have the same shape.
+
+This loss function is a form of L1 loss.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size]` tensor of absolute differences, averaged across all
+ dimensions except dimension 0.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
+- - -
+
+### `tf.contrib.layers.mean_squared_loss(predicted, target, name=None)` {#mean_squared_loss}
+
+Calculates the mean squared loss across batches.
+
+Computes the squared difference between the target and predicted
+tensors, and averages across all dimensions except dimension 0:
+
+ losses = reduce_batch_mean(squared_loss(predicted, target))
+
+where `losses` is a tensor with dimensions [batch_size].
+
+The tensors must have the same shape.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size]` tensor of squared differences, averaged across
+ all dimensions except dimension 0.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
+- - -
+
+### `tf.contrib.layers.reduce_batch_mean(x, name=None)` {#reduce_batch_mean}
+
+Given a tensor `x`, returns the mean across all dimensions except dim 0.
+
+Given a tensor with the number of dimensions > 1, reduce_batch_mean
+will calculate the mean across all dimensions except for dimension
+0. This function is useful for calculating the mean loss (error)
+across all examples in a batch when training. As an example, given a
+tensor of shape [batch_size, d1, d2], this function will calculate
+the mean across dimensions d1 and d2, returning a tensor of shape
+[batch_size].
+
+Tensors of dimension 1 are returned as-is.
+
+##### Args:
+
+
+* <b>`x`</b>: A `Tensor` with dimension > 0.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor` with values averaged across all dimensions > 0.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `x` has dimension 0.
+
+
+- - -
+
+### `tf.contrib.layers.reduce_batch_sum(x, name=None)` {#reduce_batch_sum}
+
+Given a tensor `x`, sums across all dimensions except dimension 0.
+
+Given a tensor with the number of dimensions > 1, reduce_batch_sum
+will sum across all dimensions except for dimension 0. This function
+is useful for summing the loss (error) across all examples in a
+batch when training. As an example, given a tensor of shape
+[batch_size, d1, d2], this function will sum across dimensions d1
+and d2, returning a tensor of shape [batch_size].
+
+Tensors of dimension 1 are returned as-is, while tensors of dimension 0
+raise a ValueError.
+
+##### Args:
+
+
+* <b>`x`</b>: A `Tensor` with dimension > 0.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor` with values summed across all dimensions > 0.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `x` has dimension 0.
+
+
+- - -
+
+### `tf.contrib.layers.root_mean_squared_loss(predicted, target, name=None)` {#root_mean_squared_loss}
+
+Calculates the root mean squared loss across batches.
+
+Computes the root mean squared loss between the target and predicted
+tensors, which is the square root of the mean squared differences
+between the predicted and target tensors:
+
+ losses = sqrt(mean_squared_loss(predicted, target))
+
+where `losses` is a tensor with dimensions [batch_size].
+
+The tensors must have the same shape.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size]` tensor of the root mean squared differences.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
+- - -
+
+### `tf.contrib.layers.squared_loss(predicted, target, name=None)` {#squared_loss}
+
+Computes and returns the per-example squared loss.
+
+Computes the per-example squared difference between the target and
+predicted tensors. The tensors must have the same shape.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
+- - -
+
+### `tf.contrib.layers.sum_squared_loss(predicted, target, name=None)` {#sum_squared_loss}
+
+Calculates 1/2 the sum of the squared loss across batches.
+
+Computes the squared difference between the target and predicted
+tensors, sums across all dimensions except dimension 0, and divides
+by 2:
+
+ losses = reduce_batch_sum(squared_loss(predicted, target)) / 2.0
+
+where `losses` is a tensor with dimensions [batch_size].
+
+The tensors must have the same shape.
+
+This function is equivalent to typical formulations of L2 loss, and similar
+to TensorFlow's l2_loss function. It differs from the l2_loss function
+by allowing the caller to specify both the predicted and target tensors.
+
+##### Args:
+
+
+* <b>`predicted`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+* <b>`target`</b>: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `[batch_size]` tensor of squared losses summed across all dimensions
+ except dimension 0, divided by 2.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If `predicted` and `target` shapes do not match.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/contrib.util.md b/tensorflow/g3doc/api_docs/python/contrib.util.md
index 26b71172cb..d5442b4f67 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.util.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.util.md
@@ -83,3 +83,26 @@ Otherwise, "shape" specifies the tensor's shape and the numpy array
can not have more elements than what "shape" specifies.
+- - -
+
+### `tf.contrib.util.make_ndarray(tensor)` {#make_ndarray}
+
+Create a numpy ndarray from a tensor.
+
+Create a numpy ndarray with the same shape and data as the tensor.
+
+##### Args:
+
+
+* <b>`tensor`</b>: A TensorProto.
+
+##### Returns:
+
+ A numpy array with the tensor contents.
+
+##### Raises:
+
+
+* <b>`TypeError`</b>: if tensor has unsupported type.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 52d2560dd3..33f36387bf 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1234,6 +1234,13 @@ Returns a non-reference `DType` based on this `DType`.
- - -
+#### `tf.DType.real_dtype` {#DType.real_dtype}
+
+Returns the dtype correspond to this dtype's real part.
+
+
+- - -
+
#### `tf.DType.is_ref_dtype` {#DType.is_ref_dtype}
Returns `True` if this `DType` represents a reference type.
@@ -1255,6 +1262,13 @@ Returns whether this is a (real) floating point type.
- - -
+#### `tf.DType.is_complex` {#DType.is_complex}
+
+Returns whether this is a complex floating point type.
+
+
+- - -
+
#### `tf.DType.is_integer` {#DType.is_integer}
Returns whether this is a (non-quantized) integer type.
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index c85e25bbe8..d87eef8ff6 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -60,8 +60,10 @@
* [`get_variable_scope`](../../api_docs/python/state_ops.md#get_variable_scope)
* [`IndexedSlices`](../../api_docs/python/state_ops.md#IndexedSlices)
* [`initialize_all_variables`](../../api_docs/python/state_ops.md#initialize_all_variables)
+ * [`initialize_local_variables`](../../api_docs/python/state_ops.md#initialize_local_variables)
* [`initialize_variables`](../../api_docs/python/state_ops.md#initialize_variables)
* [`latest_checkpoint`](../../api_docs/python/state_ops.md#latest_checkpoint)
+ * [`local_variables`](../../api_docs/python/state_ops.md#local_variables)
* [`make_template`](../../api_docs/python/state_ops.md#make_template)
* [`moving_average_variables`](../../api_docs/python/state_ops.md#moving_average_variables)
* [`no_regularizer`](../../api_docs/python/state_ops.md#no_regularizer)
@@ -272,6 +274,7 @@
* [`shape`](../../api_docs/python/sparse_ops.md#shape)
* [`sparse_concat`](../../api_docs/python/sparse_ops.md#sparse_concat)
* [`sparse_fill_empty_rows`](../../api_docs/python/sparse_ops.md#sparse_fill_empty_rows)
+ * [`sparse_merge`](../../api_docs/python/sparse_ops.md#sparse_merge)
* [`sparse_reorder`](../../api_docs/python/sparse_ops.md#sparse_reorder)
* [`sparse_retain`](../../api_docs/python/sparse_ops.md#sparse_retain)
* [`sparse_split`](../../api_docs/python/sparse_ops.md#sparse_split)
@@ -326,9 +329,11 @@
* [`conv2d`](../../api_docs/python/nn.md#conv2d)
* [`conv2d_transpose`](../../api_docs/python/nn.md#conv2d_transpose)
* [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d)
+ * [`depthwise_conv2d_native`](../../api_docs/python/nn.md#depthwise_conv2d_native)
* [`dropout`](../../api_docs/python/nn.md#dropout)
* [`elu`](../../api_docs/python/nn.md#elu)
* [`embedding_lookup`](../../api_docs/python/nn.md#embedding_lookup)
+ * [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse)
* [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
* [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
* [`l2_loss`](../../api_docs/python/nn.md#l2_loss)
@@ -340,6 +345,7 @@
* [`max_pool_with_argmax`](../../api_docs/python/nn.md#max_pool_with_argmax)
* [`moments`](../../api_docs/python/nn.md#moments)
* [`nce_loss`](../../api_docs/python/nn.md#nce_loss)
+ * [`normalize_moments`](../../api_docs/python/nn.md#normalize_moments)
* [`relu`](../../api_docs/python/nn.md#relu)
* [`relu6`](../../api_docs/python/nn.md#relu6)
* [`sampled_softmax_loss`](../../api_docs/python/nn.md#sampled_softmax_loss)
@@ -351,6 +357,7 @@
* [`softplus`](../../api_docs/python/nn.md#softplus)
* [`softsign`](../../api_docs/python/nn.md#softsign)
* [`sparse_softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#sparse_softmax_cross_entropy_with_logits)
+ * [`sufficient_statistics`](../../api_docs/python/nn.md#sufficient_statistics)
* [`tanh`](../../api_docs/python/nn.md#tanh)
* [`top_k`](../../api_docs/python/nn.md#top_k)
* [`uniform_candidate_sampler`](../../api_docs/python/nn.md#uniform_candidate_sampler)
@@ -426,6 +433,7 @@
* [`main`](../../api_docs/python/test.md#main)
* **[Layers (contrib)](../../api_docs/python/contrib.layers.md)**:
+ * [`absolute_loss`](../../api_docs/python/contrib.layers.md#absolute_loss)
* [`assert_same_float_dtype`](../../api_docs/python/contrib.layers.md#assert_same_float_dtype)
* [`assert_scalar_int`](../../api_docs/python/contrib.layers.md#assert_scalar_int)
* [`convolution2d`](../../api_docs/python/contrib.layers.md#convolution2d)
@@ -433,6 +441,13 @@
* [`is_numeric_tensor`](../../api_docs/python/contrib.layers.md#is_numeric_tensor)
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
+ * [`mean_absolute_loss`](../../api_docs/python/contrib.layers.md#mean_absolute_loss)
+ * [`mean_squared_loss`](../../api_docs/python/contrib.layers.md#mean_squared_loss)
+ * [`reduce_batch_mean`](../../api_docs/python/contrib.layers.md#reduce_batch_mean)
+ * [`reduce_batch_sum`](../../api_docs/python/contrib.layers.md#reduce_batch_sum)
+ * [`root_mean_squared_loss`](../../api_docs/python/contrib.layers.md#root_mean_squared_loss)
+ * [`squared_loss`](../../api_docs/python/contrib.layers.md#squared_loss)
+ * [`sum_squared_loss`](../../api_docs/python/contrib.layers.md#sum_squared_loss)
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)
* [`summarize_activations`](../../api_docs/python/contrib.layers.md#summarize_activations)
* [`summarize_collection`](../../api_docs/python/contrib.layers.md#summarize_collection)
@@ -443,5 +458,6 @@
* **[Utilities (contrib)](../../api_docs/python/contrib.util.md)**:
* [`constant_value`](../../api_docs/python/contrib.util.md#constant_value)
+ * [`make_ndarray`](../../api_docs/python/contrib.util.md#make_ndarray)
* [`make_tensor_proto`](../../api_docs/python/contrib.util.md#make_tensor_proto)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index b8685b6a61..35a072c1a4 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -285,7 +285,7 @@ concatenated.
- - -
-### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)` {#conv2d}
+### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)` {#conv2d}
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@@ -302,7 +302,7 @@ performs the following:
3. For each patch, right-multiplies the filter matrix and the image patch
vector.
-In detail,
+In detail, with the default NCHW format,
output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
@@ -318,10 +318,16 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
* <b>`filter`</b>: A `Tensor`. Must have the same type as `input`.
* <b>`strides`</b>: A list of `ints`.
1-D of length 4. The stride of the sliding window for each dimension
- of `input`.
+ of `input`. Must be in the same order as the dimension specified with format.
* <b>`padding`</b>: A `string` from: `"SAME", "VALID"`.
The type of padding algorithm to use.
* <b>`use_cudnn_on_gpu`</b>: An optional `bool`. Defaults to `True`.
+* <b>`data_format`</b>: An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`.
+ Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -467,7 +473,7 @@ to the `Convolution` section for details about the padding calculation.
- - -
-### `tf.nn.avg_pool(value, ksize, strides, padding, name=None)` {#avg_pool}
+### `tf.nn.avg_pool(value, ksize, strides, padding, data_format='NHWC', name=None)` {#avg_pool}
Performs the average pooling on the input.
@@ -485,6 +491,7 @@ window in `value`.
The stride of the sliding window for each dimension of the
input tensor.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+* <b>`data_format`</b>: A string. 'NHWC' and 'NCHW" are supported.
* <b>`name`</b>: Optional name for the operation.
##### Returns:
@@ -494,7 +501,7 @@ window in `value`.
- - -
-### `tf.nn.max_pool(value, ksize, strides, padding, name=None)` {#max_pool}
+### `tf.nn.max_pool(value, ksize, strides, padding, data_format='NHWC', name=None)` {#max_pool}
Performs the max pooling on the input.
@@ -508,6 +515,7 @@ Performs the max pooling on the input.
* <b>`strides`</b>: A list of ints that has length >= 4. The stride of the sliding
window for each dimension of the input tensor.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+* <b>`data_format`</b>: A string. 'NHWC' and 'NCHW" are supported.
* <b>`name`</b>: Optional name for the operation.
##### Returns:
@@ -594,7 +602,7 @@ each component is divided by the weighted, squared sum of inputs within
sqr_sum[a, b, c, d] =
sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
- output = input / (bias + alpha * sqr_sum ** beta)
+ output = input / (bias + alpha * sqr_sum) ** beta
For details, see [Krizhevsky et al., ImageNet classification with deep
convolutional neural networks (NIPS 2012)]
@@ -620,6 +628,58 @@ convolutional neural networks (NIPS 2012)]
- - -
+### `tf.nn.sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None)` {#sufficient_statistics}
+
+Calculate the sufficient statistics for the mean and variance of `x`.
+
+These sufficient statistics are computed using the one pass algorithm on
+an input that's optionally shifted using the value of the 1st element in `x`.
+See:
+https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
+
+##### Args:
+
+
+* <b>`x`</b>: A `Tensor`.
+* <b>`axes`</b>: Array of ints. Axes along which to compute mean and variance.
+* <b>`shift`</b>: If true, shift the data to provide more numerically stable results.
+* <b>`keep_dims`</b>: produce statistics with the same dimensionality as the input.
+* <b>`name`</b>: Name used to scope the operations that compute the sufficient stats.
+
+##### Returns:
+
+ Four `Tensor` objects of the same type as `x`:
+ * the count (number of elements to average over).
+ * the (possibly shifted) sum of the elements in the array.
+ * the (possibly shifted) sum of squares of the elements in the array.
+ * the shift by which the mean must be corrected or None if `shift` is False.
+
+
+- - -
+
+### `tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift, name=None)` {#normalize_moments}
+
+Calculate the mean and variance of based on the sufficient statistics.
+
+##### Args:
+
+
+* <b>`counts`</b>: A `Tensor` containing a the total count of the data (one value).
+* <b>`mean_ss`</b>: A `Tensor` containing the mean sufficient statistics: the (possibly
+ shifted) sum of the elements to average over.
+* <b>`variance_ss`</b>: A `Tensor` containing the variance sufficient statistics: the
+ (possibly shifted) squared sum of the data to compute the variance over.
+* <b>`shift`</b>: A `Tensor` containing the value by which the data is shifted for
+ numerical stability, or `None` if no shift was performed.
+* <b>`name`</b>: Name used to scope the operations that compute the moments.
+
+##### Returns:
+
+ Two `Tensor` objects: `mean` and `variance`.
+
+
+- - -
+
### `tf.nn.moments(x, axes, name=None, keep_dims=False)` {#moments}
Calculate the mean and variance of `x`.
@@ -628,9 +688,11 @@ The mean and variance are calculated by aggregating the contents of `x`
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
and variance of a vector.
-For so-called "global normalization" needed for convolutional filters pass
-`axes=[0, 1, 2]` (batch, height, width). For batch normalization pass
-`axes=[0]` (batch).
+When using these moments for batch normalization (see
+`tf.nn.batch_normalization`):
+ * for so-called "global normalization", used with convolutional filters with
+ shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
+ * for simple batch normalization pass `axes=[0]` (batch only).
##### Args:
@@ -877,6 +939,75 @@ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
* <b>`ValueError`</b>: If `params` is empty.
+- - -
+
+### `tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, partition_strategy='mod', name=None, combiner='mean')` {#embedding_lookup_sparse}
+
+Computes embeddings for the given ids and weights.
+
+This op assumes that there is at least one id for each row in the dense tensor
+represented by sp_ids (i.e. there are no rows with empty features), and that
+all the indices of sp_ids are in canonical row-major order.
+
+It also assumes that all id values lie in the range [0, p0), where p0
+is the sum of the size of params along dimension 0.
+
+##### Args:
+
+
+* <b>`params`</b>: A single tensor representing the complete embedding tensor,
+ or a list of P tensors all of same shape except for the first dimension,
+ representing sharded embedding tensors.
+* <b>`sp_ids`</b>: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ where N is typically batch size and M is arbitrary.
+* <b>`sp_weights`</b>: either a SparseTensor of float / double weights, or None to
+ indicate all weights should be taken to be 1. If specified, sp_weights
+ must have exactly the same shape and indices as sp_ids.
+* <b>`partition_strategy`</b>: A string specifying the partitioning strategy, relevant
+ if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
+ is `"mod"`. See `tf.nn.embedding_lookup` for more details.
+* <b>`name`</b>: Optional name for the op.
+* <b>`combiner`</b>: A string specifying the reduction op. Currently "mean", "sqrtn"
+ and "sum" are supported.
+ "sum" computes the weighted sum of the embedding results for each row.
+ "mean" is the weighted sum divided by the total weight.
+ "sqrtn" is the weighted sum divided by the square root of the sum of the
+ squares of the weights.
+
+##### Returns:
+
+ A dense tensor representing the combined embeddings for the
+ sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ looks up the embeddings for all ids in that row, multiplies them by the
+ corresponding weight, and combines these embeddings as specified.
+
+ In other words, if
+ shape(combined params) = [p0, p1, ..., pm]
+ and
+ shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
+ then
+ shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
+
+ For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
+
+ [0, 0]: id 1, weight 2.0
+ [0, 1]: id 3, weight 0.5
+ [1, 0]: id 0, weight 1.0
+ [2, 3]: id 1, weight 3.0
+
+ with combiner="mean", then the output will be a 3x20 matrix where
+ output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
+ output[1, :] = params[0, :] * 1.0
+ output[2, :] = params[1, :] * 3.0
+
+##### Raises:
+
+
+* <b>`TypeError`</b>: If sp_ids is not a SparseTensor, or if sp_weights is neither
+ None nor SparseTensor.
+* <b>`ValueError`</b>: If combiner is not one of {"mean", "sqrtn", "sum"}.
+
+
## Evaluation
@@ -1381,3 +1512,94 @@ target classes as noise classes for the same example.
Each value is `-FLOAT_MAX`.
+
+## Other Functions and Classes
+- - -
+
+### `tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)` {#batch_normalization}
+
+Batch normalization.
+
+As described in http://arxiv.org/abs/1502.03167.
+Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
+`scale` \\(\gamma\\) to it, as well as an `offest` \\(eta\\):
+
+\\( rac{\gamma(x-\mu)}{\sigma}+eta\\)
+
+`mean`, `variance`, `offset` and `scale` are all expected to be of one of two
+shapes:
+ * In all generality, they can have the same number of dimensions as the
+ input `x`, with identical sizes as `x` for the dimensions that are not
+ normalized over (the 'depth' dimension(s)), and dimension 1 for the
+ others which are being normalized over.
+ `mean` and `variance` in this case would typically be the outputs of
+ `tf.nn.moments(..., keep_dims=True)` during training, or running averages
+ thereof during inference.
+ * In the common case where the 'depth' dimension is the last dimension in
+ the input tensor `x`, they may be one dimensional tensors of the same
+ size as the 'depth' dimension.
+ This is the case for example for the common `[batch, depth]` layout of
+ fully-connected layers, and `[batch, height, width, depth]` for
+ convolutions.
+ `mean` and `variance` in this case would typically be the outputs of
+ `tf.nn.moments(..., keep_dims=False)` during training, or running averages
+ thereof during inference.
+
+##### Args:
+
+
+* <b>`x`</b>: Input `Tensor` of arbitrary dimensionality.
+* <b>`mean`</b>: A mean `Tensor`.
+* <b>`variance`</b>: A variance `Tensor`.
+* <b>`offset`</b>: An offset `Tensor`, often denoted \\(eta\\) in equations, or
+ None. If present, will be added to the normalized tensor.
+* <b>`scale`</b>: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
+ `None`. If present, the scale is applied to the normalized tensor.
+* <b>`variance_epsilon`</b>: A small float number to avoid dividing by 0.
+* <b>`name`</b>: A name for this operation (optional).
+
+##### Returns:
+
+ the normalized, scaled, offset tensor.
+
+
+- - -
+
+### `tf.nn.depthwise_conv2d_native(input, filter, strides, padding, name=None)` {#depthwise_conv2d_native}
+
+Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.
+
+Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
+and a filter / kernel tensor of shape
+`[filter_height, filter_width, in_channels, channel_multiplier]`, containing
+`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
+a different filter to each input channel (expanding from 1 channel to
+`channel_multiplier` channels for each), then concatenates the results
+together. Thus, the output has `in_channels * channel_multiplier` channels.
+
+for k in 0..in_channels-1
+ for q in 0..channel_multiplier-1
+ output[b, i, j, k * channel_multiplier + q] =
+ sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
+ filter[di, dj, k, q]
+
+Must have `strides[0] = strides[3] = 1`. For the most common case of the same
+horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
+
+##### Args:
+
+
+* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
+* <b>`filter`</b>: A `Tensor`. Must have the same type as `input`.
+* <b>`strides`</b>: A list of `ints`.
+ 1-D of length 4. The stride of the sliding window for each dimension
+ of `input`.
+* <b>`padding`</b>: A `string` from: `"SAME", "VALID"`.
+ The type of padding algorithm to use.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor`. Has the same type as `input`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md
index 419a3a0757..dcbc71fb1b 100644
--- a/tensorflow/g3doc/api_docs/python/sparse_ops.md
+++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md
@@ -153,7 +153,7 @@ Alias for field number 1
-## Sparse to Dense Conversion
+## Conversion
- - -
@@ -255,8 +255,8 @@ tested if validate_indices is True.
Converts a `SparseTensor` of ids into a dense bool indicator tensor.
-The last dimension of `sp_input` is discarded and replaced with the values of
-`sp_input`. If `sp_input.shape = [D0, D1, ..., Dn, K]`, then
+The last dimension of `sp_input.indices` is discarded and replaced with
+the values of `sp_input`. If `sp_input.shape = [D0, D1, ..., Dn, K]`, then
`output.shape = [D0, D1, ..., Dn, vocab_size]`, where
output[d_0, d_1, ..., d_n, sp_input[d_0, d_1, ..., d_n, k]] = True
@@ -288,9 +288,10 @@ The input `SparseTensor` must be in row-major order.
##### Args:
-* <b>`sp_input`</b>: A `SparseTensor` of type `int32` or `int64`.
-* <b>`vocab_size`</b>: The new size of the last dimension, with
- `all(0 <= sp_input.values < vocab_size)`.
+* <b>`sp_input`</b>: A `SparseTensor` with `values` property of type `int32` or
+ `int64`.
+* <b>`vocab_size`</b>: A scalar int64 Tensor (or Python int) containing the new size
+ of the last dimension, `all(0 <= sp_input.values < vocab_size)`.
* <b>`name`</b>: A name prefix for the returned tensors (optional)
##### Returns:
@@ -303,6 +304,82 @@ The input `SparseTensor` must be in row-major order.
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
+- - -
+
+### `tf.sparse_merge(sp_ids, sp_values, vocab_size, name=None)` {#sparse_merge}
+
+Combines a batch of feature ids and values into a single `SparseTensor`.
+
+The most common use case for this function occurs when feature ids and
+their corresponding values are stored in `Example` protos on disk.
+`parse_example` will return a batch of ids and a batch of values, and this
+function joins them into a single logical `SparseTensor` for use in
+functions such as `sparse_tensor_dense_matmul`, `sparse_to_dense`, etc.
+
+The `SparseTensor` returned by this function has the following properties:
+
+ - `indices` is equivalent to `sp_ids.indices` with the last
+ dimension discarded and replaced with `sp_ids.values`.
+ - `values` is simply `sp_values.values`.
+ - If `sp_ids.shape = [D0, D1, ..., Dn, K]`, then
+ `output.shape = [D0, D1, ..., Dn, vocab_size]`.
+
+For example, consider the following feature vectors:
+
+ vector1 = [-3, 0, 0, 0, 0, 0]
+ vector2 = [ 0, 1, 0, 4, 1, 0]
+ vector3 = [ 5, 0, 0, 9, 0, 0]
+
+These might be stored sparsely in the following Example protos by storing
+only the feature ids (column number if the vectors are treated as a matrix)
+of the non-zero elements and the corresponding values:
+
+ examples = [Example(features={
+ "ids": Feature(int64_list=Int64List(value=[0])),
+ "values": Feature(float_list=FloatList(value=[-3]))}),
+ Example(features={
+ "ids": Feature(int64_list=Int64List(value=[1, 4, 3])),
+ "values": Feature(float_list=FloatList(value=[1, 1, 4]))}),
+ Example(features={
+ "ids": Feature(int64_list=Int64List(value=[0, 3])),
+ "values": Feature(float_list=FloatList(value=[5, 9]))})]
+
+The result of calling parse_example on these examples will produce a
+dictionary with entries for "ids" and "values". Passing those two objects
+to this function will produce a `SparseTensor` that sparsely represents
+all three instances. Namely, the `indices` property will contain
+the coordinates of the non-zero entries in the feature matrix (the first
+dimension is the row number in the matrix, i.e., the index within the batch,
+and the second dimension is the column number, i.e., the feature id);
+`values` will contain the actual values. `shape` will be the shape of the
+original matrix, i.e., (3, 7). For our example above, the output will be
+equal to:
+
+ SparseTensor(indices=[[0, 0], [1, 1], [1, 3], [1, 4], [2, 0], [2, 3]],
+ values=[-3, 1, 4, 1, 5, 9],
+ shape=[3, 7])
+
+##### Args:
+
+
+* <b>`sp_ids`</b>: A `SparseTensor` with `values` property of type `int32`
+ or `int64`.
+* <b>`sp_values`</b>: A`SparseTensor` of any type.
+* <b>`vocab_size`</b>: A scalar `int64` Tensor (or Python int) containing the new size
+ of the last dimension, `all(0 <= sp_ids.values < vocab_size)`.
+* <b>`name`</b>: A name prefix for the returned tensors (optional)
+
+##### Returns:
+
+ A `SparseTensor` compactly representing a batch of feature ids and values,
+ useful for passing to functions that expect such a `SparseTensor`.
+
+##### Raises:
+
+
+* <b>`TypeError`</b>: If `sp_ids` or `sp_values` are not a `SparseTensor`.
+
+
## Manipulation
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 860087939c..a7957bc474 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -462,7 +462,7 @@ collected in the graph.
### `tf.all_variables()` {#all_variables}
-Returns all variables collected in the graph.
+Returns all variables that must be saved/restored.
The `Variable()` constructor automatically adds new variables to the graph
collection `GraphKeys.VARIABLES`. This convenience function returns the
@@ -491,6 +491,17 @@ contents of that collection.
- - -
+### `tf.local_variables()` {#local_variables}
+
+Returns all variables created with collection=[LOCAL_VARIABLES].
+
+##### Returns:
+
+ A list of local Variable objects.
+
+
+- - -
+
### `tf.moving_average_variables()` {#moving_average_variables}
Returns all variables that maintain their moving averages.
@@ -548,6 +559,19 @@ be run. That Op just has no effect.
- - -
+### `tf.initialize_local_variables()` {#initialize_local_variables}
+
+Returns an Op that initializes all local variables.
+
+This is just a shortcut for `initialize_variables(local_variables())`
+
+##### Returns:
+
+ An Op that initializes all local variables in the graph.
+
+
+- - -
+
### `tf.assert_variables_initialized(var_list=None)` {#assert_variables_initialized}
Returns an Op to check if variables are initialized.
diff --git a/tensorflow/g3doc/api_docs/python/test.md b/tensorflow/g3doc/api_docs/python/test.md
index 1b5f409b1c..65963ea5d4 100644
--- a/tensorflow/g3doc/api_docs/python/test.md
+++ b/tensorflow/g3doc/api_docs/python/test.md
@@ -92,6 +92,18 @@ differentiation of graphs for comparison against registered analytic gradients.
Computes and returns the theoretical and numerical Jacobian.
+If `x` or `y` is complex, the Jacobian will still be real but the
+corresponding Jacobian dimension(s) will be twice as large. This is required
+even if both input and output is complex since TensorFlow graphs are not
+necessarily holomorphic, and may have gradients not expressible as complex
+numbers. For example, if `x` is complex with shape `[m]` and `y` is complex
+with shape `[n]`, each Jacobian `J` will have shape `[m * 2, n * 2]` with
+
+ J[:m, :n] = d(Re y)/d(Re x)
+ J[:m, n:] = d(Im y)/d(Re x)
+ J[m:, :n] = d(Re y)/d(Im x)
+ J[m:, n:] = d(Im y)/d(Im x)
+
##### Args:
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 2698230b11..5936b6e9c0 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <cstring>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/types.h"
@@ -474,13 +476,25 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
// requirements for tensorflow::Tensor. We hard code this here to
// avoid taking a dependency on Eigen in the client code.
void* data = tensorflow::cpu_allocator()->AllocateRaw(32, size);
+ if (tensorflow::LogMemory::IsEnabled()) {
+ LogMemory::RecordRawAllocation(
+ "Python session helper",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, size,
+ data, tensorflow::cpu_allocator());
+ }
std::memcpy(data, PyArray_DATA(array), size);
- inputs_safe.emplace_back(make_safe(
- TF_NewTensor(dtype, dims.data(), dims.size(), data, size,
- [](void* data, size_t len, void* arg) {
- tensorflow::cpu_allocator()->DeallocateRaw(data);
- },
- nullptr)));
+ inputs_safe.emplace_back(make_safe(TF_NewTensor(
+ dtype, dims.data(), dims.size(), data, size,
+ [](void* data, size_t len, void* arg) {
+ if (tensorflow::LogMemory::IsEnabled()) {
+ LogMemory::RecordRawDeallocation(
+ "Python session helper",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
+ data, tensorflow::cpu_allocator(), false);
+ }
+ tensorflow::cpu_allocator()->DeallocateRaw(data);
+ },
+ nullptr)));
// The destruction of the numpy array will now be handled by the
// inputs_safe destructor.
py_inputs_safe[i].reset();
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 043db5fd72..f0329a6730 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -264,6 +264,15 @@ def call_function(func_def, *inputs, **kwargs):
return op
+def _get_func_name(func):
+ if inspect.isfunction(func):
+ return func.__name__
+ elif inspect.ismethod(func):
+ return func.__self__.__name__ + "." + func.__name__
+ else:
+ raise ValueError("Argument must be a function")
+
+
def define_function(func, input_types):
"""Creates a `FunctionDef` for a python function.
@@ -301,6 +310,8 @@ def define_function(func, input_types):
# Create a FunctionDef for 'my_func'. (This does not change the default
graph.)
my_func_def = tf.define_function(my_func, {'x': tf.float32, 'y': tf.float32})
+ # Alternatively:
+ # my_func_def = tf.define_function(my_func, [tf.float32, tf.float32])
# Build the graph, calling the function.
a = tf.constant([1.0])
@@ -310,8 +321,9 @@ def define_function(func, input_types):
Args:
func: a Python function.
- input_types: dict. Keys are the names of the arguments of `func`, values
- are their expected `tf.DType`.
+ input_types: if a dict, keys are the names of the arguments of
+ `func`, values are their expected `tf.DType`. Otherwise,
+ a list of `tf.DType`s.
Returns:
A FunctionDef protocol buffer.
@@ -321,26 +333,39 @@ def define_function(func, input_types):
"""
# TODO(touts): Lift the limitation that func can only receive Tensor args.
- if inspect.isfunction(func):
- func_name = func.__name__
- elif inspect.ismethod(func):
- func_name = func.__self__.__name__ + "." + func.__name__
- else:
- raise ValueError("Argument must be a function")
+ func_name = _get_func_name(func)
+
argspec = inspect.getargspec(func)
- if argspec.varargs or argspec.keywords or argspec.defaults:
- raise ValueError("Only functions with plain arglists are supported.")
+ if argspec.keywords or argspec.defaults:
+ raise ValueError("Functions with argument defaults or keywards "
+ "arguments are not supported.")
if inspect.isfunction(func):
- if len(argspec.args) != len(input_types):
- raise ValueError("The function must have the same number of arguments "
- "as the number of specified input types.")
- args = argspec.args
+ if argspec.varargs and (
+ len(argspec.args) > len(input_types)) or not argspec.varargs and (
+ len(argspec.args) != len(input_types)):
+ raise ValueError("The function has fewer arguments "
+ "than the number of specified input types.")
+ argnames = argspec.args
elif inspect.ismethod(func):
- if len(argspec.args) != 1 + len(input_types):
- raise ValueError(
- "The class function must have the same number of arguments "
- "as the number of specified input types.")
- args = argspec.args[1:] # 1st argument is the "class" type.
+ if argspec.varargs and (
+ len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and (
+ len(argspec.args) != 1 + len(input_types)):
+ raise ValueError("The class function has fewer arguments "
+ "than the number of specified input types.")
+ # 1st argument is the "class" type.
+ argnames = argspec.args[1:]
+
+ args = []
+ if isinstance(input_types, (list, tuple)):
+ for i in range(len(input_types)):
+ argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
+ argtype = input_types[i]
+ args.append((argname, argtype))
+ else:
+ for name in argnames:
+ if name not in input_types:
+ raise ValueError("Missing type for argument: " + name)
+ args.append((name, input_types[name]))
# Create the func_def object.
temp_graph = ops.Graph()
@@ -349,14 +374,15 @@ def define_function(func, input_types):
inputs = []
# Arglist to call 'func'
kwargs = {}
- for argname in args:
- if argname not in input_types:
- raise ValueError("Missing type for argument: " + argname)
- argholder = array_ops.placeholder(input_types[argname], name=argname)
+ for (argname, argtype) in args:
+ argholder = array_ops.placeholder(argtype, name=argname)
inputs.append(argholder)
kwargs[argname] = argholder
# Call func and gather the output tensors.
- outputs = func(**kwargs)
+ if isinstance(input_types, (list, tuple)):
+ outputs = func(*inputs)
+ else:
+ outputs = func(**kwargs)
if not outputs:
raise ValueError("Function must return at least one tensor")
# Convenience: if func only returned one value, make it a tuple.
@@ -383,7 +409,7 @@ class Defun(object):
For example if the function to decorate accepts to `tf.float32` arguments
named `x` and `y`, call the decorator with:
- @Defun(x=tf.float32, y=tf.float32)
+ @Defun(tf.float32, tf.float32)
def foo(x, y):
...
@@ -393,7 +419,7 @@ class Defun(object):
```python
# Defining the function.
- @tf.Defun(x=tf.float32, y=tf.float32)
+ @tf.Defun(tf.float32, tf.float32)
def MyFunc(x, y):
return x + y, x - y
@@ -406,15 +432,22 @@ class Defun(object):
@@__init__
"""
- def __init__(self, **input_types):
+ def __init__(self, *input_type_list, **input_types):
"""Create a `Defun` decorator.
Args:
+ *input_type_list: A list of `tf.DType`
**input_types: Dict mapping string with `tf.DType`
One key for each argument of the function to decorate.
"""
+ assert not input_type_list or not input_types, (
+ "Can't specify both *input_type_list and **input_types")
self._input_types = input_types
+ self._input_type_list = input_type_list
def __call__(self, f):
- func_def = define_function(f, self._input_types)
+ if self._input_types:
+ func_def = define_function(f, self._input_types)
+ else:
+ func_def = define_function(f, self._input_type_list)
return lambda *args, **kwargs: call_function(func_def, *args, **kwargs)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 6728d31a25..48bc12237e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -131,7 +131,7 @@ class FunctionTest(tf.test.TestCase):
def testTanhSymGrad(self):
g = tf.Graph()
with g.as_default():
- @function.Defun(x=tf.float32)
+ @function.Defun(tf.float32)
def Forward(x):
return tf.reduce_sum(tf.tanh(x))
x = tf.placeholder(tf.float32)
@@ -200,9 +200,6 @@ class FunctionTest(tf.test.TestCase):
def NoResult():
pass
- def VarArgs(*unused_b):
- return tf.constant([1])
-
def DefaultArg(unused_a=12):
return tf.constant([1])
@@ -215,11 +212,9 @@ class FunctionTest(tf.test.TestCase):
with tf.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
function.define_function(NoResult, {})
- with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
- function.define_function(VarArgs, {})
- with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
+ with self.assertRaisesRegexp(ValueError, "are not supported"):
function.define_function(DefaultArg, {})
- with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
+ with self.assertRaisesRegexp(ValueError, "are not supported"):
function.define_function(KwArgs, {})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {})
@@ -278,7 +273,7 @@ class FunctionTest(tf.test.TestCase):
with tf.Graph().as_default():
- @function.Defun(b=tf.float32)
+ @function.Defun(tf.float32)
def Minus1(b):
return b - 1.0
@@ -296,11 +291,11 @@ class FunctionTest(tf.test.TestCase):
def testNestedFunction(self):
with tf.Graph().as_default():
- @function.Defun(x=tf.float32)
+ @function.Defun(tf.float32)
def Cube(x):
return x * x * x
- @function.Defun(x=tf.float32, y=tf.float32)
+ @function.Defun(tf.float32, tf.float32)
def CubeXPlusY(x, y):
return Cube(x) + y
@@ -358,7 +353,7 @@ class UnrollLSTMTest(tf.test.TestCase):
if mode == "loop":
# Wraps the whole loop as a function.
- @function.Defun(w=tf.float32, i=tf.float32)
+ @function.Defun(tf.float32, tf.float32)
def LSTMLoop(w, i):
return Loop(cell, w, i)
@@ -368,27 +363,15 @@ class UnrollLSTMTest(tf.test.TestCase):
# Wraps 10 lstm steps into one function, and the whole loop
# into another calling the formers.
- # Groups 10 steps at a time):
- # TODO(zhifengc): Any way to make the syntax less hideous?
- @function.Defun(m=tf.float32,
- c=tf.float32,
- w=tf.float32,
- x0=tf.float32,
- x1=tf.float32,
- x2=tf.float32,
- x3=tf.float32,
- x4=tf.float32,
- x5=tf.float32,
- x6=tf.float32,
- x7=tf.float32,
- x8=tf.float32,
- x9=tf.float32)
- def Loop10(w, m, c, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
- for x in [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]:
+ # Groups 10 steps at a time.
+ @function.Defun(tf.float32, tf.float32, tf.float32,
+ *([tf.float32] * 10))
+ def Loop10(w, m, c, *args):
+ for x in args:
m, c = cell(x, m, c, w)
return m, c
- @function.Defun(weights=tf.float32, inp=tf.float32)
+ @function.Defun(tf.float32, tf.float32)
def LSTMLoop10(weights, inp):
x = tf.unpack(inp, self.NUM_UNROLL)
m = tf.zeros_like(x[0])
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index f73ba8e980..863c9fcd1c 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -71,6 +71,24 @@ class PyOpTest(tf.test.TestCase):
y, = tf.py_func(literal, [x], [tf.float64])
self.assertAllClose(y.eval(), 1.0)
+ # returns a list
+ with self.test_session():
+ def list_func(x):
+ return [x, x + 1]
+ x = tf.constant(0.0, tf.float64)
+ y, z = tf.py_func(list_func, [x], [tf.float64] * 2)
+ self.assertAllClose(y.eval(), 0.0)
+ self.assertAllClose(z.eval(), 1.0)
+
+ # returns a tuple
+ with self.test_session():
+ def tuple_func(x):
+ return x, x + 1
+ x = tf.constant(0.0, tf.float64)
+ y, z = tf.py_func(tuple_func, [x], [tf.float64] * 2)
+ self.assertAllClose(y.eval(), 0.0)
+ self.assertAllClose(z.eval(), 1.0)
+
def testStrings(self):
def read_fixed_length_numpy_strings():
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index c69f7858dc..0c94eb8858 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -412,10 +412,10 @@ class AsyncReaderTest(tf.test.TestCase):
for i, d in enumerate(reversed(thread_data)):
fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
with open(fname, "wb") as f:
- f.write("file-%s" % i)
+ f.write(("file-%s" % i).encode())
d.queue.enqueue_many([[fname]]).run()
d.thread.join()
- self.assertEqual([["file-%s" % i]], d.output)
+ self.assertEqual([[("file-%s" % i).encode()]], d.output)
@staticmethod
def _RunSessionAndSave(sess, args, output):
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
index 7ad64f5c75..c15e8e4ed9 100644
--- a/tensorflow/python/ops/constant_op.py
+++ b/tensorflow/python/ops/constant_op.py
@@ -125,9 +125,8 @@ def constant(value, dtype=None, shape=None, name="Const"):
elements specified by `shape`, the last element in the list will be used
to fill the remaining entries.
- The argument `shape` is optional. If present, it specifies the dimensions
- of the resulting tensor. If not present, then the tensor is a scalar (0-D)
- if `value` is a scalar, or 1-D otherwise.
+ The argument `shape` is optional. If present, it specifies the dimensions of
+ the resulting tensor. If not present, the shape of `value` is used.
If the argument `dtype` is not specified, then the type is inferred from
the type of `value`.
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 7bac9a622a..07563e829a 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -59,9 +59,9 @@ class FuncRegistry(object):
if func is None:
raise ValueError("callback %s is not found" % token)
ret = func(*args)
- # Ensures that we return either a single np array or or a list of
- # np array.
- if isinstance(ret, list):
+ # Ensures that we return either a single numpy array or a list of numpy
+ # arrays.
+ if isinstance(ret, (tuple, list)):
ret = [np.array(x) for x in ret]
else:
ret = np.array(ret)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index c3833a2bf6..c5f235cff9 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1032,7 +1032,12 @@ class Saver(object):
Args:
sess: A `Session` to use to restore the parameters.
save_path: Path where parameters were previously saved.
+
+ Raises:
+ ValueError: If the given `save_path` does not point to a file.
"""
+ if not gfile.Glob(save_path):
+ raise ValueError("Restore called with invalid save path %s" % save_path)
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 85ac42421c..6023ab20f9 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -18,6 +18,7 @@ def tf_android_core_proto_sources():
"//tensorflow/core:framework/function.proto",
"//tensorflow/core:framework/graph.proto",
"//tensorflow/core:framework/kernel_def.proto",
+ "//tensorflow/core:framework/log_memory.proto",
"//tensorflow/core:framework/op_def.proto",
"//tensorflow/core:framework/step_stats.proto",
"//tensorflow/core:framework/summary.proto",
diff --git a/tensorflow/tools/docs/gen_cc_md.py b/tensorflow/tools/docs/gen_cc_md.py
index 8ad3eb80d2..389d8ef857 100644
--- a/tensorflow/tools/docs/gen_cc_md.py
+++ b/tensorflow/tools/docs/gen_cc_md.py
@@ -20,11 +20,9 @@ from __future__ import print_function
import os
import re
-import sys
from BeautifulSoup import BeautifulStoneSoup
-
-from tensorflow.python import flags
+import tensorflow as tf
ANCHOR_RE = re.compile(r'\W+')
@@ -95,11 +93,11 @@ write the graph to a file.
@@ThreadOptions
'''
-FLAGS = flags.FLAGS
-flags.DEFINE_string('src_dir', None,
- 'Directory containing the doxygen output.')
-flags.DEFINE_string('out_dir', None,
- 'Directory to which docs should be written.')
+FLAGS = tf.flags.FLAGS
+tf.flags.DEFINE_string('src_dir', None,
+ 'Directory containing the doxygen output.')
+tf.flags.DEFINE_string('out_dir', None,
+ 'Directory to which docs should be written.')
def member_definition(member_elt):
@@ -255,7 +253,6 @@ class Page(object):
self.name = soup.find('compoundname').text
print('Making page with name ' + self.name + ' (from ' + xml_path + ')')
members = soup('memberdef', prot='public')
- briefs = all_briefs(members)
fulls = all_fulls(members)
self.overview = page_overview(soup.find('compounddef'))
self.page_text = PAGE_TEMPLATE.format(
@@ -275,7 +272,8 @@ class Page(object):
return self.type
def get_md_filename(self):
- return self.get_type() + anchorize(self.get_short_name()) + '.md'
+ capitalized_type = self.get_type()[0].upper() + self.get_type()[1:]
+ return capitalized_type + anchorize(self.get_short_name()) + '.md'
def main(unused_argv):
@@ -303,4 +301,4 @@ def main(unused_argv):
return 0
if __name__ == '__main__':
- main(sys.argv)
+ tf.app.run()
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index e4a37b0ffb..ab05f11366 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -10,8 +10,8 @@ def tf_workspace(path_prefix = ""):
native.new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/73a4995594c6.tar.gz",
- sha256 = "7d7e9b47788b643f65678c3239787edbd049a15bc274c7861bd1f7dc6e265dc6",
+ url = "https://bitbucket.org/eigen/eigen/get/017cff30cf74.tar.gz",
+ sha256 = "c06ce36dc8fd740336c5b169ad2fa3dd587f2e4b8168be50656cf2c849649c7c",
build_file = path_prefix + "eigen.BUILD",
)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index 69ec33da61..b83f7d314d 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/Eigen/Cholesky"
+#include "eigen-eigen-017cff30cf74/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index 006358aea8..4561d530b6 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/Eigen/Core"
+#include "eigen-eigen-017cff30cf74/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 6977cfce35..24932d2234 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/Eigen/Eigenvalues"
+#include "eigen-eigen-017cff30cf74/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index d70cf3dc37..7670c26e6c 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/Eigen/LU"
+#include "eigen-eigen-017cff30cf74/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index 8137cb7a2d..afa99c9b6c 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/Eigen/QR"
+#include "eigen-eigen-017cff30cf74/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index 7abd359696..82ad592a3d 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-73a4995594c6/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-017cff30cf74/unsupported/Eigen/CXX11/Tensor"