aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-12 11:27:00 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-12 11:27:00 -0800
commit4dffee7f62d81ec9173aba1b0ef6b96e47f8037c (patch)
tree4b5c04b37afe45fdc5f1729252514a2770fbf1ab /tensorflow
parentf2102f4e2c1c87f1d1bf9ab856a2849c54478760 (diff)
TensorFlow: Minor updates to docs, BUILD, GPU config / perf, etc.
Changes: - Updates to op documentation and index by Josh - More changes to BUILD files for python 3 support by @girving - Fix to Eigen to use DenseIndex everywhere by @jiayq - Enable configuration for cuda compute capability by @zheng-xq, including updates to docs. - Route aggregation method through optimizer by schuster - Updates to install instructions for bazel 0.1.1. Base CL: 107702099
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/executor.cc25
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc57
-rw-r--r--tensorflow/core/framework/rendezvous.cc2
-rw-r--r--tensorflow/core/framework/tensor_slice.h11
-rw-r--r--tensorflow/core/kernels/concat_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc200
-rw-r--r--tensorflow/core/kernels/fifo_queue.h76
-rw-r--r--tensorflow/core/kernels/lrn_op.cc4
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h2
-rw-r--r--tensorflow/core/kernels/queue_base.cc265
-rw-r--r--tensorflow/core/kernels/queue_base.h78
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc274
-rw-r--r--tensorflow/core/kernels/slice_op.cc8
-rw-r--r--tensorflow/core/kernels/slice_op.h4
-rw-r--r--tensorflow/core/kernels/split_op.cc10
-rw-r--r--tensorflow/core/kernels/split_op.h8
-rw-r--r--tensorflow/core/kernels/split_op_cpu.cc4
-rw-r--r--tensorflow/core/kernels/split_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/tile_ops.cc48
-rw-r--r--tensorflow/core/kernels/tile_ops.h13
-rw-r--r--tensorflow/core/kernels/typed_queue.h54
-rw-r--r--tensorflow/core/kernels/unpack_op.cc4
-rw-r--r--tensorflow/g3doc/api_docs/cc/index.md29
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md2
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md18
-rw-r--r--tensorflow/g3doc/resources/index.md5
-rw-r--r--tensorflow/models/embedding/BUILD4
-rw-r--r--tensorflow/models/image/alexnet/BUILD1
-rw-r--r--tensorflow/models/image/cifar10/BUILD6
-rw-r--r--tensorflow/models/image/mnist/BUILD2
-rw-r--r--tensorflow/models/rnn/BUILD8
-rw-r--r--tensorflow/models/rnn/ptb/BUILD3
-rw-r--r--tensorflow/models/rnn/translate/BUILD4
-rw-r--r--tensorflow/python/BUILD42
-rw-r--r--tensorflow/python/framework/tensor_shape.py30
-rw-r--r--tensorflow/python/framework/tensor_shape_div_test.py24
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py6
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/training/optimizer.py17
-rw-r--r--tensorflow/python/training/optimizer_test.py54
-rw-r--r--tensorflow/tensorboard/BUILD4
-rw-r--r--tensorflow/tensorflow.bzl3
-rw-r--r--tensorflow/tools/docker/BUILD1
-rw-r--r--tensorflow/tools/pip_package/BUILD1
44 files changed, 734 insertions, 687 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 7f2473f93b..2d5a63ac92 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -294,6 +294,31 @@ Status ExecutorImpl::InferAllocAttr(
const DeviceNameUtils::ParsedName& local_dev_name,
AllocatorAttributes* attr) {
Status s;
+ // Note that it's possible for *n to be a Recv and *dst to be a Send,
+ // so these two cases are not mutually exclusive.
+ if (IsRecv(n)) {
+ string src_name;
+ s = GetNodeAttr(n->def(), "send_device", &src_name);
+ if (!s.ok()) return s;
+ DeviceNameUtils::ParsedName parsed_src_name;
+ if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
+ s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
+ n->name());
+ return s;
+ }
+ if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
+ // Value is going to be the sink of an RPC.
+ attr->set_nic_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
+ } else if (local_dev_name.type == "CPU" && parsed_src_name.type == "GPU") {
+ // Value is going to be the sink of a local DMA from GPU to CPU.
+ attr->set_gpu_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
+ } else {
+ VLOG(2) << "default alloc case local type " << local_dev_name.type
+ << " remote type " << parsed_src_name.type;
+ }
+ }
if (IsSend(dst)) {
string dst_name;
s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 65174135d8..b6bae7c0f8 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -8,6 +8,7 @@
#include <stdlib.h>
#include <string.h>
+#include <algorithm>
//#include "base/commandlineflags.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
@@ -590,10 +591,50 @@ static int GetMinGPUMultiprocessorCount() {
return kDefaultMinGPUMultiprocessorCount;
}
+namespace {
+
+struct CudaVersion {
+ // Initialize from version_name in the form of "3.5"
+ explicit CudaVersion(const std::string& version_name) {
+ size_t dot_pos = version_name.find('.');
+ CHECK(dot_pos != string::npos);
+ string major_str = version_name.substr(0, dot_pos);
+ CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
+ string minor_str = version_name.substr(dot_pos + 1);
+ CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
+ }
+ CudaVersion() {}
+ bool operator<(const CudaVersion& other) const {
+ if (this->major_part != other.major_part) {
+ return this->major_part < other.major_part;
+ }
+ return this->minor_part < other.minor_part;
+ }
+ friend std::ostream& operator<<(std::ostream& os,
+ const CudaVersion& version) {
+ os << version.major_part << "." << version.minor_part;
+ return os;
+ }
+ int major_part = -1;
+ int minor_part = -1;
+};
+
+// "configure" uses the specific name to substitute the following string.
+// If you change it, make sure you modify "configure" as well.
+std::vector<CudaVersion> supported_cuda_compute_capabilities = {
+ CudaVersion("3.5"), CudaVersion("5.2")};
+
+} // namespace
+
void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
auto gpu_manager = GPUMachineManager();
int min_gpu_core_count = GetMinGPUMultiprocessorCount();
if (gpu_manager) {
+ CHECK(!supported_cuda_compute_capabilities.empty());
+ CudaVersion min_supported_capability =
+ *std::min_element(supported_cuda_compute_capabilities.begin(),
+ supported_cuda_compute_capabilities.end());
+
auto visible_device_count = gpu_manager->VisibleDeviceCount();
for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
auto exec_status = gpu_manager->ExecutorForDevice(i);
@@ -602,17 +643,19 @@ void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
}
gpu::StreamExecutor* se = exec_status.ValueOrDie();
const gpu::DeviceDescription& desc = se->GetDeviceDescription();
- int major, minor;
- if (!desc.cuda_compute_capability(&major, &minor)) {
+ CudaVersion device_capability;
+ if (!desc.cuda_compute_capability(&device_capability.major_part,
+ &device_capability.minor_part)) {
continue;
}
- // Only consider GPUs with compute capability >= 3.5 (Kepler or
- // higher)
- if (major < 3 || (major == 3 && minor < 5)) {
+ // Only GPUs with no less than the minimum supported compute capability is
+ // accepted.
+ if (device_capability < min_supported_capability) {
LOG(INFO) << "Ignoring gpu device "
<< "(" << GetShortDeviceDescription(i, desc) << ") "
- << "with Cuda compute capability " << major << "." << minor
- << ". The minimum required Cuda capability is 3.5.";
+ << "with Cuda compute capability " << device_capability
+ << ". The minimum required Cuda capability is "
+ << min_supported_capability << ".";
continue;
}
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
index 7f551ea65f..205098d58f 100644
--- a/tensorflow/core/framework/rendezvous.cc
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -188,9 +188,9 @@ class LocalRendezvousImpl : public Rendezvous {
// message arrives.
Item* item = new Item;
item->waiter = done;
+ item->recv_alloc_attrs = recv_args.alloc_attrs;
if (recv_args.device_context) {
item->recv_dev_context = recv_args.device_context;
- item->recv_alloc_attrs = recv_args.alloc_attrs;
item->recv_dev_context->Ref();
}
CHECK(table_.insert({key, item}).second);
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
index 8e2f108c3f..62e1543789 100644
--- a/tensorflow/core/framework/tensor_slice.h
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -98,9 +98,10 @@ class TensorSlice {
// We allow NDIMS to be greater than dims(), in which case we will pad the
// higher dimensions with trivial dimensions.
template <int NDIMS>
- void FillIndicesAndSizes(const TensorShape& shape,
- Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
- Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
+ void FillIndicesAndSizes(
+ const TensorShape& shape,
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;
// Interaction with other TensorSlices.
@@ -162,8 +163,8 @@ class TensorSlice {
template <int NDIMS>
void TensorSlice::FillIndicesAndSizes(
- const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
- Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
+ const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
<< "slices: shape = " << shape.DebugString()
<< ", slice = " << DebugString();
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
index d8ce6bd85d..aed36dccef 100644
--- a/tensorflow/core/kernels/concat_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -18,9 +18,9 @@ void ConcatGPU(const GPUDevice& d,
const std::vector<
std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
typename TTypes<T, 2>::Matrix* output) {
- Eigen::array<ptrdiff_t, 2> offset(0, 0);
+ Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
for (int i = 0; i < inputs.size(); ++i) {
- Eigen::array<ptrdiff_t, 2> size = inputs[i]->dimensions();
+ Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
output->slice(offset, size).device(d) = *inputs[i];
offset[1] += size[1];
}
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index 122c3c1c81..9828b460da 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -17,74 +17,7 @@ namespace tensorflow {
FIFOQueue::FIFOQueue(int capacity, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name)
- : QueueBase(component_dtypes, component_shapes, name),
- capacity_(capacity),
- closed_(false) {}
-
-Status FIFOQueue::Initialize() {
- if (component_dtypes_.empty()) {
- return errors::InvalidArgument("Empty component types for queue ", name_);
- }
- if (!component_shapes_.empty() &&
- component_dtypes_.size() != component_shapes_.size()) {
- return errors::InvalidArgument("Different number of component types (",
- component_dtypes_.size(), ") vs. shapes (",
- component_shapes_.size(), ").");
- }
-
- mutex_lock lock(mu_);
- queues_.reserve(num_components());
- for (int i = 0; i < num_components(); ++i) {
- queues_.push_back(SubQueue());
- }
- return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-// reduce the number of times that they are called.
-Status FIFOQueue::ValidateTuple(const Tuple& tuple) {
- TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
- if (specified_shapes()) {
- for (size_t i = 0; i < tuple.size(); ++i) {
- if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
- return errors::InvalidArgument(
- "Shape mismatch in tuple component ", i, ". Expected ",
- component_shapes_[i].ShortDebugString(), ", got ",
- tuple[i].shape().ShortDebugString());
- }
- }
- }
- return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-// reduce the number of times that they are called.
-Status FIFOQueue::ValidateManyTuple(const Tuple& tuple) {
- TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
- const int64 batch_size = tuple[0].dim_size(0);
- if (specified_shapes()) {
- for (size_t i = 0; i < tuple.size(); ++i) {
- // Expected shape is [batch_size] + component_shapes_[i]
- const TensorShape expected_shape = ManyOutShape(i, batch_size);
- if (!tuple[i].shape().IsSameSize(expected_shape)) {
- return errors::InvalidArgument(
- "Shape mismatch in tuple component ", i, ". Expected ",
- expected_shape.ShortDebugString(), ", got ",
- tuple[i].shape().ShortDebugString());
- }
- }
- } else {
- for (size_t i = 1; i < tuple.size(); ++i) {
- if (tuple[i].dim_size(0) != batch_size) {
- return errors::InvalidArgument(
- "All input tensors must have the same size in the 0th ",
- "dimension. Component ", i, " has ", tuple[i].dim_size(0),
- ", and should have ", batch_size);
- }
- }
- }
- return Status::OK();
-}
+ : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
DCHECK_GT(queues_[0].size(), 0);
@@ -95,113 +28,6 @@ void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
}
}
-void FIFOQueue::Cancel(Action action, CancellationToken token) {
- DoneCallback callback = nullptr;
- {
- mutex_lock lock(mu_);
- std::deque<Attempt>* attempts =
- action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
- for (Attempt& attempt : *attempts) {
- if (attempt.cancellation_token == token) {
- attempt.is_cancelled = true;
- if (action == kEnqueue) {
- attempt.context->SetStatus(
- errors::Cancelled("Enqueue operation was cancelled"));
- } else {
- attempt.context->SetStatus(
- errors::Cancelled("Dequeue operation was cancelled"));
- }
- std::swap(callback, attempt.done_callback);
- break;
- }
- }
- }
- if (callback) {
- callback();
- FlushUnlocked();
- }
-}
-
-void FIFOQueue::CloseAndCancel() {
- std::vector<DoneCallback> callbacks;
- {
- mutex_lock lock(mu_);
- closed_ = true;
- for (Attempt& attempt : enqueue_attempts_) {
- attempt.is_cancelled = true;
- attempt.context->SetStatus(
- errors::Cancelled("Enqueue operation was cancelled"));
- callbacks.emplace_back(std::move(attempt.done_callback));
- }
- }
- for (const DoneCallback& callback : callbacks) {
- callback();
- }
- FlushUnlocked();
-}
-
-bool FIFOQueue::TryAttemptLocked(Action action,
- std::vector<CleanUp>* clean_up) {
- std::deque<Attempt>* attempts =
- action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
- bool progress = false;
- bool done = false;
- while (!done && !attempts->empty()) {
- if (attempts->front().is_cancelled) {
- if (action == kEnqueue) {
- LOG(INFO) << "Skipping cancelled enqueue attempt";
- } else {
- LOG(INFO) << "Skipping cancelled dequeue attempt";
- }
- attempts->pop_front();
- } else {
- Attempt* cur_attempt = &attempts->front();
- switch (cur_attempt->run_callback(cur_attempt)) {
- case kNoProgress:
- done = true;
- break;
- case kProgress:
- done = true;
- progress = true;
- break;
- case kComplete:
- progress = true;
- clean_up->emplace_back(std::move(cur_attempt->done_callback),
- cur_attempt->cancellation_token,
- cur_attempt->context->cancellation_manager());
- attempts->pop_front();
- break;
- }
- }
- }
- return progress;
-}
-
-void FIFOQueue::FlushUnlocked() {
- std::vector<CleanUp> clean_up;
- Ref();
- {
- mutex_lock lock(mu_);
- bool changed;
- do {
- changed = TryAttemptLocked(kEnqueue, &clean_up);
- changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
- } while (changed);
- }
- Unref();
- for (const auto& to_clean : clean_up) {
- if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
- // NOTE(mrry): We can safely ignore the return value of
- // DeregisterCallback because the mutex mu_ ensures that the
- // cleanup action only executes once.
- to_clean.cm->DeregisterCallback(to_clean.to_deregister);
- }
- to_clean.finished();
- }
-}
-
void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
DoneCallback callback) {
CancellationManager* cm = ctx->cancellation_manager();
@@ -484,30 +310,6 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
}
}
-void FIFOQueue::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
- DoneCallback callback) {
- if (cancel_pending_enqueues) {
- CloseAndCancel();
- callback();
- } else {
- {
- mutex_lock lock(mu_);
- enqueue_attempts_.emplace_back(
- 0, callback, ctx, CancellationManager::kInvalidToken,
- [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (closed_) {
- attempt->context->SetStatus(errors::Aborted(
- "FIFOQueue '", name_, "' is already closed."));
- } else {
- closed_ = true;
- }
- return kComplete;
- });
- }
- FlushUnlocked();
- }
-}
-
Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue"));
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index e9fe5f34a4..4fc0ed75d2 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -6,24 +6,21 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
-class FIFOQueue : public QueueBase {
+class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
public:
FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name);
- Status Initialize(); // Must be called before any other method.
// Implementations of QueueInterface methods --------------------------------
- Status ValidateTuple(const Tuple& tuple) override;
- Status ValidateManyTuple(const Tuple& tuple) override;
void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
DoneCallback callback) override;
void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
@@ -31,8 +28,6 @@ class FIFOQueue : public QueueBase {
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
CallbackWithTuple callback) override;
- void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
- DoneCallback callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
int32 size() override {
@@ -40,80 +35,13 @@ class FIFOQueue : public QueueBase {
return queues_[0].size();
}
- int32 capacity() const { return capacity_; }
-
private:
- enum Action { kEnqueue, kDequeue };
-
~FIFOQueue() override {}
- TensorShape ManyOutShape(int i, int64 batch_size) {
- TensorShape shape({batch_size});
- shape.AppendShape(component_shapes_[i]);
- return shape;
- }
-
// Helper for dequeuing a single element from queues_.
void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
- void Cancel(Action action, CancellationToken token);
-
- // Helper for cancelling all pending Enqueue(Many) operations when
- // Close is called with cancel_pending_enqueues.
- void CloseAndCancel();
-
- // Tries to enqueue/dequeue (or close) based on whatever is at the
- // front of enqueue_attempts_/dequeue_attempts_. Appends to
- // *finished the callback for any finished attempt (so it may be
- // called once mu_ is released). Returns true if any progress was
- // made.
- struct CleanUp {
- CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
- : finished(f), to_deregister(ct), cm(cm) {}
- DoneCallback finished;
- CancellationToken to_deregister;
- CancellationManager* cm;
- };
- bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Tries to make progress on the enqueues or dequeues at the front
- // of the *_attempts_ queues.
- void FlushUnlocked();
-
- const int32 capacity_;
-
- mutex mu_;
- typedef std::deque<PersistentTensor> SubQueue;
- std::vector<SubQueue> queues_ GUARDED_BY(mu_);
- bool closed_ GUARDED_BY(mu_);
-
- enum RunResult { kNoProgress, kProgress, kComplete };
- struct Attempt;
- typedef std::function<RunResult(Attempt*)> RunCallback;
- struct Attempt {
- int32 elements_requested;
- DoneCallback done_callback; // must be run outside mu_
- OpKernelContext* context;
- CancellationToken cancellation_token;
- RunCallback run_callback; // must be run while holding mu_
- bool is_cancelled;
- Tuple tuple;
-
- Attempt(int32 elements_requested, DoneCallback done_callback,
- OpKernelContext* context, CancellationToken cancellation_token,
- RunCallback run_callback)
- : elements_requested(elements_requested),
- done_callback(done_callback),
- context(context),
- cancellation_token(cancellation_token),
- run_callback(run_callback),
- is_cancelled(false) {}
- };
- std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
- std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
-
static Status GetElementComponentFromBatch(const Tuple& tuple, int index,
int component,
OpKernelContext* ctx,
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index 4b67304a37..bb2657085e 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -23,8 +23,8 @@ static void GetBandMatrix(int depth, int64 depth_radius,
for (int row = 0; row < depth; ++row) {
const int begin = std::max<int>(0, row - depth_radius);
const int end = std::min<int64>(depth, row + depth_radius + 1);
- Eigen::DSizes<ptrdiff_t, 2> start(row, begin);
- Eigen::DSizes<ptrdiff_t, 2> sizes(1, end - begin);
+ Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
result->slice(start, sizes).setConstant(1.0f);
}
}
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index 5bf44b6e40..d086b6850e 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -243,7 +243,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
std::min(wpad / params.col_stride + 1, params.out_width);
const int in_offset =
(b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
- Eigen::DSizes<ptrdiff_t, 2> in_indices(0, in_offset);
+ Eigen::DSizes<Eigen::DenseIndex, 2> in_indices(0, in_offset);
for (int ph = h_start; ph < h_end; ++ph) {
for (int pw = w_start; pw < w_end; ++pw) {
const int out_offset =
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 4217b9ce86..d0f47505c4 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -46,52 +46,14 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
} // namespace
-// static
-Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
- int index) {
-#define HANDLE_TYPE(DT) \
- if (parent.dtype() == DT) { \
- TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
- return Status::OK(); \
- }
- HANDLE_TYPE(DT_FLOAT);
- HANDLE_TYPE(DT_DOUBLE);
- HANDLE_TYPE(DT_INT32);
- HANDLE_TYPE(DT_UINT8);
- HANDLE_TYPE(DT_INT16);
- HANDLE_TYPE(DT_INT8);
- HANDLE_TYPE(DT_STRING);
- HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", parent.dtype());
-}
-
-// static
-Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
- int index) {
-#define HANDLE_TYPE(DT) \
- if (element.dtype() == DT) { \
- TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
- return Status::OK(); \
- }
- HANDLE_TYPE(DT_FLOAT);
- HANDLE_TYPE(DT_DOUBLE);
- HANDLE_TYPE(DT_INT32);
- HANDLE_TYPE(DT_UINT8);
- HANDLE_TYPE(DT_INT16);
- HANDLE_TYPE(DT_INT8);
- HANDLE_TYPE(DT_STRING);
- HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", element.dtype());
-}
-
-QueueBase::QueueBase(const DataTypeVector& component_dtypes,
+QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name)
- : component_dtypes_(component_dtypes),
+ : capacity_(capacity),
+ component_dtypes_(component_dtypes),
component_shapes_(component_shapes),
- name_(name) {}
+ name_(name),
+ closed_(false) {}
Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
if (tuple.size() != static_cast<size_t>(num_components())) {
@@ -172,4 +134,221 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
return Status::OK();
}
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status QueueBase::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ component_shapes_[i].ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + component_shapes_[i]
+ const TensorShape expected_shape = ManyOutShape(i, batch_size);
+ if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ expected_shape.ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ } else {
+ for (size_t i = 1; i < tuple.size(); ++i) {
+ if (tuple[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All input tensors must have the same size in the 0th ",
+ "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+ ", and should have ", batch_size);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void QueueBase::Cancel(Action action, CancellationToken token) {
+ DoneCallback callback = nullptr;
+ {
+ mutex_lock lock(mu_);
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ for (Attempt& attempt : *attempts) {
+ if (attempt.cancellation_token == token) {
+ attempt.is_cancelled = true;
+ if (action == kEnqueue) {
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ } else {
+ attempt.context->SetStatus(
+ errors::Cancelled("Dequeue operation was cancelled"));
+ }
+ std::swap(callback, attempt.done_callback);
+ break;
+ }
+ }
+ }
+ if (callback) {
+ callback();
+ FlushUnlocked();
+ }
+}
+
+void QueueBase::CloseAndCancel() {
+ std::vector<DoneCallback> callbacks;
+ {
+ mutex_lock lock(mu_);
+ closed_ = true;
+ for (Attempt& attempt : enqueue_attempts_) {
+ attempt.is_cancelled = true;
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ callbacks.emplace_back(std::move(attempt.done_callback));
+ }
+ }
+ for (const DoneCallback& callback : callbacks) {
+ callback();
+ }
+ FlushUnlocked();
+}
+
+void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) {
+ if (cancel_pending_enqueues) {
+ CloseAndCancel();
+ callback();
+ } else {
+ {
+ mutex_lock lock(mu_);
+ enqueue_attempts_.emplace_back(
+ 0, callback, ctx, CancellationManager::kInvalidToken,
+ [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(
+ errors::Aborted("Queue '", name_, "' is already closed."));
+ } else {
+ closed_ = true;
+ }
+ return kComplete;
+ });
+ }
+ FlushUnlocked();
+ }
+}
+
+bool QueueBase::TryAttemptLocked(Action action,
+ std::vector<CleanUp>* clean_up) {
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ bool progress = false;
+ bool done = false;
+ while (!done && !attempts->empty()) {
+ if (attempts->front().is_cancelled) {
+ if (action == kEnqueue) {
+ LOG(INFO) << "Skipping cancelled enqueue attempt";
+ } else {
+ LOG(INFO) << "Skipping cancelled dequeue attempt";
+ }
+ attempts->pop_front();
+ } else {
+ Attempt* cur_attempt = &attempts->front();
+ switch (cur_attempt->run_callback(cur_attempt)) {
+ case kNoProgress:
+ done = true;
+ break;
+ case kProgress:
+ done = true;
+ progress = true;
+ break;
+ case kComplete:
+ progress = true;
+ clean_up->emplace_back(std::move(cur_attempt->done_callback),
+ cur_attempt->cancellation_token,
+ cur_attempt->context->cancellation_manager());
+ attempts->pop_front();
+ break;
+ }
+ }
+ }
+ return progress;
+}
+
+void QueueBase::FlushUnlocked() {
+ std::vector<CleanUp> clean_up;
+ Ref();
+ {
+ mutex_lock lock(mu_);
+ bool changed;
+ do {
+ changed = TryAttemptLocked(kEnqueue, &clean_up);
+ changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+ } while (changed);
+ }
+ Unref();
+ for (const auto& to_clean : clean_up) {
+ if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+ // NOTE(mrry): We can safely ignore the return value of
+ // DeregisterCallback because the mutex mu_ ensures that the
+ // cleanup action only executes once.
+ to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+ }
+ to_clean.finished();
+ }
+}
+
+// Static method
+Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (parent.dtype() == DT) { \
+ TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", parent.dtype());
+}
+
+// Static method
+Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (element.dtype() == DT) { \
+ TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", element.dtype());
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 4897102974..d32d98b7eb 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -1,6 +1,9 @@
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
+#include <deque>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/types.h"
@@ -11,7 +14,7 @@
namespace tensorflow {
-// Functionality common to QueueInterface implementations.
+// Functionality common to asynchronous QueueInterface implementations.
class QueueBase : public QueueInterface {
public:
// As a possible value of 'capacity'.
@@ -23,7 +26,7 @@ class QueueBase : public QueueInterface {
// which must either be empty (if the shapes are not specified) or
// or have the same size as component_dtypes.
// name: A name to use for the queue.
- QueueBase(const DataTypeVector& component_dtypes,
+ QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name);
@@ -32,12 +35,36 @@ class QueueBase : public QueueInterface {
return component_dtypes_;
}
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status ValidateManyTuple(const Tuple& tuple) override;
+
+ void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) override;
+
// Other public methods -----------------------------------------------------
const std::vector<TensorShape>& component_shapes() const {
return component_shapes_;
}
+ int32 capacity() const { return capacity_; }
+
protected:
+ enum Action { kEnqueue, kDequeue };
+ enum RunResult { kNoProgress, kProgress, kComplete };
+
+ // Tries to enqueue/dequeue (or close) based on whatever is at the
+ // front of enqueue_attempts_/dequeue_attempts_. Appends to
+ // *finished the callback for any finished attempt (so it may be
+ // called once mu_ is released). Returns true if any progress was
+ // made.
+ struct CleanUp {
+ CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+ : finished(f), to_deregister(ct), cm(cm) {}
+ DoneCallback finished;
+ CancellationToken to_deregister;
+ CancellationManager* cm;
+ };
+
// Returns the number of components in a queue-element tuple.
int32 num_components() const { return component_dtypes_.size(); }
@@ -48,6 +75,12 @@ class QueueBase : public QueueInterface {
// Code common to Validate*Tuple().
Status ValidateTupleCommon(const Tuple& tuple) const;
+ TensorShape ManyOutShape(int i, int64 batch_size) {
+ TensorShape shape({batch_size});
+ shape.AppendShape(component_shapes_[i]);
+ return shape;
+ }
+
// Copies the index^th slice (in the first dimension) of parent into element.
static Status CopySliceToElement(const Tensor& parent, Tensor* element,
int index);
@@ -56,6 +89,19 @@ class QueueBase : public QueueInterface {
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int index);
+ void Cancel(Action action, CancellationToken token);
+
+ // Helper for cancelling all pending Enqueue(Many) operations when
+ // Close is called with cancel_pending_enqueues.
+ void CloseAndCancel();
+
+ bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Tries to make progress on the enqueues or dequeues at the front
+ // of the *_attempts_ queues.
+ void FlushUnlocked();
+
~QueueBase() override {}
// Helpers for implementing MatchesNodeDef().
@@ -65,9 +111,37 @@ class QueueBase : public QueueInterface {
Status MatchesNodeDefTypes(const NodeDef& node_def) const;
Status MatchesNodeDefShapes(const NodeDef& node_def) const;
+ protected:
+ const int32 capacity_;
const DataTypeVector component_dtypes_;
const std::vector<TensorShape> component_shapes_;
const string name_;
+ mutex mu_;
+ bool closed_ GUARDED_BY(mu_);
+
+ struct Attempt;
+ typedef std::function<RunResult(Attempt*)> RunCallback;
+ struct Attempt {
+ int32 elements_requested;
+ DoneCallback done_callback; // must be run outside mu_
+ OpKernelContext* context;
+ CancellationToken cancellation_token;
+ RunCallback run_callback; // must be run while holding mu_
+ bool is_cancelled;
+ Tuple tuple;
+
+ Attempt(int32 elements_requested, DoneCallback done_callback,
+ OpKernelContext* context, CancellationToken cancellation_token,
+ RunCallback run_callback)
+ : elements_requested(elements_requested),
+ done_callback(done_callback),
+ context(context),
+ cancellation_token(cancellation_token),
+ run_callback(run_callback),
+ is_cancelled(false) {}
+ };
+ std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+ std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
};
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index 561ec76e53..0723e4fc61 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -6,7 +6,7 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h"
@@ -19,18 +19,16 @@
namespace tensorflow {
-class RandomShuffleQueue : public QueueBase {
+class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
public:
RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
int64 seed2, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name);
- Status Initialize(); // Must be called before any other method.
- // Implementations of QueueInterface methods --------------------------------
+ Status Initialize() override; // Must be called before any other method.
- Status ValidateTuple(const Tuple& tuple) override;
- Status ValidateManyTuple(const Tuple& tuple) override;
+ // Implementations of QueueInterface methods --------------------------------
void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
DoneCallback callback) override;
void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
@@ -38,8 +36,6 @@ class RandomShuffleQueue : public QueueBase {
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
CallbackWithTuple callback) override;
- void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
- DoneCallback callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
int32 size() override {
@@ -48,95 +44,30 @@ class RandomShuffleQueue : public QueueBase {
}
private:
- enum Action { kEnqueue, kDequeue };
-
~RandomShuffleQueue() override {}
- TensorShape ManyOutShape(int i, int batch_size) {
- TensorShape shape({batch_size});
- shape.AppendShape(component_shapes_[i]);
- return shape;
- }
-
// Helper for dequeuing a single random element from queues_.
void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
- void Cancel(Action action, CancellationToken token);
-
- // Helper for cancelling all pending Enqueue(Many) operations when
- // Close is called with cancel_pending_enqueues.
- void CloseAndCancel();
-
- // Tries to enqueue/dequeue (or close) based on whatever is at the
- // front of enqueue_attempts_/dequeue_attempts_. Appends to
- // *finished the callback for any finished attempt (so it may be
- // called once mu_ is released). Returns true if any progress was
- // made.
- struct CleanUp {
- CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
- : finished(f), to_deregister(ct), cm(cm) {}
- DoneCallback finished;
- CancellationToken to_deregister;
- CancellationManager* cm;
- };
- bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Tries to make progress on the enqueues or dequeues at the front
- // of the *_attempts_ queues.
- void FlushUnlocked();
-
- const int32 capacity_;
const int32 min_after_dequeue_;
const int64 original_seed_;
const int64 original_seed2_;
- mutex mu_;
- typedef std::vector<PersistentTensor> SubQueue;
- std::vector<SubQueue> queues_ GUARDED_BY(mu_);
- bool closed_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
- enum RunResult { kNoProgress, kProgress, kComplete };
- struct Attempt;
- typedef std::function<RunResult(Attempt*)> RunCallback;
- struct Attempt {
- int32 elements_requested;
- DoneCallback done_callback; // must be run outside mu_
- OpKernelContext* context;
- CancellationToken cancellation_token;
- RunCallback run_callback; // must be run while holding mu_
- bool is_cancelled;
- Tuple tuple;
-
- Attempt(int32 elements_requested, DoneCallback done_callback,
- OpKernelContext* context, CancellationToken cancellation_token,
- RunCallback run_callback)
- : elements_requested(elements_requested),
- done_callback(done_callback),
- context(context),
- cancellation_token(cancellation_token),
- run_callback(run_callback),
- is_cancelled(false) {}
- };
- std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
- std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
-
TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
};
RandomShuffleQueue::RandomShuffleQueue(
- int capacity, int min_after_dequeue, int64 seed, int64 seed2,
+ int32 capacity, int32 min_after_dequeue, int64 seed, int64 seed2,
const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes, const string& name)
- : QueueBase(component_dtypes, component_shapes, name),
- capacity_(capacity),
+ : TypedQueue(capacity, component_dtypes, component_shapes, name),
min_after_dequeue_(min_after_dequeue),
original_seed_(seed),
original_seed2_(seed2),
- closed_(false),
generator_(&parent_generator_) {
if (seed == 0 && seed2 == 0) {
// If both seeds are unspecified, use completely random seeds.
@@ -147,71 +78,16 @@ RandomShuffleQueue::RandomShuffleQueue(
}
Status RandomShuffleQueue::Initialize() {
- if (component_dtypes_.empty()) {
- return errors::InvalidArgument("Empty component types for queue ", name_);
- }
- if (!component_shapes_.empty() &&
- component_dtypes_.size() != component_shapes_.size()) {
- return errors::InvalidArgument("Different number of component types (",
- component_dtypes_.size(), ") vs. shapes (",
- component_shapes_.size(), ").");
- }
+ Status s = TypedQueue::Initialize();
+ if (!s.ok()) return s;
mutex_lock lock(mu_);
- queues_.reserve(num_components());
for (int i = 0; i < num_components(); ++i) {
- queues_.push_back(SubQueue());
queues_.back().reserve(min_after_dequeue_);
}
return Status::OK();
}
-// TODO(mrry): If these checks become a bottleneck, find a way to
-// reduce the number of times that they are called.
-Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) {
- TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
- if (specified_shapes()) {
- for (size_t i = 0; i < tuple.size(); ++i) {
- if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
- return errors::InvalidArgument(
- "Shape mismatch in tuple component ", i, ". Expected ",
- component_shapes_[i].ShortDebugString(), ", got ",
- tuple[i].shape().ShortDebugString());
- }
- }
- }
- return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-// reduce the number of times that they are called.
-Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) {
- TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
- const int64 batch_size = tuple[0].dim_size(0);
- if (specified_shapes()) {
- for (size_t i = 0; i < tuple.size(); ++i) {
- // Expected shape is [batch_size] + component_shapes_[i]
- const TensorShape expected_shape = ManyOutShape(i, batch_size);
- if (!tuple[i].shape().IsSameSize(expected_shape)) {
- return errors::InvalidArgument(
- "Shape mismatch in tuple component ", i, ". Expected ",
- expected_shape.ShortDebugString(), ", got ",
- tuple[i].shape().ShortDebugString());
- }
- }
- } else {
- for (size_t i = 1; i < tuple.size(); ++i) {
- if (tuple[i].dim_size(0) != batch_size) {
- return errors::InvalidArgument(
- "All input tensors must have the same size in the 0th ",
- "dimension. Component ", i, " has ", tuple[i].dim_size(0),
- ", and should have ", batch_size);
- }
- }
- }
- return Status::OK();
-}
-
void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
DCHECK_GT(queues_[0].size(), 0);
int64 index = generator_() % queues_[0].size();
@@ -223,113 +99,6 @@ void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
}
}
-void RandomShuffleQueue::Cancel(Action action, CancellationToken token) {
- DoneCallback callback = nullptr;
- {
- mutex_lock lock(mu_);
- std::deque<Attempt>* attempts =
- action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
- for (Attempt& attempt : *attempts) {
- if (attempt.cancellation_token == token) {
- attempt.is_cancelled = true;
- if (action == kEnqueue) {
- attempt.context->SetStatus(
- errors::Cancelled("Enqueue operation was cancelled"));
- } else {
- attempt.context->SetStatus(
- errors::Cancelled("Dequeue operation was cancelled"));
- }
- std::swap(callback, attempt.done_callback);
- break;
- }
- }
- }
- if (callback) {
- callback();
- FlushUnlocked();
- }
-}
-
-void RandomShuffleQueue::CloseAndCancel() {
- std::vector<DoneCallback> callbacks;
- {
- mutex_lock lock(mu_);
- closed_ = true;
- for (Attempt& attempt : enqueue_attempts_) {
- attempt.is_cancelled = true;
- attempt.context->SetStatus(
- errors::Cancelled("Enqueue operation was cancelled"));
- callbacks.emplace_back(std::move(attempt.done_callback));
- }
- }
- for (const DoneCallback& callback : callbacks) {
- callback();
- }
- FlushUnlocked();
-}
-
-bool RandomShuffleQueue::TryAttemptLocked(
- Action action, std::vector<CleanUp>* clean_up) {
- std::deque<Attempt>* attempts =
- action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
- bool progress = false;
- bool done = false;
- while (!done && !attempts->empty()) {
- if (attempts->front().is_cancelled) {
- if (action == kEnqueue) {
- LOG(INFO) << "Skipping cancelled enqueue attempt";
- } else {
- LOG(INFO) << "Skipping cancelled dequeue attempt";
- }
- attempts->pop_front();
- } else {
- Attempt* cur_attempt = &attempts->front();
- switch (cur_attempt->run_callback(cur_attempt)) {
- case kNoProgress:
- done = true;
- break;
- case kProgress:
- done = true;
- progress = true;
- break;
- case kComplete:
- progress = true;
- clean_up->emplace_back(std::move(cur_attempt->done_callback),
- cur_attempt->cancellation_token,
- cur_attempt->context->cancellation_manager());
- attempts->pop_front();
- break;
- }
- }
- }
- return progress;
-}
-
-void RandomShuffleQueue::FlushUnlocked() {
- std::vector<CleanUp> clean_up;
- Ref();
- {
- mutex_lock lock(mu_);
- bool changed;
- do {
- changed = TryAttemptLocked(kEnqueue, &clean_up);
- changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
- } while (changed);
- }
- Unref();
- for (const auto& to_clean : clean_up) {
- if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
- // NOTE(mrry): We can safely ignore the return value of
- // DeregisterCallback because the mutex mu_ ensures that the
- // cleanup action only executes once.
- to_clean.cm->DeregisterCallback(to_clean.to_deregister);
- }
- to_clean.finished();
- }
-}
-
void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
DoneCallback callback) {
CancellationManager* cm = ctx->cancellation_manager();
@@ -583,31 +352,6 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
}
}
-void RandomShuffleQueue::Close(OpKernelContext* ctx,
- bool cancel_pending_enqueues,
- DoneCallback callback) {
- if (cancel_pending_enqueues) {
- CloseAndCancel();
- callback();
- } else {
- {
- mutex_lock lock(mu_);
- enqueue_attempts_.emplace_back(
- 0, callback, ctx, CancellationManager::kInvalidToken,
- [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (closed_) {
- attempt->context->SetStatus(errors::Aborted(
- "RandomShuffleQueue '", name_, "' is already closed."));
- } else {
- closed_ = true;
- }
- return kComplete;
- });
- }
- FlushUnlocked();
- }
-}
-
Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
@@ -640,8 +384,6 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
-typedef std::shared_ptr<QueueInterface> QueueInterfacePtr;
-
// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
// backed by RandomShuffleQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 3477266d5d..7e55149cd1 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -171,8 +171,8 @@ class SliceOp : public OpKernel {
template <int NDIM>
void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
- Eigen::DSizes<ptrdiff_t, NDIM> indices;
- Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
indices[i] = begin[i];
sizes[i] = size[i];
@@ -205,8 +205,8 @@ namespace functor {
void Slice<GPUDevice, T, NDIM>::operator()( \
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
typename TTypes<T, NDIM>::ConstTensor input, \
- const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
- const Eigen::DSizes<ptrdiff_t, NDIM>& sizes); \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
extern template struct Slice<GPUDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index 1b6bd9c112..89bc8be8ac 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -13,8 +13,8 @@ template <typename Device, typename T, int NDIMS>
struct Slice {
void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
typename TTypes<T, NDIMS>::ConstTensor input,
- const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_indices,
- const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_sizes) {
+ const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
output.device(d) = input.slice(slice_indices, slice_sizes);
}
};
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index f4f9ada000..e8808c1be2 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -90,17 +90,17 @@ class SplitOp : public OpKernel {
TensorShape output_shape(input_shape);
output_shape.set_dim(split_dim, split_dim_output_size);
- Eigen::DSizes<ptrdiff_t, 3> indices{0, 0, 0};
- Eigen::DSizes<ptrdiff_t, 3> sizes{prefix_dim_size, split_dim_output_size,
- suffix_dim_size};
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ prefix_dim_size, split_dim_output_size, suffix_dim_size};
for (int i = 0; i < num_split; ++i) {
Tensor* result = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(i, output_shape, &result));
if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
- Eigen::DSizes<ptrdiff_t, 3> slice_indices;
- Eigen::DSizes<ptrdiff_t, 3> slice_sizes;
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
for (int j = 0; j < 3; ++j) {
slice_indices[j] = indices[j];
slice_sizes[j] = sizes[j];
diff --git a/tensorflow/core/kernels/split_op.h b/tensorflow/core/kernels/split_op.h
index 2572c77285..fb81d93a39 100644
--- a/tensorflow/core/kernels/split_op.h
+++ b/tensorflow/core/kernels/split_op.h
@@ -12,8 +12,8 @@ template <typename Device, typename T>
struct Split {
void operator()(const Device& d, typename TTypes<T, 3>::Tensor output,
typename TTypes<T, 3>::ConstTensor input,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
};
template <typename T>
@@ -21,8 +21,8 @@ struct Split<Eigen::ThreadPoolDevice, T> {
void operator()(const Eigen::ThreadPoolDevice& d,
typename TTypes<T, 3>::Tensor output,
typename TTypes<T, 3>::ConstTensor input,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
};
} // namespace functor
diff --git a/tensorflow/core/kernels/split_op_cpu.cc b/tensorflow/core/kernels/split_op_cpu.cc
index b86deeb8fb..eda432b6f9 100644
--- a/tensorflow/core/kernels/split_op_cpu.cc
+++ b/tensorflow/core/kernels/split_op_cpu.cc
@@ -13,8 +13,8 @@ template <typename T>
void Split<Eigen::ThreadPoolDevice, T>::operator()(
const Eigen::ThreadPoolDevice& d, typename TTypes<T, 3>::Tensor output,
typename TTypes<T, 3>::ConstTensor input,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
if (output.size() < 131072) {
output = input.slice(slice_indices, slice_sizes);
} else {
diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc
index f8931d6a89..d6a68bf9a5 100644
--- a/tensorflow/core/kernels/split_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_op_gpu.cu.cc
@@ -16,8 +16,8 @@ template <typename Device, typename T>
void Split<Device, T>::operator()(
const Device& d, typename TTypes<T, 3>::Tensor output,
typename TTypes<T, 3>::ConstTensor input,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
- const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
output.device(d) = input.slice(slice_indices, slice_sizes);
}
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index d5e0e89d60..decc3207a1 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -273,8 +273,8 @@ class TileGradientOp : public OpKernel {
#undef HANDLE_DIM
}
- Eigen::DSizes<ptrdiff_t, NDIM> indices;
- Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
// Accumulate slices along the dimensions into the output. The number of
// slices along dimension 'i' is simply the multiple along dimension 'i'
@@ -309,8 +309,8 @@ class TileGradientOp : public OpKernel {
void HandleReduce(OpKernelContext* context,
const std::vector<int32>& reduce_dim_in, Tensor* result) {
static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions");
- Eigen::DSizes<ptrdiff_t, REDUCENDIM> reduce_dim;
- Eigen::DSizes<ptrdiff_t, NDIM> reshape_dim;
+ Eigen::DSizes<Eigen::DenseIndex, REDUCENDIM> reduce_dim;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> reshape_dim;
for (int i = 0; i < REDUCENDIM; ++i) {
reduce_dim[i] = reduce_dim_in[i];
@@ -392,26 +392,26 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
DEFINE_GPU_DIM(T, 4) \
DEFINE_GPU_DIM(T, 5)
-#define DEFINE_GPU_DIM(T, NDIM) \
- template <> \
- void Tile<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::array<int32, NDIM>& broadcast_array) const; \
- extern template struct Tile<GPUDevice, T, NDIM>; \
- template <> \
- void TileGrad<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
- const Eigen::DSizes<ptrdiff_t, NDIM>& sizes, bool first) const; \
- extern template struct TileGrad<GPUDevice, T, NDIM>; \
- template <> \
- void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::DSizes<ptrdiff_t, 1>& reduce_dim, \
- const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const; \
+#define DEFINE_GPU_DIM(T, NDIM) \
+ template <> \
+ void Tile<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::array<int32, NDIM>& broadcast_array) const; \
+ extern template struct Tile<GPUDevice, T, NDIM>; \
+ template <> \
+ void TileGrad<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes, bool first) const; \
+ extern template struct TileGrad<GPUDevice, T, NDIM>; \
+ template <> \
+ void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::DSizes<Eigen::DenseIndex, 1>& reduce_dim, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const; \
extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
namespace functor {
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
index 41c2deb42d..1a614fe4f1 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -31,8 +31,8 @@ template <typename Device, typename T, int NDIM>
struct TileGrad {
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
typename TTypes<T, NDIM>::ConstTensor in,
- const Eigen::DSizes<ptrdiff_t, NDIM>& indices,
- const Eigen::DSizes<ptrdiff_t, NDIM>& sizes,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
bool first) const {
if (first) {
out.device(d) = in.slice(indices, sizes);
@@ -58,10 +58,11 @@ struct TileGrad<Device, T, 0> {
template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
struct ReduceAndReshape {
- void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
- typename TTypes<T, NDIM>::ConstTensor in,
- const Eigen::DSizes<ptrdiff_t, REDUCEDNDIM>& reduce_dim,
- const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const {
+ void operator()(
+ const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const {
out.device(d) = in.sum(reduce_dim).reshape(reshape_dim);
}
};
diff --git a/tensorflow/core/kernels/typed_queue.h b/tensorflow/core/kernels/typed_queue.h
new file mode 100644
index 0000000000..ae2878d87b
--- /dev/null
+++ b/tensorflow/core/kernels/typed_queue.h
@@ -0,0 +1,54 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
+
+#include <vector>
+
+#include "tensorflow/core/kernels/queue_base.h"
+
+namespace tensorflow {
+
+// TypedQueue builds on QueueBase, with backing class (SubQueue)
+// known and stored within. Shared methods that need to have access
+// to the backed data sit in this class.
+template <typename SubQueue>
+class TypedQueue : public QueueBase {
+ public:
+ TypedQueue(const int32 capacity, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+
+ virtual Status Initialize(); // Must be called before any other method.
+
+ protected:
+ std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+}; // class TypedQueue
+
+template <typename SubQueue>
+TypedQueue<SubQueue>::TypedQueue(
+ int32 capacity, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes, const string& name)
+ : QueueBase(capacity, component_dtypes, component_shapes, name) {}
+
+template <typename SubQueue>
+Status TypedQueue<SubQueue>::Initialize() {
+ if (component_dtypes_.empty()) {
+ return errors::InvalidArgument("Empty component types for queue ", name_);
+ }
+ if (!component_shapes_.empty() &&
+ component_dtypes_.size() != component_shapes_.size()) {
+ return errors::InvalidArgument("Different number of component types (",
+ component_dtypes_.size(), ") vs. shapes (",
+ component_shapes_.size(), ").");
+ }
+
+ mutex_lock lock(mu_);
+ queues_.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ queues_.push_back(SubQueue());
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc
index 36cfb2c8e5..5d1376be83 100644
--- a/tensorflow/core/kernels/unpack_op.cc
+++ b/tensorflow/core/kernels/unpack_op.cc
@@ -63,8 +63,8 @@ class UnpackOp : public OpKernel {
context->allocate_output(i, output_shape, &output));
auto output_shaped = output->shaped<T, 3>({1, 1, output_size});
- Eigen::DSizes<ptrdiff_t, 3> indices{0, i, 0};
- Eigen::DSizes<ptrdiff_t, 3> sizes{1, 1, output_size};
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, i, 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1, output_size};
functor::Split<Device, T>()(context->eigen_device<Device>(),
output_shaped, input_reshaped, indices,
sizes);
diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md
index e30022d9c2..ddbdd72e16 100644
--- a/tensorflow/g3doc/api_docs/cc/index.md
+++ b/tensorflow/g3doc/api_docs/cc/index.md
@@ -23,28 +23,37 @@ write the graph to a file.
1. Run the graph with a call to `session->Run()`
-
-##Classes <a class="md-anchor" id="AUTOGENERATED-classes"></a>
+## Env <a class="md-anchor" id="AUTOGENERATED-env"></a>
* [tensorflow::Env](../../api_docs/cc/ClassEnv.md)
* [tensorflow::RandomAccessFile](../../api_docs/cc/ClassRandomAccessFile.md)
* [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
* [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
+
+## Session <a class="md-anchor" id="AUTOGENERATED-session"></a>
+
* [tensorflow::Session](../../api_docs/cc/ClassSession.md)
+* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
+
+## Status <a class="md-anchor" id="AUTOGENERATED-status"></a>
+
* [tensorflow::Status](../../api_docs/cc/ClassStatus.md)
+* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
+
+## Tensor <a class="md-anchor" id="AUTOGENERATED-tensor"></a>
+
* [tensorflow::Tensor](../../api_docs/cc/ClassTensor.md)
* [tensorflow::TensorShape](../../api_docs/cc/ClassTensorShape.md)
+* [tensorflow::TensorShapeDim](../../api_docs/cc/StructTensorShapeDim.md)
* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
-* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
-##Structs <a class="md-anchor" id="AUTOGENERATED-structs"></a>
+## Thread <a class="md-anchor" id="AUTOGENERATED-thread"></a>
-* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
-* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
-* [tensorflow::TensorShapeDim](../../api_docs/cc/StructTensorShapeDim.md)
+* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
* [tensorflow::ThreadOptions](../../api_docs/cc/StructThreadOptions.md)
+
<div class='sections-order' style="display: none;">
<!--
<!-- ClassEnv.md -->
@@ -52,14 +61,14 @@ write the graph to a file.
<!-- ClassWritableFile.md -->
<!-- ClassEnvWrapper.md -->
<!-- ClassSession.md -->
+<!-- StructSessionOptions.md -->
<!-- ClassStatus.md -->
+<!-- StructState.md -->
<!-- ClassTensor.md -->
<!-- ClassTensorShape.md -->
+<!-- StructTensorShapeDim.md -->
<!-- ClassTensorShapeUtils.md -->
<!-- ClassThread.md -->
-<!-- StructSessionOptions.md -->
-<!-- StructState.md -->
-<!-- StructTensorShapeDim.md -->
<!-- StructThreadOptions.md -->
-->
</div>
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 5bb7cd9c7a..4f4292ff5e 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -597,7 +597,7 @@ For so-called "global normalization" needed for convolutional filters pass
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
- Two `Tensors`: `mean` and `variance`.
+ Two `Tensor` objects: `mean` and `variance`.
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 9014a81150..aa0301028a 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -175,20 +175,20 @@ depends on.
Follow instructions [here](http://bazel.io/docs/install.html) to install the
-dependencies for Bazel. Then download and build the Bazel source with the
-following commands:
+dependencies for Bazel. Then download bazel version 0.1.1 using the
+[installer for your system](https://github.com/bazelbuild/bazel/releases) and
+run the installer as mentioned there:
```bash
-$ git clone https://github.com/bazelbuild/bazel.git
-$ cd bazel
-$ git checkout tags/0.1.0
-$ ./compile.sh
+$ chmod +x PATH_TO_INSTALL.SH
+$ ./PATH_TO_INSTALL.SH --user
```
-These commands use the commit tag `0.1.0`, which is known to work with
-TensorFlow. `HEAD` may be unstable.
+Remember to replace `PATH_TO_INSTALL.SH` to point to the location where you
+downloaded the installer.
-Add the executable `output/bazel` to your `$PATH` environment variable.
+Finally, follow the instructions in that script to place bazel into your binary
+path.
#### Install other dependencies <a class="md-anchor" id="AUTOGENERATED-install-other-dependencies"></a>
diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md
index a2bc573348..57c88f4167 100644
--- a/tensorflow/g3doc/resources/index.md
+++ b/tensorflow/g3doc/resources/index.md
@@ -15,6 +15,11 @@ system, we suggest you cite the paper above.
You can use this [BibTeX entry](../resources/bib.md). As the project progresses, we
may update the suggested citation with new papers.
+Please only use the TensorFlow name and marks when accurately referencing this
+software distribution, and do not use our marks in a way that suggests you are
+endorsed by or otherwise affiliated with Google. When referring to our marks,
+please include the following attribution statement: "TensorFlow, the TensorFlow
+logo and any related marks are trademarks of Google Inc."
## Community <a class="md-anchor" id="AUTOGENERATED-community"></a>
diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD
index f8f7e7bcb2..fe52778fa9 100644
--- a/tensorflow/models/embedding/BUILD
+++ b/tensorflow/models/embedding/BUILD
@@ -12,6 +12,7 @@ py_binary(
srcs = [
"word2vec.py",
],
+ srcs_version = "PY2AND3",
deps = [
":gen_word2vec",
"//tensorflow:tensorflow_py",
@@ -24,6 +25,7 @@ py_binary(
srcs = [
"word2vec_optimized.py",
],
+ srcs_version = "PY2AND3",
deps = [
":gen_word2vec",
"//tensorflow:tensorflow_py",
@@ -35,6 +37,7 @@ py_test(
name = "word2vec_test",
size = "small",
srcs = ["word2vec_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":word2vec",
"//tensorflow:tensorflow_py",
@@ -45,6 +48,7 @@ py_test(
name = "word2vec_optimized_test",
size = "small",
srcs = ["word2vec_optimized_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":word2vec_optimized",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/image/alexnet/BUILD b/tensorflow/models/image/alexnet/BUILD
index e1b9cd6965..bbe29da6f5 100644
--- a/tensorflow/models/image/alexnet/BUILD
+++ b/tensorflow/models/image/alexnet/BUILD
@@ -10,6 +10,7 @@ py_binary(
srcs = [
"alexnet_benchmark.py",
],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD
index adf9aaffd4..25dce65f28 100644
--- a/tensorflow/models/image/cifar10/BUILD
+++ b/tensorflow/models/image/cifar10/BUILD
@@ -8,6 +8,7 @@ exports_files(["LICENSE"])
py_library(
name = "cifar10_input",
srcs = ["cifar10_input.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
@@ -16,6 +17,7 @@ py_library(
py_test(
name = "cifar10_input_test",
srcs = ["cifar10_input_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":cifar10_input",
"//tensorflow:tensorflow_py",
@@ -27,6 +29,7 @@ py_test(
py_library(
name = "cifar10",
srcs = ["cifar10.py"],
+ srcs_version = "PY2AND3",
deps = [
":cifar10_input",
"//tensorflow:tensorflow_py",
@@ -38,6 +41,7 @@ py_binary(
srcs = [
"cifar10_eval.py",
],
+ srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
":cifar10",
@@ -49,6 +53,7 @@ py_binary(
srcs = [
"cifar10_train.py",
],
+ srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
":cifar10",
@@ -60,6 +65,7 @@ py_binary(
srcs = [
"cifar10_multi_gpu_train.py",
],
+ srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
":cifar10",
diff --git a/tensorflow/models/image/mnist/BUILD b/tensorflow/models/image/mnist/BUILD
index 6774810e82..6dd96e1e6f 100644
--- a/tensorflow/models/image/mnist/BUILD
+++ b/tensorflow/models/image/mnist/BUILD
@@ -10,6 +10,7 @@ py_binary(
srcs = [
"convolutional.py",
],
+ srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow:tensorflow_py"],
)
@@ -24,6 +25,7 @@ py_test(
"--self_test=True",
],
main = "convolutional.py",
+ srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
diff --git a/tensorflow/models/rnn/BUILD b/tensorflow/models/rnn/BUILD
index 3e5e6b37ca..1a81ce2801 100644
--- a/tensorflow/models/rnn/BUILD
+++ b/tensorflow/models/rnn/BUILD
@@ -14,6 +14,7 @@ py_library(
srcs = [
"linear.py",
],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
@@ -23,6 +24,7 @@ py_test(
name = "linear_test",
size = "small",
srcs = ["linear_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":linear",
"//tensorflow:tensorflow_py",
@@ -34,6 +36,7 @@ py_library(
srcs = [
"rnn_cell.py",
],
+ srcs_version = "PY2AND3",
deps = [
":linear",
"//tensorflow:tensorflow_py",
@@ -44,6 +47,7 @@ py_test(
name = "rnn_cell_test",
size = "small",
srcs = ["rnn_cell_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":rnn_cell",
"//tensorflow:tensorflow_py",
@@ -55,6 +59,7 @@ py_library(
srcs = [
"__init__.py",
],
+ srcs_version = "PY2AND3",
deps = [
":rnn",
":rnn_cell",
@@ -67,6 +72,7 @@ py_library(
srcs = [
"rnn.py",
],
+ srcs_version = "PY2AND3",
deps = [
":rnn_cell",
"//tensorflow:tensorflow_py",
@@ -88,6 +94,7 @@ py_library(
srcs = [
"seq2seq.py",
],
+ srcs_version = "PY2AND3",
deps = [
":rnn",
"//tensorflow:tensorflow_py",
@@ -99,6 +106,7 @@ py_test(
srcs = [
"seq2seq_test.py",
],
+ srcs_version = "PY2AND3",
deps = [
":seq2seq",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/rnn/ptb/BUILD b/tensorflow/models/rnn/ptb/BUILD
index 56d459a0f1..c5feb191d5 100644
--- a/tensorflow/models/rnn/ptb/BUILD
+++ b/tensorflow/models/rnn/ptb/BUILD
@@ -10,12 +10,14 @@ exports_files(["LICENSE"])
py_library(
name = "reader",
srcs = ["reader.py"],
+ srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "reader_test",
srcs = ["reader_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":reader",
"//tensorflow:tensorflow_py",
@@ -27,6 +29,7 @@ py_binary(
srcs = [
"ptb_word_lm.py",
],
+ srcs_version = "PY2AND3",
deps = [
":reader",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/rnn/translate/BUILD b/tensorflow/models/rnn/translate/BUILD
index 57f17fb5ab..cf3780165b 100644
--- a/tensorflow/models/rnn/translate/BUILD
+++ b/tensorflow/models/rnn/translate/BUILD
@@ -12,6 +12,7 @@ py_library(
srcs = [
"data_utils.py",
],
+ srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@@ -20,6 +21,7 @@ py_library(
srcs = [
"seq2seq_model.py",
],
+ srcs_version = "PY2AND3",
deps = [
":data_utils",
"//tensorflow:tensorflow_py",
@@ -32,6 +34,7 @@ py_binary(
srcs = [
"translate.py",
],
+ srcs_version = "PY2AND3",
deps = [
":data_utils",
":seq2seq_model",
@@ -49,6 +52,7 @@ py_test(
"--self_test=True",
],
main = "translate.py",
+ srcs_version = "PY2AND3",
deps = [
":data_utils",
":seq2seq_model",
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7002ebfd65..5c6e08ae44 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -27,6 +27,7 @@ numpy_macosx_include_dir = select({
py_library(
name = "python",
srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
visibility = ["//tensorflow:__pkg__"],
deps = [
":client",
@@ -43,6 +44,7 @@ py_library(
py_library(
name = "platform",
srcs = glob(["platform/**/*.py"]),
+ srcs_version = "PY2AND3",
)
py_library(
@@ -51,6 +53,7 @@ py_library(
"platform/default/_googletest.py",
"platform/googletest.py",
],
+ srcs_version = "PY2AND3",
deps = [":platform"],
)
@@ -94,6 +97,7 @@ py_test(
name = "pywrap_status_test",
size = "small",
srcs = ["lib/core/pywrap_status_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
@@ -133,6 +137,7 @@ py_library(
"framework/tensor_util.py",
"ops/common_shapes.py",
],
+ srcs_version = "PY2AND3",
deps = [
":platform",
"//tensorflow/core:protos_all_py",
@@ -143,6 +148,7 @@ py_library(
py_library(
name = "extra_py_tests_deps",
+ srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@@ -151,6 +157,7 @@ py_library(
srcs = [
"framework/test_util.py",
],
+ srcs_version = "PY2AND3",
deps = [
":framework",
":platform_test",
@@ -165,6 +172,7 @@ py_library(
srcs = [
"platform/test.py",
],
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
@@ -175,6 +183,7 @@ py_test(
name = "framework_errors_test",
srcs = ["framework/errors_test.py"],
main = "framework/errors_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
@@ -187,6 +196,7 @@ py_test(
name = "framework_importer_test",
srcs = ["framework/importer_test.py"],
main = "framework/importer_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":ops",
@@ -213,6 +223,7 @@ py_test(
name = "framework_ops_test",
srcs = ["framework/ops_test.py"],
main = "framework/ops_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":ops",
@@ -226,6 +237,19 @@ py_test(
name = "framework_tensor_shape_test",
srcs = ["framework/tensor_shape_test.py"],
main = "framework/tensor_shape_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "framework_tensor_shape_div_test",
+ srcs = ["framework/tensor_shape_div_test.py"],
+ main = "framework/tensor_shape_div_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
@@ -237,6 +261,7 @@ py_test(
name = "framework_tensor_util_test",
srcs = ["framework/tensor_util_test.py"],
main = "framework/tensor_util_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":ops",
@@ -248,6 +273,7 @@ py_test(
name = "framework_test_util_test",
srcs = ["framework/test_util_test.py"],
main = "framework/test_util_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":ops",
@@ -259,6 +285,7 @@ py_test(
name = "framework_types_test",
srcs = ["framework/types_test.py"],
main = "framework/types_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
@@ -271,6 +298,7 @@ py_test(
name = "op_def_library_test",
srcs = ["ops/op_def_library_test.py"],
main = "ops/op_def_library_test.py",
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":ops",
@@ -565,6 +593,7 @@ py_library(
"ops/variables.py",
"user_ops/user_ops.py",
],
+ srcs_version = "PY2AND3",
deps = [
":array_ops",
":candidate_sampling_ops",
@@ -591,6 +620,7 @@ py_library(
["training/**/*.py"],
exclude = ["**/*test*"],
),
+ srcs_version = "PY2AND3",
deps = [
":client",
":framework",
@@ -609,6 +639,7 @@ py_library(
["client/**/*.py"],
exclude = ["**/*test*"],
),
+ srcs_version = "PY2AND3",
deps = [
":framework",
":ops",
@@ -620,6 +651,7 @@ py_library(
py_library(
name = "util",
srcs = glob(["util/**/*.py"]),
+ srcs_version = "PY2AND3",
deps = ["//google/protobuf:protobuf_python"],
)
@@ -641,6 +673,7 @@ py_test(
name = "protobuf_compare_test",
srcs = ["util/protobuf/compare_test.py"],
main = "util/protobuf/compare_test.py",
+ srcs_version = "PY2AND3",
deps = [
":compare_test_proto_py",
":platform_test",
@@ -654,6 +687,7 @@ py_test(
srcs = [
"client/events_writer_test.py",
],
+ srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":lib",
@@ -719,6 +753,7 @@ tf_py_wrap_cc(
py_library(
name = "lib",
srcs = glob(["lib/**/*.py"]),
+ srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
],
@@ -727,6 +762,7 @@ py_library(
py_library(
name = "session",
srcs = ["client/session.py"],
+ srcs_version = "PY2AND3",
deps = [
":framework",
":ops",
@@ -750,6 +786,7 @@ tf_cuda_library(
py_test(
name = "session_test",
srcs = ["client/session_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":framework",
":framework_test_lib",
@@ -760,6 +797,7 @@ py_test(
py_test(
name = "graph_util_test",
srcs = ["client/graph_util_test.py"],
+ srcs_version = "PY2AND3",
deps = [
":framework",
":framework_test_lib",
@@ -770,6 +808,7 @@ py_test(
py_library(
name = "kernel_tests/gradient_checker",
srcs = ["kernel_tests/gradient_checker.py"],
+ srcs_version = "PY2AND3",
)
cpu_only_kernel_test_list = glob([
@@ -899,6 +938,7 @@ py_library(
["summary/**/*.py"],
exclude = ["**/*test*"],
),
+ srcs_version = "PY2AND3",
deps = [
":client",
":framework",
@@ -921,6 +961,7 @@ py_library(
srcs = [
"framework/docs.py",
],
+ srcs_version = "PY2AND3",
deps = [
":platform",
],
@@ -932,6 +973,7 @@ py_binary(
"framework/gen_docs_combined.py",
],
main = "framework/gen_docs_combined.py",
+ srcs_version = "PY2AND3",
deps = [
":docs",
":platform",
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index b5462dcd17..865533cf92 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -170,18 +170,18 @@ class Dimension(object):
def __floordiv__(self, other):
"""Returns the quotient of `self` and `other` rounded down.
- Dimensions are summed as follows:
+ Dimensions are divided as follows:
- Dimension(m) / Dimension(n) == Dimension(m / n)
- Dimension(m) / Dimension(None) == Dimension(None)
- Dimension(None) / Dimension(n) == Dimension(None)
- Dimension(None) / Dimension(None) == Dimension(None)
+ Dimension(m) // Dimension(n) == Dimension(m // n)
+ Dimension(m) // Dimension(None) == Dimension(None)
+ Dimension(None) // Dimension(n) == Dimension(None)
+ Dimension(None) // Dimension(None) == Dimension(None)
Args:
- other: Another Dimension.
+ other: Another `Dimension`.
Returns:
- A Dimension whose value is the sum of `self` and `other`.
+ A `Dimension` whose value is the integer quotient of `self` and `other`.
"""
other = as_dimension(other)
if self._value is None or other.value is None:
@@ -189,6 +189,22 @@ class Dimension(object):
else:
return Dimension(self._value // other.value)
+ def __div__(self, other):
+ """DEPRECATED: Use `__floordiv__` via `x // y` instead.
+
+ This function exists only for backwards compatibility purposes; new code
+ should use `__floordiv__` via the syntax `x // y`. Using `x // y`
+ communicates clearly that the result rounds down, and is forward compatible
+ to Python 3.
+
+ Args:
+ other: Another `Dimension`.
+
+ Returns:
+ A `Dimension` whose value is the integer quotient of `self` and `other`.
+ """
+ return self // other
+
def __mod__(self, other):
"""Returns `self` modulo `other.
diff --git a/tensorflow/python/framework/tensor_shape_div_test.py b/tensorflow/python/framework/tensor_shape_div_test.py
new file mode 100644
index 0000000000..27219dbb9a
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape_div_test.py
@@ -0,0 +1,24 @@
+"""Test that old style division works for Dimension."""
+from __future__ import absolute_import
+# from __future__ import division # Intentionally skip this import
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DimensionDivTest(test_util.TensorFlowTestCase):
+
+ def testDivSucceeds(self):
+ """Without from __future__ import division, __div__ should work."""
+ values = [tensor_shape.Dimension(x) for x in 3, 7, 11, None]
+ for x in values:
+ for y in values:
+ self.assertEqual((x / y).value, (x // y).value)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index be5fbb51cb..43cadc7b13 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -233,6 +233,12 @@ class ShapeTest(test_util.TensorFlowTestCase):
tensor_shape.TensorShape(
[94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
+ def testTruedivFails(self):
+ unknown = tensor_shape.Dimension(None)
+ self.assertEqual((unknown // unknown).value, None)
+ with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
+ unknown / unknown # pylint: disable=pointless-statement
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 28138fbf39..78262a55f3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -409,7 +409,7 @@ def split(split_dim, num_split, value, name="split"):
Args:
split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
Must be in the range `[0, rank(value))`.
- num_split: A 0-D `int32` `Tensor`. The number of ways to split.
+ num_split: A Python integer. The number of ways to split.
value: The `Tensor` to split.
name: A name for the operation (optional).
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 63826170be..c5730dce21 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -138,7 +138,7 @@ class Optimizer(object):
self._slots = {}
def minimize(self, loss, global_step=None, var_list=None,
- gate_gradients=GATE_OP, name=None):
+ gate_gradients=GATE_OP, aggregation_method=None, name=None):
"""Add operations to minimize 'loss' by updating 'var_list'.
This method simply combines calls compute_gradients() and
@@ -155,6 +155,8 @@ class Optimizer(object):
under the key GraphKeys.TRAINABLE_VARIABLES.
gate_gradients: How to gate the computation of gradients. Can be
GATE_NONE, GATE_OP, or GATE_GRAPH.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
name: Optional name for the returned operation.
Returns:
@@ -164,12 +166,14 @@ class Optimizer(object):
Raises:
ValueError: if some of the variables are not variables.Variable objects.
"""
- grads_and_vars = self.compute_gradients(loss, var_list=var_list,
- gate_gradients=gate_gradients)
+ grads_and_vars = self.compute_gradients(
+ loss, var_list=var_list, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method)
return self.apply_gradients(grads_and_vars, global_step=global_step,
name=name)
- def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP):
+ def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP,
+ aggregation_method=None):
"""Compute gradients of "loss" for the variables in "var_list".
This is the first part of minimize(). It returns a list
@@ -185,6 +189,8 @@ class Optimizer(object):
under the key GraphKey.TRAINABLE_VARIABLES.
gate_gradients: How to gate the computation of gradients. Can be
GATE_NONE, GATE_OP, or GATE_GRAPH.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
Returns:
A list of (gradient, variable) pairs.
@@ -205,7 +211,8 @@ class Optimizer(object):
if not isinstance(var, variables.Variable):
raise TypeError("Argument is not a variables.Variable: %s" % var)
grads = gradients.gradients(
- loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
+ loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP),
+ aggregation_method=aggregation_method)
if gate_gradients == Optimizer.GATE_GRAPH:
grads = control_flow_ops.tuple(grads)
grads_and_vars = list(zip(grads, var_list))
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
new file mode 100644
index 0000000000..a743240d8a
--- /dev/null
+++ b/tensorflow/python/training/optimizer_test.py
@@ -0,0 +1,54 @@
+"""Functional test for optimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class OptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+ sgd_op = tf.train.GradientDescentOptimizer(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+ def testAggregationMethod(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+ sgd_op = tf.train.GradientDescentOptimizer(3.0)
+ opt_op = sgd_op.minimize(
+ cost, global_step, [var0, var1], aggregation_method=
+ tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
+
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD
index 2dcb5e4fa9..74bdd1d6ab 100644
--- a/tensorflow/tensorboard/BUILD
+++ b/tensorflow/tensorboard/BUILD
@@ -20,11 +20,13 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
+ srcs_version = "PY2AND3",
)
py_library(
name = "float_wrapper",
srcs = ["float_wrapper.py"],
+ srcs_version = "PY2AND3",
)
py_test(
@@ -35,6 +37,7 @@ py_test(
":float_wrapper",
"//tensorflow/python:platform_test",
],
+ srcs_version = "PY2AND3",
)
py_binary(
@@ -46,4 +49,5 @@ py_binary(
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
+ srcs_version = "PY2AND3",
)
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 4bcfd6234c..f88f3eb2a6 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -332,7 +332,8 @@ def py_tests(name,
deps=[
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:kernel_tests/gradient_checker",
- ] + additional_deps)
+ ] + additional_deps,
+ srcs_version="PY2AND3")
def cuda_py_tests(name, srcs, additional_deps=[], data=[], shard_count=1):
diff --git a/tensorflow/tools/docker/BUILD b/tensorflow/tools/docker/BUILD
index 2cc540ed3b..7d5ae0a94d 100644
--- a/tensorflow/tools/docker/BUILD
+++ b/tensorflow/tools/docker/BUILD
@@ -10,6 +10,7 @@ exports_files(["LICENSE"])
py_binary(
name = "simple_console",
srcs = ["simple_console.py"],
+ srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index b9a50e4288..d0eb717ad6 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -7,6 +7,7 @@ py_binary(
name = "simple_console",
srcs = ["simple_console.py"],
deps = ["//tensorflow:tensorflow_py"],
+ srcs_version = "PY2AND3",
)
sh_binary(