diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-01-05 14:05:27 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-05 14:05:27 -0800 |
commit | 1c579361cd1e088dd5e05a394b1561a73e3667ba (patch) | |
tree | ec464b9ac18113dc052744b6714eebbc7c6cc34d | |
parent | 208350a6092f9faa473daf8b6eb6a80e9f9518f1 (diff) |
Added 'logging' import to control_flow_ops which is used in the file but not imported.
Change: 110842260
170 files changed, 5924 insertions, 2045 deletions
@@ -21,8 +21,8 @@ new_http_archive( new_http_archive( name = "eigen_archive", - url = "https://bitbucket.org/eigen/eigen/get/3.3-beta1.tar.gz", - sha256 = "2d6533e86ed6b54d30ae1d6c10808533b335d1c570c5e4c58ce2f03da99c134b", + url = "https://bitbucket.org/eigen/eigen/get/a0661a2.tar.gz", + sha256 = "d4d13995a0b3a2d80189f83d28647eb35819a478522149c15a761d91f53579b1", build_file = "eigen.BUILD", ) diff --git a/eigen.BUILD b/eigen.BUILD index c33fe7186e..5c6127e6a9 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -archive_dir = "eigen-eigen-ce5a455b34c0" +archive_dir = "eigen-eigen-a0661a2bb165" cc_library( name = "eigen", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 33444cd45d..b9f253740e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -111,6 +111,7 @@ cc_library( "public/tensor_shape.h", ], copts = tf_copts(), + linkopts = ["-ldl"], visibility = [ ":friends", "//tensorflow:internal", @@ -171,8 +172,12 @@ tf_cuda_library( hdrs = glob([ "public/**/*.h", "util/device_name_utils.h", - ]), + ]) + [ + "framework/op.h", + "framework/op_kernel.h", + ], copts = tf_copts(), + linkopts = ["-ldl"], visibility = ["//visibility:public"], deps = [ ":lib", @@ -422,6 +427,7 @@ tf_gen_op_libs( "no_op", "parsing_ops", "random_ops", + "script_ops", "sendrecv_ops", "sparse_ops", "state_ops", diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc index 87bdc664fb..bc6c88f8c0 100644 --- a/tensorflow/core/client/tensor_c_api.cc +++ b/tensorflow/core/client/tensor_c_api.cc @@ -52,6 +52,11 @@ struct TF_Status { Status status; }; +struct TF_Library { + void* lib_handle; + TF_Buffer op_list; +}; + TF_Status* TF_NewStatus() { return new TF_Status; } void TF_DeleteStatus(TF_Status* s) { delete s; } @@ -304,6 +309,10 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { [](void*, size_t, void*) {}, nullptr); } +// Helpers for loading a TensorFlow plugin (a .so file). +Status LoadLibrary(const char* library_filename, void** result, + const void** buf, size_t* len); + } // namespace tensorflow extern "C" { @@ -382,4 +391,22 @@ void TF_Run(TF_Session* s, } } +const void* TF_BufferData(TF_Buffer* buffer) { return buffer->data; } + +size_t TF_BufferLength(TF_Buffer* buffer) { return buffer->length; } + +TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { + TF_Library* lib_handle = new TF_Library; + status->status = tensorflow::LoadLibrary( + library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, + &lib_handle->op_list.length); + if (!status->status.ok()) { + delete lib_handle; + return nullptr; + } + return lib_handle; +} + +TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } + } // end extern "C" diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 94818e3938..26b2948166 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/refcount.h" @@ -454,6 +455,31 @@ Status DirectSession::CreateGraphs( std::unique_ptr<FunctionLibraryDefinition> fdefs; std::unique_ptr<Graph> graph; GraphConstructorOptions opts; + if (options_.config.has_graph_options()) { + opts.optimizer_do_cse = !options_.config.graph_options() + .skip_common_subexpression_elimination(); + } else { + opts.optimizer_do_cse = true; + } + + if (opts.optimizer_do_cse) { + // Prevent CSE from eliminating nodes that will be required during + // RewriteGraphForExecution, below. + std::unordered_set<StringPiece, StringPiece::Hasher> no_cse_nodes; + for (const string& feed : feeds) { + no_cse_nodes.insert(ParseTensorName(feed).first); + } + for (const string& fetch : fetches) { + no_cse_nodes.insert(ParseTensorName(fetch).first); + } + for (const string& target_node : target_nodes) { + no_cse_nodes.insert(target_node); + } + opts.cse_consider_function = [no_cse_nodes](const Node* n) { + return n->type_string() != "Const" && !no_cse_nodes.count(n->name()); + }; + } + { mutex_lock l(graph_def_lock_); fdefs.reset(new FunctionLibraryDefinition(graph_def_.library())); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 9b6c2de473..c0376e29fa 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -185,7 +185,7 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_; return; } - em_.reset(new EventMgr(executor)); + em_.reset(new EventMgr(executor, options.config.gpu_options())); if (FLAGS_brain_gpu_max_streams < 1) { LOG(FATAL) << "Invalid value for brain_gpu_max_streams."; @@ -262,9 +262,14 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { gpu::Stream* stream = gpu_device_context->stream(); const auto stream_id = gpu_device_context->stream_id(); - VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op " - << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" - << stream_id << "]"; + const bool vlog_1 = VLOG_IS_ON(1); + const bool vlog_2 = vlog_1 && VLOG_IS_ON(2); + + if (vlog_1) { + VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op " + << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" + << stream_id << "]"; + } // NOTE(tucker): We need to discriminate between Eigen GPU // operations and all others. If an operation is Eigen @@ -292,7 +297,7 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { OP_REQUIRES(context, idc != nullptr, errors::Internal("Input device context ", i, " was not set properly.")); - if (VLOG_IS_ON(2)) { + if (vlog_2) { const void* base; size_t len; if (context->has_input(i)) { @@ -316,35 +321,36 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } gpu::cuda::ScopedActivateExecutorContext scoped_activation{ stream->parent(), gpu::cuda::MultiOpActivation::kYes}; - // Keep a copy of the inputs before Compute runs, in case they get - // deleted. TODO(misard) this will be fixed when the tracking is - // done right. - EventMgr::TensorReferenceVector* tensor_refs = nullptr; - if (!FLAGS_brain_gpu_sync_every_op) { + + if (FLAGS_brain_gpu_sync_every_op) { + op_kernel->Compute(context); + if (context->status().ok()) { + // Note: GPUUtil::Sync() only syncs the default stream. + // We need to either sync the stream used by this op, or + // all streams. Given that this flag is typically used for + // debugging it makes more sense to sync all GPU activity. + context->SetStatus(GPUUtil::SyncAll(this)); + } + } else { + // Keep a copy of the inputs before Compute runs, in case they get + // deleted. TODO(misard) this will be fixed when the tracking is + // done right. + EventMgr::TensorReferenceVector tensor_refs; const int N_inputs = context->num_inputs(); - tensor_refs = new EventMgr::TensorReferenceVector; - tensor_refs->reserve(N_inputs + context->num_outputs()); + tensor_refs.reserve(N_inputs + context->num_outputs()); for (int ii = 0; ii < N_inputs; ++ii) { if (context->has_input(ii)) { if (IsRefType(context->input_dtype(ii))) { Tensor in = context->mutable_input(ii, false); - tensor_refs->push_back(TensorReference(in)); + tensor_refs.push_back(TensorReference(in)); } else { const Tensor& in = context->input(ii); - tensor_refs->push_back(TensorReference(in)); + tensor_refs.push_back(TensorReference(in)); } } } - } - op_kernel->Compute(context); - if (context->status().ok()) { - if (FLAGS_brain_gpu_sync_every_op) { - // Note: GPUUtil::Sync() only syncs the default stream. - // We need to either sync the stream used by this op, or - // all streams. Given that this flag is typically used for - // debugging it makes more sense to sync all GPU activity. - context->SetStatus(GPUUtil::SyncAll(this)); - } else { + op_kernel->Compute(context); + if (context->status().ok()) { // The GPU kernel has been queued, but may not complete for some // time. As soon as this function completes, the caller will // discard its refs on the inputs, outputs and any scratch @@ -352,21 +358,19 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { // held until the kernel completes. for (int ii = 0; ii < context->num_temps(); ++ii) { Tensor* temp = context->temp(ii); - VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp); - tensor_refs->push_back(TensorReference(*temp)); + if (vlog_2) { + VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp); + } + tensor_refs.push_back(TensorReference(*temp)); } for (int ii = 0; ii < context->num_outputs(); ++ii) { Tensor* temp = context->mutable_output(ii); if (nullptr != temp) { - tensor_refs->push_back(TensorReference(*temp)); + tensor_refs.push_back(TensorReference(*temp)); } } em_->ThenDeleteTensors(stream, tensor_refs); } - } else { - if (!FLAGS_brain_gpu_sync_every_op) { - delete tensor_refs; - } } } } @@ -431,28 +435,29 @@ namespace { class ConcretePerOpGpuDevice : public PerOpGpuDevice { public: explicit ConcretePerOpGpuDevice(gpu::Stream* stream, - EigenAllocator* allocator) - : device_(stream, allocator), allocator_(allocator) {} - ~ConcretePerOpGpuDevice() { delete allocator_; } + Allocator* base_allocator, + ::tensorflow::EventMgr* em) + : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {} const Eigen::GpuDevice& device() const override { return device_; } private: + EigenAllocator allocator_; Eigen::GpuDevice device_; - EigenAllocator* allocator_; }; #else class ConcretePerOpGpuDevice : public PerOpGpuDevice { public: - explicit ConcretePerOpGpuDevice(EigenCudaStreamDevice* stream_device) - : device_(stream_device), stream_device_(stream_device) {} - ~ConcretePerOpGpuDevice() { delete stream_device_; } + explicit ConcretePerOpGpuDevice(const cudaStream_t* cuda_stream, int gpu_id, + Allocator* base_allocator) + : stream_device_(cuda_stream, gpu_id, base_allocator), + device_(&stream_device_) {} const Eigen::GpuDevice& device() const override { return device_; } private: + EigenCudaStreamDevice stream_device_; Eigen::GpuDevice device_; - EigenCudaStreamDevice* stream_device_; }; #endif } // namespace @@ -460,13 +465,11 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id, Allocator* allocator) { #if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) - auto ea = new EigenAllocator(streams_[stream_id], allocator, em_.get()); - return new ConcretePerOpGpuDevice(streams_[stream_id], ea); + return new ConcretePerOpGpuDevice(streams_[stream_id], allocator, em_.get()); #else const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( streams_[stream_id]->implementation()->CudaStreamMemberHack()); - auto es = new EigenCudaStreamDevice(cuda_stream, gpu_id_, allocator); - return new ConcretePerOpGpuDevice(es); + return new ConcretePerOpGpuDevice(cuda_stream, gpu_id_, allocator); #endif } diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc index 962848ad17..6dd7c3c235 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc @@ -17,13 +17,20 @@ limitations under the License. #include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/stream.h" +#include "tensorflow/core/framework/config.pb.h" namespace gpu = ::perftools::gputools; namespace tensorflow { -EventMgr::EventMgr(gpu::StreamExecutor* se) +EventMgr::EventMgr(gpu::StreamExecutor* se, const GPUOptions& gpu_options) : exec_(se), + deferred_bytes_threshold_(gpu_options.deferred_deletion_bytes() + ? gpu_options.deferred_deletion_bytes() + : 8 * 1048576), + accumulated_stream_(nullptr), + accumulated_tensors_(new TensorReferenceVector), + accumulated_tensor_bytes_(0), // threadpool_ has 1 thread for the polling loop, and one to execute // event callback functions. Maybe we should have more? threadpool_(Env::Default(), "GPU_Event_Manager", 2) { @@ -39,6 +46,10 @@ EventMgr::~EventMgr() { for (auto& e : free_events_) { delete e; } + for (auto& t : *(accumulated_tensors_)) { + t.Unref(); + } + delete accumulated_tensors_; while (!used_events_.empty()) { InUse* ue = &used_events_[0]; delete ue->event; @@ -51,6 +62,35 @@ EventMgr::~EventMgr() { } } +void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream, + const TensorReferenceVector& tensors) { + mutex_lock l(mu_); + // TODO(jeff): We currently keep one accumulated_tensors_ object. + // If we start to use multiple streams heavily, we might want to keep + // separate vectors/byte counters per stream + if (!accumulated_tensors_->empty() && stream != accumulated_stream_) { + FlushAccumulatedTensors(); + } + accumulated_stream_ = stream; + for (auto t : tensors) { + // accumulated_tensors_ takes over ownership of the reference to "t" + accumulated_tensors_->push_back(t); + accumulated_tensor_bytes_ += t.TotalBytes(); + } + if (accumulated_tensor_bytes_ >= deferred_bytes_threshold_) { + FlushAccumulatedTensors(); + } +} + +void EventMgr::FlushAccumulatedTensors() { + DCHECK(!accumulated_tensors_->empty()); + DCHECK(accumulated_stream_ != nullptr); + QueueTensors(accumulated_stream_, accumulated_tensors_); + accumulated_tensors_ = new TensorReferenceVector; + accumulated_tensor_bytes_ = 0; + accumulated_stream_ = nullptr; +} + // This polling loop runs at a relatively low frequency. Most calls to // PollEvents() should come directly from Compute() via // ThenDeleteTensors(). This function's purpose is to ensure that diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h index 3faee71614..09d785d792 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -37,29 +37,24 @@ class StreamExecutor; namespace tensorflow { +class GPUOptions; + // An object to keep track of pending Events in the StreamExecutor streams // and associated Tensors that cannot safely be deleted until the associated // Events are recorded. class EventMgr { public: - explicit EventMgr(perftools::gputools::StreamExecutor* se); + EventMgr(perftools::gputools::StreamExecutor* se, + const GPUOptions& gpu_options); ~EventMgr(); typedef gtl::InlinedVector<TensorReference, 4> TensorReferenceVector; - // Takes ownership of *tensors and deletes it as soon as all events - // currently enqueued on *stream have completed. - inline void ThenDeleteTensors(perftools::gputools::Stream* stream, - TensorReferenceVector* tensors) { - ToFreeVector to_free; - { - mutex_lock l(mu_); - QueueTensors(stream, tensors); - PollEvents(false, &to_free); - } - FreeMemory(to_free); - } + // Releases the references on the elements of "tensors" as soon as + // all events currently enqueued on "stream" have completed. + void ThenDeleteTensors(perftools::gputools::Stream* stream, + const TensorReferenceVector& tensors); struct BufRec { Allocator* alloc; @@ -92,8 +87,11 @@ class EventMgr { private: friend class TEST_EventMgrHelper; + perftools::gputools::StreamExecutor* const exec_; + const int64 deferred_bytes_threshold_; mutex mu_; - perftools::gputools::StreamExecutor* exec_; + + void FlushAccumulatedTensors() EXCLUSIVE_LOCKS_REQUIRED(mu_); struct InUse { perftools::gputools::Event* event; @@ -122,7 +120,6 @@ class EventMgr { // Tensors and/or a BufRec to be deleted only after the Event // records. void QueueInUse(perftools::gputools::Stream* stream, InUse in_use) - EXCLUSIVE_LOCKS_REQUIRED(mu_); void QueueTensors(perftools::gputools::Stream* stream, @@ -156,6 +153,12 @@ class EventMgr { // A stack of unused events std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_); + // Buffered list of tensors waiting to have an event queued for deletion + perftools::gputools::Stream* accumulated_stream_ GUARDED_BY(mu_); + TensorReferenceVector* accumulated_tensors_ GUARDED_BY(mu_); + // Sum of the TotalBytes() of the tensors in "accumulated_tensors_" + int64 accumulated_tensor_bytes_ GUARDED_BY(mu_); + // A FIFO queue of InUse events and associated tensors. std::deque<InUse> used_events_ GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc index 910093a069..57c1554678 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc @@ -17,10 +17,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include <atomic> #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/stream_executor.h" #include <gtest/gtest.h> #include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/framework/config.pb.h" namespace gpu = ::perftools::gputools; @@ -59,11 +61,32 @@ class TEST_EventMgrHelper { EventMgr* em_; }; +static std::atomic_int_fast64_t live_tensor_bytes(0); + +// A TensorBuffer that counts live memory usage for testing +class TestTensorBuffer : public TensorBuffer { + public: + TestTensorBuffer(size_t bytes) : bytes_(bytes) { + live_tensor_bytes += bytes_; + } + ~TestTensorBuffer() { live_tensor_bytes -= bytes_; } + + size_t size() const override { return bytes_; } + + // Not used in this test + void* data() const override { return nullptr; } + TensorBuffer* root_buffer() override { return nullptr; } + void FillAllocationDescription(AllocationDescription* arg) const override {} + + private: + size_t bytes_; +}; + namespace { TEST(EventMgr, Empty) { auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); - EventMgr em(stream_exec); + EventMgr em(stream_exec, GPUOptions()); TEST_EventMgrHelper th(&em); EXPECT_EQ(0, th.queue_size()); EXPECT_EQ(0, th.free_size()); @@ -74,7 +97,7 @@ TEST(EventMgr, Empty) { // the max simultaneously pending, we should not allocate any more. TEST(EventMgr, DelayedPolling) { auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); - EventMgr em(stream_exec); + EventMgr em(stream_exec, GPUOptions()); TEST_EventMgrHelper th(&em); EXPECT_EQ(0, th.queue_size()); EventMgr::TensorReferenceVector* v = nullptr; @@ -103,22 +126,87 @@ TEST(EventMgr, DelayedPolling) { } } -// Immediate polling should require only one event to be allocated. -TEST(EventMgr, ImmediatePolling) { +static void AddTensorReference(EventMgr::TensorReferenceVector* v, int64 size) { + TestTensorBuffer* buf = new TestTensorBuffer(size); + v->push_back(TensorReference(buf)); + buf->Unref(); +} + +TEST(EventMgr, FlushLargeTensorImmediately) { auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); - EventMgr em(stream_exec); + EventMgr em(stream_exec, GPUOptions()); TEST_EventMgrHelper th(&em); - EXPECT_EQ(0, th.queue_size()); - EXPECT_EQ(0, th.free_size()); - EventMgr::TensorReferenceVector* v = nullptr; + EXPECT_EQ(0, live_tensor_bytes); std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec)); CHECK(stream.get()); stream->Init(); for (int i = 0; i < 5; ++i) { - v = new EventMgr::TensorReferenceVector; + EventMgr::TensorReferenceVector v; + AddTensorReference(&v, 100 * 1048576); em.ThenDeleteTensors(stream.get(), v); - EXPECT_EQ(0, th.queue_size()); - EXPECT_EQ(1, th.free_size()); + th.PollEvents(false); // Ensure things get registered to be freed by Poll + EXPECT_EQ(0, live_tensor_bytes); + } +} + +TEST(EventMgr, ManySmallTensorsFlushedImmediately) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec, GPUOptions()); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, live_tensor_bytes); + std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + EventMgr::TensorReferenceVector v; + for (int i = 0; i < 1000; i++) { + AddTensorReference(&v, 100 * 1024); + } + em.ThenDeleteTensors(stream.get(), v); + th.PollEvents(false); // Ensure things get registered to be freed by Poll + EXPECT_EQ(0, live_tensor_bytes); + } +} + +TEST(EventMgr, StreamSwitchingFlushesImmediately) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec, GPUOptions()); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, live_tensor_bytes); + std::unique_ptr<gpu::Stream> stream1(new gpu::Stream(stream_exec)); + std::unique_ptr<gpu::Stream> stream2(new gpu::Stream(stream_exec)); + stream1->Init(); + stream2->Init(); + EventMgr::TensorReferenceVector v1; + AddTensorReference(&v1, 1024); + em.ThenDeleteTensors(stream1.get(), v1); + + EventMgr::TensorReferenceVector v2; + AddTensorReference(&v2, 1024); + int64 initial_live_bytes = live_tensor_bytes; + em.ThenDeleteTensors(stream2.get(), v2); + th.PollEvents(false); // Ensure things get registered to be freed by Poll + // Different stream should cause first tensor to get deleted + EXPECT_GT(initial_live_bytes, live_tensor_bytes); +} + +TEST(EventMgr, ManySmallTensorsSeperateCallsFlushed) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec, GPUOptions()); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, live_tensor_bytes); + std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + for (int i = 0; i < 1000; i++) { + EventMgr::TensorReferenceVector v; + AddTensorReference(&v, 100 * 1024); + em.ThenDeleteTensors(stream.get(), v); + } + th.PollEvents(false); // Ensure things get registered to be freed by Poll + // Some of the tensors at least should be flushed + EXPECT_GT(1000 * 100 * 1024, live_tensor_bytes); } } @@ -126,16 +214,15 @@ TEST(EventMgr, ImmediatePolling) { // should clear the queue. TEST(EventMgr, LongDelayedPolling) { auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); - EventMgr em(stream_exec); + EventMgr em(stream_exec, GPUOptions()); TEST_EventMgrHelper th(&em); EXPECT_EQ(0, th.queue_size()); EXPECT_EQ(0, th.free_size()); - EventMgr::TensorReferenceVector* v = nullptr; std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec)); CHECK(stream.get()); stream->Init(); for (int i = 0; i < 5; ++i) { - v = new EventMgr::TensorReferenceVector; + EventMgr::TensorReferenceVector* v = new EventMgr::TensorReferenceVector; th.QueueTensors(stream.get(), v); EXPECT_EQ(1 + i, th.queue_size()); EXPECT_EQ(0, th.free_size()); @@ -149,16 +236,15 @@ TEST(EventMgr, LongDelayedPolling) { // down gracefully. TEST(EventMgr, NonEmptyShutdown) { auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); - EventMgr em(stream_exec); + EventMgr em(stream_exec, GPUOptions()); TEST_EventMgrHelper th(&em); EXPECT_EQ(0, th.queue_size()); EXPECT_EQ(0, th.free_size()); - EventMgr::TensorReferenceVector* v = nullptr; std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec)); CHECK(stream.get()); stream->Init(); for (int i = 0; i < 5; ++i) { - v = new EventMgr::TensorReferenceVector; + EventMgr::TensorReferenceVector* v = new EventMgr::TensorReferenceVector; th.QueueTensors(stream.get(), v); EXPECT_EQ(1 + i, th.queue_size()); EXPECT_EQ(0, th.free_size()); diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto index f4d946dcf0..d2e9f24563 100644 --- a/tensorflow/core/example/example.proto +++ b/tensorflow/core/example/example.proto @@ -4,6 +4,8 @@ syntax = "proto3"; import "tensorflow/core/example/feature.proto"; // option cc_enable_arenas = true; +option java_multiple_files = true; +option java_package = "org.tensorflow.example"; package tensorflow; @@ -163,12 +165,13 @@ message Example { // an empty list (zero length). // - If a FeatureList L exists, it may be empty (zero length). // - If a FeatureList L is non-empty, all features within the FeatureList -// must have data type T, and all features within the FeatureList must -// have the same size. +// must have data type T. +// - If a FeatureList L is non-empty, it is up to the parser configuration +// to determine if all features within the FeatureList must +// have the same size. The same holds for this FeatureList across multiple +// examples. // - If a FeatureList L exists in one example with data type T, // it must be of type T in all other examples when present. -// - If a FeatureList L exists in one example having features' sizes all S, -// these sizes must be S in all other examples when present. // // Examples of conformant and non-conformant examples' FeatureLists: // @@ -186,7 +189,8 @@ message Example { // feature: { int64_list: { value: [ 5 ] } } } // } } // -// Non-conformant FeatureLists (mismatched sizes): +// Conditionally conformant FeatureLists, the parser configuration determines +// if the feature sizes must match: // feature_lists: { feature_list: { // key: "movie_ratings" // value: { feature: { float_list: { value: [ 4.5 ] } } @@ -244,7 +248,8 @@ message Example { // feature: { int64_list: { value: [ 2 ] } } } // } } // -// Non-conformant pair of SequenceExample (mismatched sizes) +// Conditionally conformant pair of SequenceExample; the parser configuration +// determines if the feature sizes must match: // feature_lists: { feature_list: { // key: "movie_ratings" // value: { feature: { float_list: { value: [ 4.5 ] } } @@ -253,7 +258,7 @@ message Example { // and: // feature_lists: { feature_list: { // key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.0, 5.0 ] } } +// value: { feature: { float_list: { value: [ 4.0 ] } } // feature: { float_list: { value: [ 5.0, 3.0 ] } } // } } diff --git a/tensorflow/core/example/feature.proto b/tensorflow/core/example/feature.proto index 52d5fac441..130e142503 100644 --- a/tensorflow/core/example/feature.proto +++ b/tensorflow/core/example/feature.proto @@ -55,6 +55,8 @@ syntax = "proto3"; // option cc_enable_arenas = true; +option java_multiple_files = true; +option java_package = "org.tensorflow.example"; package tensorflow; diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index ccdcf35b34..004f65fe62 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/regexp.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" namespace tensorflow { diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto index 3f5d01fb8d..167dd632ac 100644 --- a/tensorflow/core/framework/config.proto +++ b/tensorflow/core/framework/config.proto @@ -19,6 +19,21 @@ message GPUOptions { // "BFC": A "Best-fit with coalescing" algorithm, simplified from a // version of dlmalloc. string allocator_type = 2; + + // Delay deletion of up to this many bytes to reduce the number of + // interactions with gpu driver code. If 0, the system chooses + // a reasonable default (several MBs). + int64 deferred_deletion_bytes = 3; +}; + +message GraphOptions { + // If true, do not attempt to optimize the graph using common + // subexpression elimination. + bool skip_common_subexpression_elimination = 1; + + // If true, use control flow to schedule the activation of Recv nodes. + // (Currently ignored.) + bool enable_recv_scheduling = 2; }; // Session configuration parameters. @@ -75,4 +90,7 @@ message ConfigProto { // Whether device placements should be logged. bool log_device_placement = 8; + + // Options that apply to all graphs. + GraphOptions graph_options = 10; }; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 14ffeca6e4..fc9f1d0324 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -25,48 +25,6 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_Arg") - .Output("output: T") - .Attr("T: type") - .Attr("index: int >= 0") - .Doc(R"doc( -A graph node which represents an argument to a function. - -output: The argument. -index: This argument is the index-th argument of the function. -)doc"); - -REGISTER_OP("_Retval") - .Input("input: T") - .Attr("T: type") - .Attr("index: int >= 0") - .Doc(R"doc( -A graph node which represents a return value of a function. - -input: The return value. -index: This return value is the index-th return value of the function. -)doc"); - -REGISTER_OP("_ListToArray") - .Input("input: Tin") - .Output("output: N * T") - .Attr("Tin: list(type)") - .Attr("T: type") - .Attr("N: int >= 1") - .Doc(R"doc( -Converts a list of tensors to an array of tensors. -)doc"); - -REGISTER_OP("_ArrayToList") - .Input("input: N * T") - .Output("output: out_types") - .Attr("T: type") - .Attr("N: int >= 1") - .Attr("out_types: list(type)") - .Doc(R"doc( -Converts an array of tensors to a list of tensors. -)doc"); - namespace { // Extracts the actual type from "attr_values" based on its definition diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto index 8bf4fd5e5f..d18dd81912 100644 --- a/tensorflow/core/framework/graph.proto +++ b/tensorflow/core/framework/graph.proto @@ -21,6 +21,7 @@ message GraphDef { // 0. Graphs created before GraphDef versioning // 1. First real version (2dec2015) // 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015) + // 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015) // // The GraphDef version is distinct from the TensorFlow version. // Each released version of TensorFlow will support a range of diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc new file mode 100644 index 0000000000..0d6b8563b0 --- /dev/null +++ b/tensorflow/core/framework/load_library.cc @@ -0,0 +1,76 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <dlfcn.h> +#include <memory> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +namespace { + +template <typename R, typename... Args> +Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + R (**symbol)(Args...)) { + Env* env = Env::Default(); + void* symbol_ptr; + Status status = env->GetSymbolFromLibrary(handle, symbol_name, &symbol_ptr); + *symbol = reinterpret_cast<R (*)(Args...)>(symbol_ptr); + return status; +} + +} // namespace + +// Load a dynamic library and register the ops and kernels defined in that file. +// Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList" to be +// defined in the library. +// On success, returns the handle to library in result, copies the serialized +// OpList of OpDefs registered in the library to *buf and the length to *len, +// and returns OK from the function. Otherwise return nullptr in result +// and an error status from the function, leaving buf and len untouched. +Status LoadLibrary(const char* library_filename, void** result, + const void** buf, size_t* len) { + Env* env = Env::Default(); + void* lib; + TF_RETURN_IF_ERROR(env->LoadLibrary(library_filename, &lib)); + + typedef void (*FuncType)(void*); + FuncType RegisterOps, RegisterKernels, GetOpList; + TF_RETURN_IF_ERROR(GetSymbolFromLibrary(lib, "RegisterOps", &RegisterOps)); + TF_RETURN_IF_ERROR( + GetSymbolFromLibrary(lib, "RegisterKernels", &RegisterKernels)); + TF_RETURN_IF_ERROR(GetSymbolFromLibrary(lib, "GetOpList", &GetOpList)); + + *buf = nullptr; + *len = 0; + + RegisterOps(OpRegistry::Global()); + RegisterKernels(GlobalKernelRegistry()); + string str; + GetOpList(&str); + char* str_buf = reinterpret_cast<char*>(operator new(str.length())); + strncpy(str_buf, str.data(), str.length()); + *buf = str_buf; + *len = str.length(); + + *result = lib; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index bd246cda80..06e913ec4e 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -33,14 +33,13 @@ OpRegistryInterface::~OpRegistryInterface() {} OpRegistry::OpRegistry() : initialized_(false) {} -void OpRegistry::Register(std::function<OpDef(void)> func) { +void OpRegistry::Register(const OpDef& op_def) { mutex_lock lock(mu_); if (initialized_) { - OpDef def = func(); - TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " - << SummarizeOpDef(def); + TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: " + << SummarizeOpDef(op_def); } else { - deferred_.push_back(func); + deferred_.push_back(op_def); } } @@ -75,6 +74,14 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name, return op_def; } +void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) { + mutex_lock lock(mu_); + CallDeferred(); + for (auto p : registry_) { + op_defs->push_back(*p.second); + } +} + void OpRegistry::Export(bool include_internal, OpList* ops) const { mutex_lock lock(mu_); CallDeferred(); @@ -107,10 +114,9 @@ string OpRegistry::DebugString(bool include_internal) const { bool OpRegistry::CallDeferred() const { if (initialized_) return false; initialized_ = true; - for (const auto& fn : deferred_) { - OpDef def = fn(); - TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " - << SummarizeOpDef(def); + for (const auto& op_def : deferred_) { + TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: " + << SummarizeOpDef(op_def); } deferred_.clear(); return true; @@ -136,12 +142,25 @@ OpRegistry* OpRegistry::Global() { namespace register_op { OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilder& builder) { - OpRegistry::Global()->Register([builder]() { - OpDef op_def; - TF_QCHECK_OK(builder.Finalize(&op_def)); - return op_def; - }); + OpDef op_def; + builder.Finalize(&op_def); + OpRegistry::Global()->Register(op_def); } } // namespace register_op +extern "C" void RegisterOps(void* registry_ptr) { + OpRegistry* op_registry = static_cast<OpRegistry*>(registry_ptr); + std::vector<OpDef> op_defs; + OpRegistry::Global()->GetRegisteredOps(&op_defs); + for (auto const& op_def : op_defs) { + op_registry->Register(op_def); + } +} + +extern "C" void GetOpList(void* str) { + OpList op_list; + OpRegistry::Global()->Export(true, &op_list); + op_list.SerializeToString(reinterpret_cast<string*>(str)); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 6e20a0fb4a..2a6a34f28e 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_FRAMEWORK_OP_H_ #define TENSORFLOW_FRAMEWORK_OP_H_ -#include <functional> #include <unordered_map> #include "tensorflow/core/framework/op_def.pb.h" @@ -65,7 +64,7 @@ class OpRegistry : public OpRegistryInterface { // we defer calling func() until the first call to LookUp() or // Export() (if one of those has already been called, func() is // called immediately). - void Register(std::function<OpDef(void)> func); + void Register(const OpDef& op_def); const OpDef* LookUp(const string& op_type_name, Status* status) const override; @@ -81,6 +80,9 @@ class OpRegistry : public OpRegistryInterface { // A singleton available at startup. static OpRegistry* Global(); + // Get all registered ops. + void GetRegisteredOps(std::vector<OpDef>* op_defs); + private: // Ensures that all the functions in deferred_ get called, their OpDef's // registered, and returns with deferred_ empty. Returns true the first @@ -94,11 +96,17 @@ class OpRegistry : public OpRegistryInterface { mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. - mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_); + mutable std::vector<OpDef> deferred_ GUARDED_BY(mu_); mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_); mutable bool initialized_ GUARDED_BY(mu_); }; +// Treats 'registry_ptr' as a pointer to OpRegistry, and calls +// registry_ptr->Register(op_def) for each op_def that has been registered with +// the current library's global op registry (obtained by calling +// OpRegistry::Global(). +extern "C" void RegisterOps(void* registry_ptr); + // Support for defining the OpDef (specifying the semantics of the Op and how // it should be created) and registering it in the OpRegistry::Global() // registry. Usage: diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 767d2a0466..a59ddb3e9b 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/regexp.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/regexp.h" namespace tensorflow { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 6947ebfc7f..b984966148 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -442,17 +442,27 @@ struct KernelRegistration { // KernelDef. typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry; -static KernelRegistry* GlobalKernelRegistry() { +void* GlobalKernelRegistry() { static KernelRegistry* global_kernel_registry = new KernelRegistry; return global_kernel_registry; } +static KernelRegistry* GlobalKernelRegistryTyped() { + return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); +} + static string Key(const string& op_type, DeviceType device_type, const string& label) { return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", label); } +extern "C" void RegisterKernels(void* registry_ptr) { + KernelRegistry* kernel_registry = static_cast<KernelRegistry*>(registry_ptr); + kernel_registry->insert(GlobalKernelRegistryTyped()->begin(), + GlobalKernelRegistryTyped()->end()); +} + namespace kernel_factory { OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def, @@ -460,7 +470,7 @@ OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def, const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); - GlobalKernelRegistry()->insert( + GlobalKernelRegistryTyped()->insert( std::make_pair(key, KernelRegistration(*kernel_def, factory))); delete kernel_def; } @@ -533,7 +543,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, string label; // Label defaults to empty if not found in NodeDef. GetNodeAttr(node_def, "_kernel", &label); const string key = Key(node_def.op(), device_type, label); - auto regs = GlobalKernelRegistry()->equal_range(key); + auto regs = GlobalKernelRegistryTyped()->equal_range(key); for (auto iter = regs.first; iter != regs.second; ++iter) { // If there is a kernel registered for the op and device_type, // check that the attrs match. @@ -730,7 +740,7 @@ bool FindArgInOp(const string& arg_name, Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) { Status unused_status; - for (const auto& key_registration : *GlobalKernelRegistry()) { + for (const auto& key_registration : *GlobalKernelRegistryTyped()) { const KernelDef& kernel_def(key_registration.second.def); const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status); if (op_def == nullptr) { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index a68a170fde..dedd600b05 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -966,6 +966,13 @@ typedef ::tensorflow::KernelDefBuilder Name; +[](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) +void* GlobalKernelRegistry(); + +// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k' +// registered with the current library's global kernel registry (obtained by +// calling GlobalKernelRegistry()), inserts 'k' into registry_ptr. +extern "C" void RegisterKernels(void* registry_ptr); + namespace kernel_factory { class OpKernelRegistrar { diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 9a597a1042..8176b91a5e 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -43,18 +43,39 @@ string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, ":", frame_iter.iter_id); } +// Return the prefix of "*s" up to the next occurrence of "delim", or +// the whole remaining string if "delim" is not found. "*s" is advanced +// past the string returned plus the delimiter (if found). +static StringPiece ConsumeNextPart(StringPiece* s, char delim) { + for (int offset = 0; offset < s->size(); offset++) { + if ((*s)[offset] == delim) { + StringPiece result(s->data(), offset); + s->remove_prefix(offset + 1); // +1: remove delim, as well + return result; + } + } + // No delimiter found: return rest of string + StringPiece result(s->data(), s->size()); + s->remove_prefix(s->size()); + return result; +} + /* static */ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { - // TODO(zhifengc): This code is not fast enough. - std::vector<string> parts = str_util::Split(key, ';'); - if (parts.size() == 5 && + StringPiece s(key); + StringPiece parts[5]; + for (int i = 0; i < 5; i++) { + parts[i] = ConsumeNextPart(&s, ';'); + } + if (s.empty() && // Consumed the whole string + !parts[4].empty() && // Exactly five parts DeviceNameUtils::ParseFullName(parts[0], &out->src) && - strings::StringToFp(parts[1], &out->src_incarnation) && + strings::StringToFp(parts[1].ToString(), &out->src_incarnation) && DeviceNameUtils::ParseFullName(parts[2], &out->dst) && !parts[3].empty()) { - out->src_device = parts[0]; - out->dst_device = parts[2]; - out->edge_name = parts[3]; + out->src_device.assign(parts[0].data(), parts[0].size()); + out->dst_device.assign(parts[2].data(), parts[2].size()); + out->edge_name.assign(parts[3].data(), parts[3].size()); return Status::OK(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); diff --git a/tensorflow/core/framework/tensor_reference.h b/tensorflow/core/framework/tensor_reference.h index 88853130a0..e700bb4b6d 100644 --- a/tensorflow/core/framework/tensor_reference.h +++ b/tensorflow/core/framework/tensor_reference.h @@ -38,6 +38,17 @@ class TensorReference { if (buf_) buf_->Unref(); } + // Return an estimate of the total bytes being kept alive by this reference. + size_t TotalBytes() const { + // We add 128 as a baseline to account for per-Tensor metadata + return 128 + (buf_ ? buf_->size() : 0); + } + + // A constructor used only for tests + explicit TensorReference(TensorBuffer* test_buffer) : buf_(test_buffer) { + if (buf_) buf_->Ref(); + } + private: TensorBuffer* buf_; }; diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index aaaf226bbd..31a470e4ea 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -19,6 +19,8 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/platform/logging.h" + namespace tensorflow { void DFS(const Graph& g, std::function<void(Node*)> enter, @@ -78,14 +80,18 @@ void PruneForReverseReachability(Graph* g, // nodes, and accumulating the visited nodes. std::deque<const Node*> queue; for (const Node* n : nodes) { - queue.push_back(n); + if (visited.insert(n).second) { + VLOG(2) << "Reverse reach init: " << n->name(); + queue.push_back(n); + } } while (!queue.empty()) { const Node* n = queue.front(); queue.pop_front(); - if (visited.insert(n).second) { - for (const Node* in : n->in_nodes()) { + for (const Node* in : n->in_nodes()) { + if (visited.insert(in).second) { queue.push_back(in); + VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name(); } } } diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 4459d0b54b..e74033bd98 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -53,7 +53,7 @@ class GraphConstructor { *status = errors::InvalidArgument( "GraphDef version ", version, " is ", low ? "no longer" : "not yet", " supported: TensorFlow ", TF_VERSION_STRING, " needs ", - TF_GRAPH_DEF_VERSION_MAX, " <= version <= ", TF_GRAPH_DEF_VERSION_MIN, + TF_GRAPH_DEF_VERSION_MIN, " <= version <= ", TF_GRAPH_DEF_VERSION_MAX, ". ", low ? "Please regenerate your graph." : "Please upgrade TensorFlow."); return; @@ -150,8 +150,8 @@ void GraphConstructor::BuildNodeIndex() { SetNodeError(node_def, "Node name contains invalid characters"); return; } - if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()), - NodeInfo(n))) + if (!name_index_ + .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n))) .second) { SetNodeError(node_def, "Node name is not unique"); return; @@ -346,8 +346,8 @@ void GraphConstructor::Convert() { if (opts_.optimizer_do_cse) { if (!back_edges.empty()) { - LOG(WARNING) << "Not doing CSE. We need to figure out how to handle " - << "loops in the CSE phase."; + VLOG(1) << "Not doing CSE. We need to figure out how to handle " + << "loops in the CSE phase."; } else { VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes"; OptimizeCSE(g_, opts_.cse_consider_function); @@ -392,6 +392,9 @@ void CopyGraph(const Graph& src, Graph* dest) { CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; } + // Copy GraphDef version + dest->set_version(src.version()); + // Copy the nodes std::unordered_map<Node*, Node*> node_map; // "Node in src" -> "Node in *dest" diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 7706a3d0c6..657dfa4e90 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -180,16 +180,20 @@ TEST_F(GraphConstructorTest, VersionGraph) { TEST_F(GraphConstructorTest, LowVersion) { ExpectError(strings::StrCat("version: ", -1), - R"(^GraphDef version -1 is no longer supported: TensorFlow \S+ )" - R"(needs \d+ <= version <= \d+\. )" - R"(Please regenerate your graph\.$)"); + strings::StrCat(R"(^GraphDef version -1 is no longer supported: )" + R"(TensorFlow \S+ needs )", + TF_GRAPH_DEF_VERSION_MIN, " <= version <= ", + TF_GRAPH_DEF_VERSION_MAX, + R"(. Please regenerate your graph\.$)")); } TEST_F(GraphConstructorTest, HighVersion) { ExpectError(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION_MAX + 1), - R"(^GraphDef version \d+ is not yet supported: TensorFlow \S+ )" - R"(needs \d+ <= version <= \d+\. )" - R"(Please upgrade TensorFlow\.$)"); + strings::StrCat(R"(^GraphDef version \d+ is not yet supported: )" + R"(TensorFlow \S+ needs )", + TF_GRAPH_DEF_VERSION_MIN, " <= version <= ", + TF_GRAPH_DEF_VERSION_MAX, + R"(. Please upgrade TensorFlow\.$)")); } TEST_F(GraphConstructorTest, SimpleModel) { @@ -231,5 +235,16 @@ TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) { "Node 't2': Control dependencies must come after regular dependencies"); } +TEST_F(GraphConstructorTest, CopyGraph) { + const int version = TF_GRAPH_DEF_VERSION - 1; + + Graph src(OpRegistry::Global()); + src.set_version(version); + + Graph dst(OpRegistry::Global()); + CopyGraph(src, &dst); + EXPECT_EQ(dst.version(), version); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc index 18f7cb083d..b3fdba055c 100644 --- a/tensorflow/core/kernels/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -38,7 +38,7 @@ template <typename Device, typename T> class AdjustContrastOp : public OpKernel { public: explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) { - OP_DEPRECATED(context, 2); + OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead"); } void Compute(OpKernelContext* context) override { diff --git a/tensorflow/core/kernels/cwise_op_erf.cc b/tensorflow/core/kernels/cwise_op_erf.cc new file mode 100644 index 0000000000..02f6b4b8d1 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_erf.cc @@ -0,0 +1,23 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "Erf", functor::erf, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Erf", functor::erf, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_erfc.cc b/tensorflow/core/kernels/cwise_op_erfc.cc new file mode 100644 index 0000000000..65862d4082 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_erfc.cc @@ -0,0 +1,23 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "Erfc", functor::erfc, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Erfc", functor::erfc, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc new file mode 100644 index 0000000000..a1e31a1b2f --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(erf, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc new file mode 100644 index 0000000000..260463c8bf --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(erfc, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc new file mode 100644 index 0000000000..8105ac1694 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(lgamma, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_lgamma.cc b/tensorflow/core/kernels/cwise_op_lgamma.cc new file mode 100644 index 0000000000..6985e5c6ba --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_lgamma.cc @@ -0,0 +1,23 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "Lgamma", functor::lgamma, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Lgamma", functor::lgamma, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 7f42a4ca2b..1aec41fdc0 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -342,6 +342,15 @@ template <typename T> struct tanh : base<T, Eigen::internal::scalar_tanh_op<T> > {}; template <typename T> +struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T> > {}; + +template <typename T> +struct erf : base<T, Eigen::internal::scalar_erf_op<T> > {}; + +template <typename T> +struct erfc : base<T, Eigen::internal::scalar_erfc_op<T> > {}; + +template <typename T> struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T> > {}; template <typename T> diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index 599b05525c..e17c0da061 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -469,9 +469,13 @@ class SingleSequenceExampleParserOp : public OpKernel { ctx, ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Ncontext_sparse", &num_context_sparse_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcontext_dense", &context_dense_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_sparse_types", + &feature_list_sparse_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types_)); OP_REQUIRES_OK( + ctx, ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse_)); + OP_REQUIRES_OK( ctx, ctx->GetAttr("context_dense_shapes", &context_dense_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes_)); @@ -488,6 +492,11 @@ class SingleSequenceExampleParserOp : public OpKernel { context_dense_shapes_.size(), errors::InvalidArgument( "len(context_dense_keys) != len(context_dense_shapes")); + OP_REQUIRES( + ctx, static_cast<size_t>(num_feature_list_sparse_) == + feature_list_sparse_types_.size(), + errors::InvalidArgument( + "len(feature_list_sparse_keys) != len(feature_list_sparse_types")); OP_REQUIRES(ctx, static_cast<size_t>(num_feature_list_dense_) == feature_list_dense_types_.size(), errors::InvalidArgument("len(feature_list_dense_keys) != " @@ -501,6 +510,9 @@ class SingleSequenceExampleParserOp : public OpKernel { for (const DataType& type : feature_list_dense_types_) { OP_REQUIRES_OK(ctx, CheckValidType(type)); } + for (const DataType& type : feature_list_sparse_types_) { + OP_REQUIRES_OK(ctx, CheckValidType(type)); + } } void Compute(OpKernelContext* ctx) override { @@ -510,6 +522,7 @@ class SingleSequenceExampleParserOp : public OpKernel { OpInputList context_sparse_keys; OpInputList context_dense_defaults; OpInputList feature_list_dense_keys; + OpInputList feature_list_sparse_keys; const Tensor* feature_list_dense_missing_assumed_empty; OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name)); @@ -522,16 +535,20 @@ class SingleSequenceExampleParserOp : public OpKernel { &feature_list_dense_keys)); OP_REQUIRES_OK( ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys)); + OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys", + &feature_list_sparse_keys)); OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults", &context_dense_defaults)); std::vector<string> context_dense_keys_t(num_context_dense_); std::vector<string> context_sparse_keys_t(num_context_sparse_); std::vector<string> feature_list_dense_keys_t(num_feature_list_dense_); + std::vector<string> feature_list_sparse_keys_t(num_feature_list_sparse_); std::unordered_set<string> feature_list_dense_missing_assumed_empty_set; CHECK_EQ(context_dense_keys.size(), num_context_dense_); CHECK_EQ(context_sparse_keys.size(), num_context_sparse_); CHECK_EQ(feature_list_dense_keys.size(), num_feature_list_dense_); + CHECK_EQ(feature_list_sparse_keys.size(), num_feature_list_sparse_); for (int di = 0; di < num_context_dense_; ++di) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(context_dense_keys[di].shape()), @@ -560,6 +577,16 @@ class SingleSequenceExampleParserOp : public OpKernel { feature_list_dense_keys_t[di] = feature_list_dense_keys[di].scalar<string>()(); } + for (int di = 0; di < num_feature_list_sparse_; ++di) { + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()), + errors::InvalidArgument( + "Expected feature_list_sparse_keys[", di, + "] to be a vector, got shape: ", + feature_list_sparse_keys[di].shape().ShortDebugString())); + feature_list_sparse_keys_t[di] = + feature_list_sparse_keys[di].scalar<string>()(); + } OP_REQUIRES(ctx, TensorShapeUtils::IsVector( feature_list_dense_missing_assumed_empty->shape()), errors::InvalidArgument( @@ -622,6 +649,9 @@ class SingleSequenceExampleParserOp : public OpKernel { OpOutputList context_sparse_values; OpOutputList context_sparse_shapes; OpOutputList context_dense_values; + OpOutputList feature_list_sparse_indices; + OpOutputList feature_list_sparse_values; + OpOutputList feature_list_sparse_shapes; OpOutputList feature_list_dense_values; OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", @@ -632,6 +662,14 @@ class SingleSequenceExampleParserOp : public OpKernel { ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes)); OP_REQUIRES_OK( ctx, ctx->output_list("context_dense_values", &context_dense_values)); + OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", + &context_sparse_indices)); + OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices", + &feature_list_sparse_indices)); + OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values", + &feature_list_sparse_values)); + OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes", + &feature_list_sparse_shapes)); OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values", &feature_list_dense_values)); @@ -784,16 +822,78 @@ class SingleSequenceExampleParserOp : public OpKernel { feature_list_dense_values[d])); } } + + // Feature List Sparse ----------------------------------------------------- + for (int d = 0; d < num_feature_list_sparse_; ++d) { + const string& key = feature_list_sparse_keys_t[d]; + const DataType& dtype = feature_list_sparse_types_[d]; + + const auto& feature_list_found = feature_list_dict.find(key); + bool feature_list_has_data = // Found key + (feature_list_found != feature_list_dict.end()); + + std::vector<Tensor> sparse_values_tmp; + int64 feature_list_size = 0; + if (feature_list_has_data) { + const FeatureList& fl = feature_list_found->second; + feature_list_size = fl.feature_size(); + for (int64 t = 0; t < feature_list_size; ++t) { + const Feature& f = fl.feature(t); + bool types_match; + OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); + OP_REQUIRES( + ctx, types_match, + errors::InvalidArgument( + "Name: ", name, ", Feature List: ", key, ", Index: ", t, + ". Data types don't match. ", "Expected type: ", + DataTypeString(dtype), " Feature is: ", f.DebugString())); + sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f)); + } + } else { + sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0}))); + } + + int64 total_num_features = 0; + int64 max_num_features = 0; + for (int t = 0; t < feature_list_size; ++t) { + const Tensor& v = sparse_values_tmp[t]; + const int64 num_elements = v.shape().num_elements(); + total_num_features += num_elements; + max_num_features = std::max(max_num_features, num_elements); + } + + TensorShape indices_shape({total_num_features, 2}); + TensorShape values_shape({total_num_features}); + Tensor* sp_indices_d = nullptr; + Tensor* sp_values_d = nullptr; + Tensor* sp_shape_d = nullptr; + feature_list_sparse_indices.allocate(d, indices_shape, &sp_indices_d); + feature_list_sparse_values.allocate(d, values_shape, &sp_values_d); + feature_list_sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d); + auto shape_t = sp_shape_d->vec<int64>(); + shape_t(0) = feature_list_size; + shape_t(1) = max_num_features; + + int64 offset = 0; + + for (int t = 0; t < feature_list_size; ++t) { + const int64 num_elements = CopyIntoSparseTensor( + sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d); + offset += num_elements; + } + } } protected: int64 num_context_sparse_; int64 num_context_dense_; + int64 num_feature_list_sparse_; int64 num_feature_list_dense_; std::vector<DataType> context_sparse_types_; std::vector<DataType> context_dense_types_; - std::vector<DataType> feature_list_dense_types_; std::vector<TensorShape> context_dense_shapes_; + std::vector<DataType> feature_list_sparse_types_; + std::vector<DataType> feature_list_dense_types_; std::vector<TensorShape> feature_list_dense_shapes_; }; diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 470a7a3eb6..9153e5a31b 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -345,7 +345,14 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, HANDLE_TYPE(DT_INT16); HANDLE_TYPE(DT_INT8); HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_COMPLEX64); HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_BOOL); + HANDLE_TYPE(DT_QINT8); + HANDLE_TYPE(DT_QUINT8); + HANDLE_TYPE(DT_QINT32); + HANDLE_TYPE(DT_QINT16); + HANDLE_TYPE(DT_QUINT16); #undef HANDLE_TYPE return errors::Unimplemented("Unhandled data type: ", parent.dtype()); } @@ -365,7 +372,14 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, HANDLE_TYPE(DT_INT16); HANDLE_TYPE(DT_INT8); HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_COMPLEX64); HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_BOOL); + HANDLE_TYPE(DT_QINT8); + HANDLE_TYPE(DT_QUINT8); + HANDLE_TYPE(DT_QINT32); + HANDLE_TYPE(DT_QINT16); + HANDLE_TYPE(DT_QUINT16); #undef HANDLE_TYPE return errors::Unimplemented("Unhandled data type: ", element.dtype()); } diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index e8db7106ef..44911c9d36 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -24,6 +24,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/kernels/transpose_op.h" #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/status.h" #include "tensorflow/core/public/tensor.h" @@ -76,7 +78,7 @@ class ReductionHelper { Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims) { // bitmap[i] indicates whether to reduce data along i-th axis. - std::vector<bool> bitmap(data.dims(), false); + gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); auto axis_vec = axis.flat<int32>(); for (int64 i = 0; i < axis.NumElements(); ++i) { const int32 index = axis_vec(i); @@ -194,11 +196,43 @@ class ReductionHelper { return data.shaped<T, N>(data_reshape_); } + // Shape of shuffled input + const gtl::ArraySlice<int64> data_reshape() const { return data_reshape_; } + + // Shape with all reduction dimensions at the end + TensorShape shuffled_shape() { + const int dims = data_reshape_.size(); + TensorShape shape; + for (int i = reduce_first_axis_; i < dims; i += 2) { + shape.AddDim(data_reshape_[i]); + } + for (int i = !reduce_first_axis_; i < dims; i += 2) { + shape.AddDim(data_reshape_[i]); + } + return shape; + } + + // Permutation of reduced dims needed to put reduction dimensions at the end + gtl::InlinedVector<int32, 8> permutation() { + const int dims = data_reshape_.size(); + const int unreduced_dims = (dims + !reduce_first_axis_) / 2; + gtl::InlinedVector<int32, 8> perm(dims); + for (int i = 0; i < unreduced_dims; i++) { + perm[i] = 2 * i + reduce_first_axis_; + } + for (int i = unreduced_dims; i < dims; i++) { + perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_; + } + return perm; + } + private: bool reduce_first_axis_; // True if need to reduce the 0-th dimension. - std::vector<int64> data_reshape_; // Reshape the data before reduction. - std::vector<int64> out_shape_; // The final output shape. - std::vector<int64> out_reshape_; // Reshape the output for reduction. + gtl::InlinedVector<int64, 4> + data_reshape_; // Reshape the data before reduction. + gtl::InlinedVector<int64, 4> out_shape_; // The final output shape. + gtl::InlinedVector<int64, 4> + out_reshape_; // Reshape the output for reduction. }; } // end namespace @@ -252,6 +286,9 @@ class ReductionOp : public OpKernel { const Device& d = ctx->eigen_device<Device>(); Reducer reducer; + if (tmp_out.NumElements() == 0) { + // Nothing to do, fall through to final reshaping. + } if ((helper.ndims() == 1) && helper.reduce_first_axis()) { // Reduce to a scalar. Functor::Reduce(d, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data), @@ -274,15 +311,20 @@ class ReductionOp : public OpKernel { Functor::Reduce(d, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data), constants.kOne, reducer); } else { - // TODO(zhifengc): We can implement reduction for arbitrary rank - // tensor and arbitrary reduction axes by iterating the reduction - // multiple times. This may also be accomplished in the graph - // construction. - ctx->SetStatus( - errors::Unimplemented("Reducing ", data.shape().ShortDebugString(), - " axes [", axes.SummarizeValue(10), "] to ", - tmp_out.shape().ShortDebugString())); - return; + // If we don't hit one of the cases above, transpose the data so that + // all reduced dimensions are last and reuse the 2-D -> 1-D case. + Tensor shuffled; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::value, + helper.shuffled_shape(), &shuffled)); + TransposeTensor<Device, T>(d, data, helper.data_reshape(), + helper.permutation(), &shuffled); + const int64 unreduced = tmp_out.NumElements(); + const int64 reduced = shuffled.NumElements() / unreduced; + const Tensor& const_shuffled = shuffled; + Functor::Reduce(d, tmp_out.flat<T>(), + const_shuffled.shaped<T, 2>({unreduced, reduced}), + constants.kOne, reducer); } // Set the real output using the contents of the reduction but the diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 8f908109ed..f1260746bf 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -39,9 +39,6 @@ class ReshapeOp : public OpKernel { errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes.shape().ShortDebugString())); const int64 num_dims = sizes.NumElements(); - OP_REQUIRES( - context, num_dims <= 8, - errors::InvalidArgument(num_dims, " > max 8 output dims supported")); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one. diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc index 3de5132049..7759dbdc0f 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -41,7 +41,10 @@ namespace tensorflow { template <typename T, typename Index> class SparseToDense : public OpKernel { public: - explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {} + explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("validate_indices", &validate_indices_)); + } void Compute(OpKernelContext* c) override { // sparse_indices @@ -111,17 +114,28 @@ class SparseToDense : public OpKernel { sparse_values_b = sparse_values; } + // Assume SparseTensor is lexicographically sorted. gtl::InlinedVector<int64, 8> order(output->shape().dims()); - std::iota(order.begin(), order.end(), 0); // Assume order is correct + std::iota(order.begin(), order.end(), 0); sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(), order); + if (validate_indices_) { + OP_REQUIRES(c, st.IndicesValid(), + errors::InvalidArgument("Indices are not valid: not " + "lexicographically sorted or " + "containing repeats.")); + } + output->flat<T>().setConstant(default_value.scalar<T>()()); OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */), errors::InvalidArgument( "Indices are not valid (out of bounds). Shape: ", output->shape().DebugString())); } + + private: + bool validate_indices_; }; #define REGISTER_KERNELS(type, index_type) \ diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc index 972889f878..4031e90857 100644 --- a/tensorflow/core/kernels/summary_op.cc +++ b/tensorflow/core/kernels/summary_op.cc @@ -61,15 +61,6 @@ class SummaryScalarOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("ScalarSummary") - .Device(DEVICE_CPU) - .TypeConstraint<float>("T"), - SummaryScalarOp<float>); -REGISTER_KERNEL_BUILDER(Name("ScalarSummary") - .Device(DEVICE_CPU) - .TypeConstraint<double>("T"), - SummaryScalarOp<double>); - template <typename T> class SummaryHistoOp : public OpKernel { public: @@ -108,6 +99,9 @@ class SummaryHistoOp : public OpKernel { #define REGISTER(T) \ REGISTER_KERNEL_BUILDER( \ + Name("ScalarSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + SummaryScalarOp<T>); \ + REGISTER_KERNEL_BUILDER( \ Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ SummaryHistoOp<T>); TF_CALL_REAL_NUMBER_TYPES(REGISTER) diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index d4c6ecb6c7..971cef6ddb 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -65,14 +65,17 @@ class TileOp : public OpKernel { TensorShape output_shape; for (int i = 0; i < input_dims; ++i) { OP_REQUIRES( - context, multiples_array[i] > 0, - errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ", + context, multiples_array[i] >= 0, + errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiples_array[i])); output_shape.AddDim(input.dim_size(i) * multiples_array[i]); } Tensor* result = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result)); + // If there's no output, there's nothing to do. + if (output_shape.num_elements() == 0) return; + #define HANDLE_DIM(DT, NDIM) \ if (context->input(0).dtype() == DT && input_dims == NDIM) { \ HandleCase<DT, NDIM>(context, multiples_array, result); \ @@ -180,7 +183,9 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64); template <typename Device> class TileGradientOp : public OpKernel { public: - explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {} + explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) { + OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum"); + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index ad312b8e7b..7d9c2a90e5 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -97,8 +97,8 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) { perm.shape().DebugString())); auto Vperm = perm.vec<int32>(); const int dims = input.dims(); - static const int kMinDims = 1; - static const int kMaxDims = 8; + static const int kMinDims = 0; + static const int kMaxDims = 10; OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, errors::Unimplemented("Transposing a tensor of rank ", dims, " is not implemented.")); @@ -125,20 +125,35 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) { str_util::Join(permutation, ","), "}.")); } + // 0-D and 1-D transposes do nothing + if (dims <= 1) { + context->set_output(0, input); + return; + } + Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); + TransposeTensor<Device, T>(context->eigen_device<Device>(), input, + input.shape().dim_sizes(), permutation, output); +} + +template <typename Device, typename T> +void TransposeTensor(const Device& device, const Tensor& input, + const gtl::ArraySlice<int64> input_shape, + gtl::ArraySlice<int32> permutation, Tensor* output) { + const int dims = input_shape.size(); + CHECK(permutation.size() == dims); if (input.NumElements() == 0) { return; } switch (dims) { -#define EXPAND_DIM(N) \ - case N: { \ - functor::TransposeFunctor<Device, T, N> func; \ - func(context->eigen_device<Device>(), output->tensor<T, N>(), \ - input.tensor<T, N>(), permutation.data()); \ - break; \ +#define EXPAND_DIM(N) \ + case N: { \ + functor::TransposeFunctor<Device, T, N> func; \ + func(device, output->tensor<T, N>(), input.shaped<T, N>(input_shape), \ + permutation.data()); \ + break; \ } - EXPAND_DIM(1); EXPAND_DIM(2); EXPAND_DIM(3); EXPAND_DIM(4); @@ -146,6 +161,8 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) { EXPAND_DIM(6); EXPAND_DIM(7); EXPAND_DIM(8); + EXPAND_DIM(9); + EXPAND_DIM(10); default: LOG(FATAL) << "Unexpected dims: " << dims; } @@ -179,13 +196,16 @@ struct TransposeFunctor<CPUDevice, T, NDIMS> { } // namespace functor -#define REGISTER(D, T) \ - template class TransposeOp<D##Device, T>; \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_##D) \ - .TypeConstraint<T>("T") \ - .HostMemory("perm"), \ - TransposeOp<D##Device, T>) +#define REGISTER(D, T) \ + template class TransposeOp<D##Device, T>; \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T") \ + .HostMemory("perm"), \ + TransposeOp<D##Device, T>); \ + template void TransposeTensor<D##Device, T>( \ + const D##Device&, const Tensor&, const gtl::ArraySlice<int64>, \ + gtl::ArraySlice<int32>, Tensor*); REGISTER(CPU, float); REGISTER(CPU, double); REGISTER(CPU, complex64); @@ -195,6 +215,7 @@ REGISTER(CPU, int16); REGISTER(CPU, int32); REGISTER(CPU, int64); REGISTER(CPU, string); +REGISTER(CPU, bool); #if GOOGLE_CUDA REGISTER(GPU, uint8); REGISTER(GPU, int8); @@ -203,6 +224,8 @@ REGISTER(GPU, int32); REGISTER(GPU, int64); REGISTER(GPU, float); REGISTER(GPU, double); +REGISTER(GPU, complex64); +REGISTER(GPU, bool); #endif #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h index f4d36fea54..15cd1b6488 100644 --- a/tensorflow/core/kernels/transpose_op.h +++ b/tensorflow/core/kernels/transpose_op.h @@ -29,6 +29,12 @@ class TransposeOp : public OpKernel { void Compute(OpKernelContext* context) override; }; +// Exposed for use in reduction ops +template <typename Device, typename T> +void TransposeTensor(const Device& device, const Tensor& input, + const gtl::ArraySlice<int64> input_shape, + gtl::ArraySlice<int32> permutation, Tensor* output); + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc index c2f720a121..b8d664b95f 100644 --- a/tensorflow/core/kernels/transpose_op_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc @@ -17,8 +17,9 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/kernels/transpose_op_functor.h" +#include "tensorflow/core/platform/port.h" namespace tensorflow { namespace functor { @@ -34,14 +35,15 @@ struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> { #define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>; #define DEFINE_DIM(T) \ - DEFINE(T, 1); \ DEFINE(T, 2); \ DEFINE(T, 3); \ DEFINE(T, 4); \ DEFINE(T, 5); \ DEFINE(T, 6); \ DEFINE(T, 7); \ - DEFINE(T, 8); + DEFINE(T, 8); \ + DEFINE(T, 9); \ + DEFINE(T, 10); DEFINE_DIM(uint8); DEFINE_DIM(int8); DEFINE_DIM(int16); @@ -49,6 +51,8 @@ DEFINE_DIM(int32); DEFINE_DIM(int64); DEFINE_DIM(float); DEFINE_DIM(double); +DEFINE_DIM(complex64); +DEFINE_DIM(bool); #undef DEFINE_DIM #undef DEFINE diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index b4a42c45f8..50fc40e18d 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -103,19 +103,19 @@ using ::tensorflow::error::OK; // } // Declares an op deprecated, and illegal starting at GraphDef version VERSION -#define OP_DEPRECATED(CTX, VERSION) \ +#define OP_DEPRECATED(CTX, VERSION, NOTE) \ if ((CTX)->graph_def_version() >= (VERSION)) { \ ::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \ "Op ", (CTX)->op_def().name(), \ " is not available in GraphDef version ", (CTX)->graph_def_version(), \ - ". It has been removed in version ", (VERSION), ".")); \ + ". It has been removed in version ", (VERSION), ". ", (NOTE), ".")); \ VLOG(1) << _s; \ (CTX)->SetStatus(_s); \ return; \ } else { \ LOG(WARNING) << "Op is deprecated." \ << " It will cease to work in GraphDef version " << (VERSION) \ - << "."; \ + << ". " << (NOTE) << "."; \ } #define OP_REQUIRES(CTX, EXP, STATUS) \ diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc index 55514c0242..722f1da48c 100644 --- a/tensorflow/core/lib/io/inputbuffer.cc +++ b/tensorflow/core/lib/io/inputbuffer.cc @@ -61,7 +61,10 @@ Status InputBuffer::ReadLine(string* result) { // We don't append the '\n' to *result return Status::OK(); } - *result += c; + // We don't append '\r' to *result + if (c != '\r') { + *result += c; + } } if (errors::IsOutOfRange(s) && !result->empty()) { return Status::OK(); diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc index 5e4888b727..d424336e06 100644 --- a/tensorflow/core/lib/io/inputbuffer_test.cc +++ b/tensorflow/core/lib/io/inputbuffer_test.cc @@ -116,6 +116,32 @@ TEST(InputBuffer, ReadLine_EmptyLines) { } } +TEST(InputBuffer, ReadLine_CRLF) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\r\n\r\n\r\nline two\r\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + TEST(InputBuffer, ReadNBytes) { Env* env = Env::Default(); string fname = testing::TmpDir() + "/inputbuffer_test"; diff --git a/tensorflow/core/lib/strings/regexp.h b/tensorflow/core/lib/strings/regexp.h new file mode 100644 index 0000000000..aaf58f8139 --- /dev/null +++ b/tensorflow/core/lib/strings/regexp.h @@ -0,0 +1,33 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_ + +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +// Conversion to/from the appropriate StringPiece type for using in RE2 +inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) { + return RegexpStringPiece(sp.data(), sp.size()); +} +inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) { + return tensorflow::StringPiece(sp.data(), sp.size()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_ diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index cd972e2371..134a49e631 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -81,9 +81,7 @@ void TitlecaseString(string* s, StringPiece delimiters); // Join functionality template <typename T> -string Join(const std::vector<T>& s, const char* sep); -template <typename T> -string Join(const gtl::ArraySlice<T>& s, const char* sep); +string Join(const T& s, const char* sep); struct AllowEmpty { bool operator()(StringPiece sp) const { return true; } @@ -110,31 +108,16 @@ bool SplitAndParseAsInts(StringPiece text, char delim, // ------------------------------------------------------------------ // Implementation details below -namespace internal { template <typename T> -string JoinHelper(typename gtl::ArraySlice<T>::const_iterator begin, - typename gtl::ArraySlice<T>::const_iterator end, - const char* sep) { +string Join(const T& s, const char* sep) { string result; bool first = true; - for (typename gtl::ArraySlice<T>::const_iterator it = begin; it != end; - ++it) { - tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it); + for (const auto& x : s) { + tensorflow::strings::StrAppend(&result, (first ? "" : sep), x); first = false; } return result; } -} // namespace internal - -template <typename T> -string Join(const std::vector<T>& s, const char* sep) { - return Join<T>(gtl::ArraySlice<T>(s), sep); -} - -template <typename T> -string Join(const gtl::ArraySlice<T>& s, const char* sep) { - return internal::JoinHelper<T>(s.begin(), s.end(), sep); -} inline std::vector<string> Split(StringPiece text, char delim) { return Split(text, delim, AllowEmpty()); diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index c9d782f1c5..0bad9feb7a 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -388,6 +388,7 @@ REGISTER_OP("HashTable") .Attr("shared_name: string = ''") .Attr("key_dtype: type") .Attr("value_dtype: type") + .SetIsStateful() .Doc(R"doc( Creates a non-initialized hash table. diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc new file mode 100644 index 0000000000..3842c025b3 --- /dev/null +++ b/tensorflow/core/ops/function_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" + +#include <unordered_set> + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +REGISTER_OP("_Arg") + .Output("output: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents an argument to a function. + +output: The argument. +index: This argument is the index-th argument of the function. +)doc"); + +REGISTER_OP("_Retval") + .Input("input: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents a return value of a function. + +input: The return value. +index: This return value is the index-th return value of the function. +)doc"); + +REGISTER_OP("_ListToArray") + .Input("input: Tin") + .Output("output: N * T") + .Attr("Tin: list(type)") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Converts a list of tensors to an array of tensors. +)doc"); + +REGISTER_OP("_ArrayToList") + .Input("input: N * T") + .Output("output: out_types") + .Attr("T: type") + .Attr("N: int >= 1") + .Attr("out_types: list(type)") + .Doc(R"doc( +Converts an array of tensors to a list of tensors. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3f359334b0..88e2b34d6a 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -188,6 +188,24 @@ REGISTER_OP("Tanh") Computes hyperbolic tangent of `x` element-wise. )doc"); +REGISTER_OP("Lgamma") + .UNARY() + .Doc(R"doc( +Computes the log of the absolute value of Gamma of `x` element-wise. +)doc"); + +REGISTER_OP("Erf") + .UNARY() + .Doc(R"doc( +Computes the Gauss error function of `x` element-wise. +)doc"); + +REGISTER_OP("Erfc") + .UNARY() + .Doc(R"doc( +Computes the complementary error function of `x` element-wise. +)doc"); + REGISTER_OP("Sigmoid") .UNARY() .Doc(R"doc( diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b90d6b2ddc..56f70f9420 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -2413,6 +2413,56 @@ op { is_commutative: true } op { + name: "Erf" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_INT64 + } + } + } + summary: "Computes the Gauss error function of `x` element-wise." +} +op { + name: "Erfc" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_INT64 + } + } + } + summary: "Computes the complementary error function of `x` element-wise." +} +op { name: "Exit" input_arg { name: "data" @@ -2949,6 +2999,7 @@ op { } summary: "Creates a non-initialized hash table." description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable." + is_stateful: true } op { name: "HistogramSummary" @@ -3554,6 +3605,31 @@ op { summary: "Returns the truth value of (x <= y) element-wise." } op { + name: "Lgamma" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_INT64 + } + } + } + summary: "Computes the log of the absolute value of Gamma of `x` element-wise." +} +op { name: "LinSpace" input_arg { name: "start" @@ -4731,6 +4807,12 @@ op { number_attr: "Ncontext_dense" } input_arg { + name: "feature_list_sparse_keys" + description: "A list of Nfeature_list_sparse string Tensors\n(scalars). The keys expected in the FeatureLists associated with sparse\nvalues." + type: DT_STRING + number_attr: "Nfeature_list_sparse" + } + input_arg { name: "feature_list_dense_keys" description: "A list of Nfeature_list_dense string Tensors (scalars).\nThe keys expected in the SequenceExamples\' feature_lists associated\nwith lists of dense values." type: DT_STRING @@ -4765,27 +4847,62 @@ op { type_list_attr: "Tcontext_dense" } output_arg { + name: "feature_list_sparse_indices" + type: DT_INT64 + number_attr: "Nfeature_list_sparse" + } + output_arg { + name: "feature_list_sparse_values" + type_list_attr: "feature_list_sparse_types" + } + output_arg { + name: "feature_list_sparse_shapes" + type: DT_INT64 + number_attr: "Nfeature_list_sparse" + } + output_arg { name: "feature_list_dense_values" type_list_attr: "feature_list_dense_types" } attr { name: "Ncontext_sparse" type: "int" + default_value { + i: 0 + } has_minimum: true } attr { name: "Ncontext_dense" type: "int" + default_value { + i: 0 + } + has_minimum: true + } + attr { + name: "Nfeature_list_sparse" + type: "int" + default_value { + i: 0 + } has_minimum: true } attr { name: "Nfeature_list_dense" type: "int" + default_value { + i: 0 + } has_minimum: true } attr { name: "context_sparse_types" type: "list(type)" + default_value { + list { + } + } description: "A list of Ncontext_sparse types; the data types of data in\neach context Feature given in context_sparse_keys.\nCurrently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),\nDT_INT64 (Int64List), and DT_STRING (BytesList)." has_minimum: true allowed_values { @@ -4799,6 +4916,10 @@ op { attr { name: "Tcontext_dense" type: "list(type)" + default_value { + list { + } + } has_minimum: true allowed_values { list { @@ -4811,6 +4932,10 @@ op { attr { name: "feature_list_dense_types" type: "list(type)" + default_value { + list { + } + } has_minimum: true allowed_values { list { @@ -4823,12 +4948,37 @@ op { attr { name: "context_dense_shapes" type: "list(shape)" + default_value { + list { + } + } description: "A list of Ncontext_dense shapes; the shapes of data in\neach context Feature given in context_dense_keys.\nThe number of elements in the Feature corresponding to context_dense_key[j]\nmust always equal context_dense_shapes[j].NumEntries().\nThe shape of context_dense_values[j] will match context_dense_shapes[j]." has_minimum: true } attr { + name: "feature_list_sparse_types" + type: "list(type)" + default_value { + list { + } + } + description: "A list of Nfeature_list_sparse types; the data types\nof data in each FeatureList given in feature_list_sparse_keys.\nCurrently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),\nDT_INT64 (Int64List), and DT_STRING (BytesList)." + has_minimum: true + allowed_values { + list { + type: DT_FLOAT + type: DT_INT64 + type: DT_STRING + } + } + } + attr { name: "feature_list_dense_shapes" type: "list(shape)" + default_value { + list { + } + } description: "A list of Nfeature_list_dense shapes; the shapes of\ndata in each FeatureList given in feature_list_dense_keys.\nThe shape of each Feature in the FeatureList corresponding to\nfeature_list_dense_key[j] must always equal\nfeature_list_dense_shapes[j].NumEntries()." has_minimum: true } @@ -4986,6 +5136,39 @@ op { description: "Reduces `input` along the dimensions given in `reduction_indices`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_indices`. If `keep_dims` is true, the reduced dimensions are\nretained with length 1." } op { + name: "PyFunc" + input_arg { + name: "input" + description: "List of Tensors that will provide input to the Op." + type_list_attr: "Tin" + } + output_arg { + name: "output" + description: "The outputs from the Op." + type_list_attr: "Tout" + } + attr { + name: "token" + type: "string" + description: "A token representing a registered python function in this address space." + } + attr { + name: "Tin" + type: "list(type)" + description: "Data types of the inputs to the op." + has_minimum: true + minimum: 1 + } + attr { + name: "Tout" + type: "list(type)" + description: "Data types of the outputs from the op.\nThe length of the list specifies the number of outputs." + has_minimum: true + minimum: 1 + } + summary: "Invokes a python function to compute func(input)->output." +} +op { name: "QueueClose" input_arg { name: "handle" @@ -6354,12 +6537,12 @@ op { name: "ScalarSummary" input_arg { name: "tags" - description: "1-D. Tags for the summary." + description: "Tags for the summary." type: DT_STRING } input_arg { name: "values" - description: "1-D, same size as `tags. Values for the summary." + description: "Same shape as `tags. Values for the summary." type_attr: "T" } output_arg { @@ -6374,6 +6557,11 @@ op { list { type: DT_FLOAT type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 } } } @@ -7806,6 +7994,14 @@ op { type_attr: "T" } attr { + name: "validate_indices" + type: "bool" + default_value { + b: true + } + description: "If true, indices are checked to make sure they are sorted in\nlexicographic order and that there are no repeats." + } + attr { name: "T" type: "type" } @@ -7820,7 +8016,7 @@ op { } } summary: "Converts a sparse representation into a dense tensor." - description: "Builds an array `dense` with shape `output_shape` such that\n\n```prettyprint\n# If sparse_indices is scalar\ndense[i] = (i == sparse_indices ? sparse_values : default_value)\n\n# If sparse_indices is a vector, then for each i\ndense[sparse_indices[i]] = sparse_values[i]\n\n# If sparse_indices is an n by d matrix, then for each i in [0, n)\ndense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]\n```\n\nAll other values in `dense` are set to `default_value`. If `sparse_values` is a\nscalar, all sparse indices are set to this single value." + description: "Builds an array `dense` with shape `output_shape` such that\n\n```prettyprint\n# If sparse_indices is scalar\ndense[i] = (i == sparse_indices ? sparse_values : default_value)\n\n# If sparse_indices is a vector, then for each i\ndense[sparse_indices[i]] = sparse_values[i]\n\n# If sparse_indices is an n by d matrix, then for each i in [0, n)\ndense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]\n```\n\nAll other values in `dense` are set to `default_value`. If `sparse_values` is a\nscalar, all sparse indices are set to this single value.\n\nIndices should be sorted in lexicographic order, and indices must not\ncontain any repeats. If `validate_indices` is true, these properties\nare checked during execution." } op { name: "Split" diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index eb40783206..150e36ffad 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -87,6 +87,7 @@ REGISTER_OP("ParseSingleSequenceExample") .Input("feature_list_dense_missing_assumed_empty: string") .Input("context_sparse_keys: Ncontext_sparse * string") .Input("context_dense_keys: Ncontext_dense * string") + .Input("feature_list_sparse_keys: Nfeature_list_sparse * string") .Input("feature_list_dense_keys: Nfeature_list_dense * string") .Input("context_dense_defaults: Tcontext_dense") .Input("debug_name: string") @@ -94,16 +95,24 @@ REGISTER_OP("ParseSingleSequenceExample") .Output("context_sparse_values: context_sparse_types") .Output("context_sparse_shapes: Ncontext_sparse * int64") .Output("context_dense_values: Tcontext_dense") + .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64") + .Output("feature_list_sparse_values: feature_list_sparse_types") + .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64") .Output("feature_list_dense_values: feature_list_dense_types") - .Attr("Ncontext_sparse: int >= 0") // Infer from context_sparse_keys - .Attr("Ncontext_dense: int >= 0") // Infer from context_dense_keys - .Attr( - "Nfeature_list_dense: int >= 0") // Infer from feature_list_dense_keys - .Attr("context_sparse_types: list({float,int64,string}) >= 0") - .Attr("Tcontext_dense: list({float,int64,string}) >= 0") - .Attr("feature_list_dense_types: list({float,int64,string}) >= 0") - .Attr("context_dense_shapes: list(shape) >= 0") - .Attr("feature_list_dense_shapes: list(shape) >= 0") + // Infer from context_sparse_keys + .Attr("Ncontext_sparse: int >= 0 = 0") + // Infer from context_dense_keys + .Attr("Ncontext_dense: int >= 0 = 0") + // Infer from feature_list_sparse_keys + .Attr("Nfeature_list_sparse: int >= 0 = 0") + // Infer from feature_list_dense_keys + .Attr("Nfeature_list_dense: int >= 0 = 0") + .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []") + .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []") + .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []") + .Attr("context_dense_shapes: list(shape) >= 0 = []") + .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []") + .Attr("feature_list_dense_shapes: list(shape) >= 0 = []") .Doc(R"doc( Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. @@ -148,6 +157,13 @@ context_sparse_types: A list of Ncontext_sparse types; the data types of data in each context Feature given in context_sparse_keys. Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), DT_INT64 (Int64List), and DT_STRING (BytesList). +feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors + (scalars). The keys expected in the FeatureLists associated with sparse + values. +feature_list_sparse_types: A list of Nfeature_list_sparse types; the data types + of data in each FeatureList given in feature_list_sparse_keys. + Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), + DT_INT64 (Int64List), and DT_STRING (BytesList). )doc"); REGISTER_OP("DecodeCSV") diff --git a/tensorflow/core/ops/script_ops.cc b/tensorflow/core/ops/script_ops.cc new file mode 100644 index 0000000000..7b6d6d7c81 --- /dev/null +++ b/tensorflow/core/ops/script_ops.cc @@ -0,0 +1,37 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("PyFunc") + .Input("input: Tin") + .Output("output: Tout") + .Attr("token: string") + .Attr("Tin: list(type)") + .Attr("Tout: list(type)") + .Doc(R"doc( +Invokes a python function to compute func(input)->output. + +token: A token representing a registered python function in this address space. +input: List of Tensors that will provide input to the Op. +output: The outputs from the Op. +Tin: Data types of the inputs to the op. +Tout: Data types of the outputs from the op. + The length of the list specifies the number of outputs. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 57913c8a30..ea8bf95ab3 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -114,8 +114,9 @@ REGISTER_OP("SparseToDense") .Input("output_shape: Tindices") .Input("sparse_values: T") .Input("default_value: T") - .Output("dense: T") + .Attr("validate_indices: bool = true") .Attr("T: type") + .Output("dense: T") .Attr("Tindices: {int32, int64}") .Doc(R"doc( Converts a sparse representation into a dense tensor. @@ -136,6 +137,10 @@ dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] All other values in `dense` are set to `default_value`. If `sparse_values` is a scalar, all sparse indices are set to this single value. +Indices should be sorted in lexicographic order, and indices must not +contain any repeats. If `validate_indices` is true, these properties +are checked during execution. + sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete index where `sparse_values[i]` will be placed. output_shape: 1-D. Shape of the dense output tensor. @@ -143,6 +148,8 @@ sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, or a scalar value to be used for all sparse indices. default_value: Scalar value to set for indices not specified in `sparse_indices`. +validate_indices: If true, indices are checked to make sure they are sorted in + lexicographic order and that there are no repeats. dense: Dense output tensor of shape `output_shape`. )doc"); diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc index 63fa4a8b5c..33a7a614ab 100644 --- a/tensorflow/core/ops/summary_ops.cc +++ b/tensorflow/core/ops/summary_ops.cc @@ -24,15 +24,15 @@ REGISTER_OP("ScalarSummary") .Input("tags: string") .Input("values: T") .Output("summary: string") - .Attr("T: {float, double}") + .Attr("T: realnumbertype") .Doc(R"doc( Outputs a `Summary` protocol buffer with scalar values. The input `tags` and `values` must have the same shape. The generated summary has a summary value for each tag-value pair in `tags` and `values`. -tags: 1-D. Tags for the summary. -values: 1-D, same size as `tags. Values for the summary. +tags: Tags for the summary. +values: Same shape as `tags. Values for the summary. summary: Scalar. Serialized `Summary` protocol buffer. )doc"); diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 9497b48726..d24276c547 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/public/env.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -30,32 +31,31 @@ Thread::~Thread() {} EnvWrapper::~EnvWrapper() {} Status ReadFileToString(Env* env, const string& fname, string* data) { - data->clear(); + uint64 file_size; + Status s = env->GetFileSize(fname, &file_size); + if (!s.ok()) { + return s; + } RandomAccessFile* file; - Status s = env->NewRandomAccessFile(fname, &file); + s = env->NewRandomAccessFile(fname, &file); if (!s.ok()) { return s; } - int64 offset = 0; - static const int kBufferSize = 8192; - char* space = new char[kBufferSize]; - while (true) { - StringPiece fragment; - s = file->Read(offset, kBufferSize, &fragment, space); - if (!s.ok()) { - if (errors::IsOutOfRange(s)) { // No more bytes, but not an error - s = Status::OK(); - data->append(fragment.data(), fragment.size()); - } - break; - } - offset += fragment.size(); - data->append(fragment.data(), fragment.size()); - if (fragment.empty()) { - break; - } + gtl::STLStringResizeUninitialized(data, file_size); + char* p = gtl::string_as_array(data); + StringPiece result; + s = file->Read(0, file_size, &result, p); + if (!s.ok()) { + data->clear(); + } else if (result.size() != file_size) { + s = errors::Aborted("File ", fname, " changed while reading: ", file_size, + " vs. ", result.size()); + data->clear(); + } else if (result.data() == p) { + // Data is already in the correct location + } else { + memmove(p, result.data(), result.size()); } - delete[] space; delete file; return s; } diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 678d427242..ca4cdc8d66 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -27,7 +27,8 @@ struct EnvTest {}; TEST(EnvTest, ReadFileToString) { Env* env = Env::Default(); const string dir = testing::TmpDir(); - for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000}) { + for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1, + 1 << 20, (1 << 20) + 1}) { const string filename = io::JoinPath(dir, strings::StrCat("file", length)); // Write a file with the given length diff --git a/tensorflow/core/platform/load_library.cc b/tensorflow/core/platform/load_library.cc new file mode 100644 index 0000000000..aff2562d95 --- /dev/null +++ b/tensorflow/core/platform/load_library.cc @@ -0,0 +1,44 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <dlfcn.h> + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace internal { + +Status LoadLibrary(const char* library_filename, void** handle) { + *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL); + if (!*handle) { + return errors::NotFound("Unable to find library ", library_filename); + } + return Status::OK(); +} + +Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { + *symbol = dlsym(handle, symbol_name); + if (!*symbol) { + return errors::NotFound("Unable to find symbol ", symbol_name, + " in library"); + } + return Status::OK(); +} + +} // namespace internal + +} // namespace tensorflow diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h new file mode 100644 index 0000000000..eb546acc55 --- /dev/null +++ b/tensorflow/core/platform/load_library.h @@ -0,0 +1,33 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_ +#define TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_ + +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace internal { + +Status LoadLibrary(const char* library_filename, void** handle); +Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol); + +} // namespace internal + +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_ diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index 2c8daf98a5..164d11a81f 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -26,6 +26,7 @@ limitations under the License. #include <thread> #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/platform/load_library.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/env.h" @@ -397,9 +398,20 @@ class PosixEnv : public Env { // TODO(mrry): Replace with a non-blocking timer mechanism and threadpool. CHECK(false) << "PosixEnv::SchedClosureAfter not implemented."; } + + Status LoadLibrary(const char* library_filename, void** handle) override { + return tensorflow::internal::LoadLibrary(library_filename, handle); + } + + Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { + return tensorflow::internal::GetSymbolFromLibrary(handle, symbol_name, + symbol); + } }; } // namespace + #if defined(PLATFORM_POSIX) || defined(__ANDROID__) Env* Env::Default() { static Env* default_env = new PosixEnv; diff --git a/tensorflow/core/platform/regexp.h b/tensorflow/core/platform/regexp.h index 8432f47289..52fb475062 100644 --- a/tensorflow/core/platform/regexp.h +++ b/tensorflow/core/platform/regexp.h @@ -33,16 +33,4 @@ typedef re2::StringPiece RegexpStringPiece; #endif -namespace tensorflow { - -// Conversion to/from the appropriate StringPiece type for using in RE2 -inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) { - return RegexpStringPiece(sp.data(), sp.size()); -} -inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) { - return tensorflow::StringPiece(sp.data(), sp.size()); -} - -} // namespace tensorflow - #endif // TENSORFLOW_PLATFORM_REGEXP_H_ diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index 23d70fc3eb..0ce7bd379d 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -152,6 +152,9 @@ class Tracing::Engine { Engine() {} virtual ~Engine(); + // Returns true if Tracing is currently enabled. + virtual bool IsEnabled() const = 0; + // Represents an active annotation. class Annotation { public: @@ -225,7 +228,7 @@ class Tracing::TraceMe { inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) { auto e = Tracing::engine(); - if (e) { + if (e && e->IsEnabled()) { annotation_.reset(e->PushAnnotation(name)); } } @@ -233,7 +236,7 @@ inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) { inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name_part1, StringPiece name_part2) { auto e = Tracing::engine(); - if (e) { + if (e && e->IsEnabled()) { annotation_.reset( e->PushAnnotation(strings::StrCat(name_part1, ":", name_part2))); } @@ -241,7 +244,7 @@ inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name_part1, inline Tracing::TraceMe::TraceMe(StringPiece name) { auto e = Tracing::engine(); - if (e) { + if (e && e->IsEnabled()) { tracer_.reset(e->StartTracing(name)); } } diff --git a/tensorflow/core/public/env.h b/tensorflow/core/public/env.h index ac34a02c89..e40fe8974f 100644 --- a/tensorflow/core/public/env.h +++ b/tensorflow/core/public/env.h @@ -145,6 +145,27 @@ class Env { // NOTE(mrry): This closure must not block. virtual void SchedClosureAfter(int micros, std::function<void()> closure) = 0; + // \brief Load a dynamic library. + // + // Pass "library_filename" to a platform-specific mechanism for dynamically + // loading a library. The rules for determining the exact location of the + // library are platform-specific and are not documented here. + // + // On success, returns a handle to the library in "*handle" and returns + // OK from the function. + // Otherwise returns nullptr in "*handle" and an error status from the + // function. + virtual Status LoadLibrary(const char* library_filename, void** handle) = 0; + + // \brief Get a pointer to a symbol from a dynamic library. + // + // "handle" should be a pointer returned from a previous call to LoadLibrary. + // On success, store a pointer to the located symbol in "*symbol" and return + // OK from the function. Otherwise, returns nullptr in "*symbol" and an error + // status from the function. + virtual Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) = 0; + private: /// No copying allowed Env(const Env&); @@ -251,6 +272,13 @@ class EnvWrapper : public Env { void SchedClosureAfter(int micros, std::function<void()> closure) override { target_->SchedClosureAfter(micros, closure); } + Status LoadLibrary(const char* library_filename, void** handle) override { + return target_->LoadLibrary(library_filename, handle); + } + Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { + return target_->GetSymbolFromLibrary(handle, symbol_name, symbol); + } private: Env* target_; diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h index 5d90e80342..22219d1413 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/core/public/tensor_c_api.h @@ -117,6 +117,21 @@ typedef enum { // else an error code with an associated error message. typedef struct TF_Status TF_Status; +// -------------------------------------------------------------------------- +// TF_Buffer holds a pointer to a block of data and its associated length. +// Typically, the data consists of a serialized protocol buffer, but other data +// may also be held in a buffer. +// +// TF_Buffer itself does not do any memory management of the pointed-to block. +typedef struct { + const void* data; + size_t length; +} TF_Buffer; + +// -------------------------------------------------------------------------- +// TF_Library holds information about dynamically loaded TensorFlow plugins. +typedef struct TF_Library TF_Library; + // Return a new status object. extern TF_Status* TF_NewStatus(); @@ -253,6 +268,32 @@ extern void TF_Run(TF_Session*, // Output status TF_Status*); +// -------------------------------------------------------------------------- +// Load plugins containing custom ops and kernels + +// Load the library specified by library_filename and register the ops and +// kernels present in that library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be +// defined in the library. +// +// On success, place OK in status and return the newly created library handle. +// The caller owns the library handle. +// +// On failure, place an error status in status and return nullptr. +extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); + +// Get the OpList of OpDefs defined in the library pointed by lib_handle. +// +// Returns a TF_Buffer. The memory pointed to by the result is owned by +// lib_handle. The data in the buffer will be the serialized OpList proto for +// ops defined in the library. +extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0884538d52..5863a3d782 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -38,7 +38,7 @@ limitations under the License. // Supported GraphDef versions (see graph.proto). #define TF_GRAPH_DEF_VERSION_MIN 0 -#define TF_GRAPH_DEF_VERSION_MAX 1 +#define TF_GRAPH_DEF_VERSION_MAX 3 #define TF_GRAPH_DEF_VERSION TF_GRAPH_DEF_VERSION_MAX #endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h index 9681dc4c18..19aee104dd 100644 --- a/tensorflow/core/util/bcast.h +++ b/tensorflow/core/util/bcast.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -71,7 +72,7 @@ class BCast { // element is the outer-most dimension and the last element is the // inner-most dimension. Note that we do not use TensorShape since // it's more convenient to manipulate Vec directly for this module. - typedef std::vector<int64> Vec; + typedef gtl::InlinedVector<int64, 4> Vec; BCast(const Vec& x, const Vec& y); ~BCast() {} diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD index 4fc90730d5..55eff77960 100644 --- a/tensorflow/examples/tutorials/mnist/BUILD +++ b/tensorflow/examples/tutorials/mnist/BUILD @@ -22,7 +22,7 @@ py_library( name = "input_data", srcs = ["input_data.py"], srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], + visibility = ["//tensorflow:internal"], deps = ["//tensorflow:tensorflow_py"], ) diff --git a/tensorflow/examples/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py index ae3727c82e..bca3bcf008 100644 --- a/tensorflow/examples/tutorials/mnist/input_data.py +++ b/tensorflow/examples/tutorials/mnist/input_data.py @@ -21,9 +21,12 @@ from __future__ import print_function import gzip import os +import tensorflow.python.platform + import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' @@ -91,9 +94,18 @@ def extract_labels(filename, one_hot=False): class DataSet(object): - def __init__(self, images, labels, fake_data=False, one_hot=False): - """Construct a DataSet. one_hot arg is used only if fake_data is true.""" - + def __init__(self, images, labels, fake_data=False, one_hot=False, + dtype=tf.float32): + """Construct a DataSet. + + one_hot arg is used only if fake_data is true. `dtype` can be either + `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into + `[0, 1]`. + """ + dtype = tf.as_dtype(dtype).base_dtype + if dtype not in (tf.uint8, tf.float32): + raise TypeError('Invalid image dtype %r, expected uint8 or float32' % + dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot @@ -108,9 +120,10 @@ class DataSet(object): assert images.shape[3] == 1 images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]) - # Convert from [0, 255] -> [0.0, 1.0]. - images = images.astype(numpy.float32) - images = numpy.multiply(images, 1.0 / 255.0) + if dtype == tf.float32: + # Convert from [0, 255] -> [0.0, 1.0]. + images = images.astype(numpy.float32) + images = numpy.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 @@ -160,15 +173,17 @@ class DataSet(object): return self._images[start:end], self._labels[start:end] -def read_data_sets(train_dir, fake_data=False, one_hot=False): +def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): class DataSets(object): pass data_sets = DataSets() if fake_data: - data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot) - data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot) - data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot) + def fake(): + return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) + data_sets.train = fake() + data_sets.validation = fake() + data_sets.test = fake() return data_sets TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' @@ -194,8 +209,9 @@ def read_data_sets(train_dir, fake_data=False, one_hot=False): train_images = train_images[VALIDATION_SIZE:] train_labels = train_labels[VALIDATION_SIZE:] - data_sets.train = DataSet(train_images, train_labels) - data_sets.validation = DataSet(validation_images, validation_labels) - data_sets.test = DataSet(test_images, test_labels) + data_sets.train = DataSet(train_images, train_labels, dtype=dtype) + data_sets.validation = DataSet(validation_images, validation_labels, + dtype=dtype) + data_sets.test = DataSet(test_images, test_labels, dtype=dtype) return data_sets diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnv.md b/tensorflow/g3doc/api_docs/cc/ClassEnv.md index 38bb94ac63..ca50ca1563 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassEnv.md +++ b/tensorflow/g3doc/api_docs/cc/ClassEnv.md @@ -36,6 +36,10 @@ All Env implementations are safe for concurrent access from multiple threads wit * Sleeps/delays the thread for the prescribed number of micro-seconds. * [`virtual Thread* tensorflow::Env::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) TF_MUST_USE_RESULT=0`](#virtual_Thread_tensorflow_Env_StartThread) * Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name". +* [`virtual void tensorflow::Env::SchedClosure(std::function< void()> closure)=0`](#virtual_void_tensorflow_Env_SchedClosure) +* [`virtual void tensorflow::Env::SchedClosureAfter(int micros, std::function< void()> closure)=0`](#virtual_void_tensorflow_Env_SchedClosureAfter) +* [`virtual Status tensorflow::Env::LoadLibrary(const char *library_filename, void **handle)=0`](#virtual_Status_tensorflow_Env_LoadLibrary) +* [`virtual Status tensorflow::Env::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol)=0`](#virtual_Status_tensorflow_Env_GetSymbolFromLibrary) * [`static Env* tensorflow::Env::Default()`](#static_Env_tensorflow_Env_Default) * Returns a default environment suitable for the current operating system. @@ -137,6 +141,30 @@ Returns a new thread that is running fn() and is identified (for debugging/perfo Caller takes ownership of the result and must delete it eventually (the deletion will block until fn() stops running). +#### `virtual void tensorflow::Env::SchedClosure(std::function< void()> closure)=0` {#virtual_void_tensorflow_Env_SchedClosure} + + + + + +#### `virtual void tensorflow::Env::SchedClosureAfter(int micros, std::function< void()> closure)=0` {#virtual_void_tensorflow_Env_SchedClosureAfter} + + + + + +#### `virtual Status tensorflow::Env::LoadLibrary(const char *library_filename, void **handle)=0` {#virtual_Status_tensorflow_Env_LoadLibrary} + + + + + +#### `virtual Status tensorflow::Env::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol)=0` {#virtual_Status_tensorflow_Env_GetSymbolFromLibrary} + + + + + #### `static Env* tensorflow::Env::Default()` {#static_Env_tensorflow_Env_Default} Returns a default environment suitable for the current operating system. diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md index 9ed2a97016..bdfb1af1d4 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md +++ b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md @@ -37,6 +37,10 @@ May be useful to clients who wish to override just part of the functionality of * Sleeps/delays the thread for the prescribed number of micro-seconds. * [`Thread* tensorflow::EnvWrapper::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) override`](#Thread_tensorflow_EnvWrapper_StartThread) * Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name". +* [`void tensorflow::EnvWrapper::SchedClosure(std::function< void()> closure) override`](#void_tensorflow_EnvWrapper_SchedClosure) +* [`void tensorflow::EnvWrapper::SchedClosureAfter(int micros, std::function< void()> closure) override`](#void_tensorflow_EnvWrapper_SchedClosureAfter) +* [`Status tensorflow::EnvWrapper::LoadLibrary(const char *library_filename, void **handle) override`](#Status_tensorflow_EnvWrapper_LoadLibrary) +* [`Status tensorflow::EnvWrapper::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol) override`](#Status_tensorflow_EnvWrapper_GetSymbolFromLibrary) ##Member Details @@ -141,3 +145,27 @@ Sleeps/delays the thread for the prescribed number of micro-seconds. Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name". Caller takes ownership of the result and must delete it eventually (the deletion will block until fn() stops running). + +#### `void tensorflow::EnvWrapper::SchedClosure(std::function< void()> closure) override` {#void_tensorflow_EnvWrapper_SchedClosure} + + + + + +#### `void tensorflow::EnvWrapper::SchedClosureAfter(int micros, std::function< void()> closure) override` {#void_tensorflow_EnvWrapper_SchedClosureAfter} + + + + + +#### `Status tensorflow::EnvWrapper::LoadLibrary(const char *library_filename, void **handle) override` {#Status_tensorflow_EnvWrapper_LoadLibrary} + + + + + +#### `Status tensorflow::EnvWrapper::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol) override` {#Status_tensorflow_EnvWrapper_GetSymbolFromLibrary} + + + + diff --git a/tensorflow/g3doc/api_docs/cc/ClassSession.md b/tensorflow/g3doc/api_docs/cc/ClassSession.md index a0fe3a4a30..ffe51ca310 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassSession.md +++ b/tensorflow/g3doc/api_docs/cc/ClassSession.md @@ -33,7 +33,7 @@ if (output_tensor(0) > 0.5) { ... } // Close the session to release the resources associated with // this session. -session->Close() +session->Close(); ``` diff --git a/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md b/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md new file mode 100644 index 0000000000..e0e46a8794 --- /dev/null +++ b/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md @@ -0,0 +1,24 @@ +# Struct `TF_Buffer` + + + + + +##Member Summary + +* [`const void* TF_Buffer::data`](#const_void_TF_Buffer_data) +* [`size_t TF_Buffer::length`](#size_t_TF_Buffer_length) + +##Member Details + +#### `const void* TF_Buffer::data` {#const_void_TF_Buffer_data} + + + + + +#### `size_t TF_Buffer::length` {#size_t_TF_Buffer_length} + + + + diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md index 2bb24375cb..97abde341e 100644 --- a/tensorflow/g3doc/api_docs/cc/index.md +++ b/tensorflow/g3doc/api_docs/cc/index.md @@ -46,6 +46,7 @@ write the graph to a file. * [tensorflow::TensorShape](ClassTensorShape.md) * [tensorflow::TensorShapeDim](StructTensorShapeDim.md) * [tensorflow::TensorShapeUtils](ClassTensorShapeUtils.md) +* [TF_Buffer](StructTF_Buffer.md) ## Thread @@ -68,6 +69,7 @@ write the graph to a file. <!-- ClassTensorShape.md --> <!-- StructTensorShapeDim.md --> <!-- ClassTensorShapeUtils.md --> +<!-- StructTF_Buffer.md --> <!-- ClassThread.md --> <!-- StructThreadOptions.md --> --> diff --git a/tensorflow/g3doc/api_docs/index.md b/tensorflow/g3doc/api_docs/index.md index 7e41a44d7f..f58624cf51 100644 --- a/tensorflow/g3doc/api_docs/index.md +++ b/tensorflow/g3doc/api_docs/index.md @@ -1,4 +1,4 @@ -# Overview +# API Documentation TensorFlow has APIs available in several languages both for constructing and executing a TensorFlow graph. The Python API is at present the most complete diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md index 4f243e58b3..4feeb7556a 100644 --- a/tensorflow/g3doc/api_docs/python/client.md +++ b/tensorflow/g3doc/api_docs/python/client.md @@ -277,7 +277,7 @@ with tf.Session(): - - - -#### `tf.InteractiveSession.__init__(target='', graph=None)` {#InteractiveSession.__init__} +#### `tf.InteractiveSession.__init__(target='', graph=None, config=None)` {#InteractiveSession.__init__} Creates a new interactive TensorFlow session. @@ -296,6 +296,7 @@ the session constructor. Defaults to using an in-process engine. At present, no value other than the empty string is supported. * <b>`graph`</b>: (Optional.) The `Graph` to be launched (described above). +* <b>`config`</b>: (Optional) `ConfigProto` proto used to configure the session. - - - diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 6fd8e44dc1..9ee01c727d 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -1339,7 +1339,7 @@ Converts the given `type_value` to a `DType`. Wrapper for `Graph.device()` using the default graph. See -[`Graph.name_scope()`](../../api_docs/python/framework.md#Graph.name_scope) +[`Graph.device()`](../../api_docs/python/framework.md#Graph.device) for more details. ##### Args: @@ -1544,6 +1544,35 @@ protocol buffer, and extract individual objects in the `GraphDef` as it refers to an unknown tensor). +- - - + +### `tf.load_op_library(library_filename)` {#load_op_library} + +Loads a TensorFlow plugin, containing custom ops and kernels. + +Pass "library_filename" to a platform-specific mechanism for dynamically +loading a library. The rules for determining the exact location of the +library are platform-specific and are not documented here. +Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be +defined in the library. + +##### Args: + + +* <b>`library_filename`</b>: Path to the plugin. + Relative or absolute filesystem path to a dynamic library file. + +##### Returns: + + A python module containing the Python wrappers for Ops defined in + the plugin. + +##### Raises: + + +* <b>`RuntimeError`</b>: when unable to load the library or get the python wrappers. + + ## Graph collections diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index a2165c834d..efd38ed148 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -17,6 +17,7 @@ * [`Graph`](../../api_docs/python/framework.md#Graph) * [`GraphKeys`](../../api_docs/python/framework.md#GraphKeys) * [`import_graph_def`](../../api_docs/python/framework.md#import_graph_def) + * [`load_op_library`](../../api_docs/python/framework.md#load_op_library) * [`name_scope`](../../api_docs/python/framework.md#name_scope) * [`NoGradient`](../../api_docs/python/framework.md#NoGradient) * [`op_scope`](../../api_docs/python/framework.md#op_scope) @@ -124,6 +125,8 @@ * [`diag`](../../api_docs/python/math_ops.md#diag) * [`div`](../../api_docs/python/math_ops.md#div) * [`edit_distance`](../../api_docs/python/math_ops.md#edit_distance) + * [`erf`](../../api_docs/python/math_ops.md#erf) + * [`erfc`](../../api_docs/python/math_ops.md#erfc) * [`exp`](../../api_docs/python/math_ops.md#exp) * [`fft2d`](../../api_docs/python/math_ops.md#fft2d) * [`floor`](../../api_docs/python/math_ops.md#floor) @@ -132,6 +135,7 @@ * [`imag`](../../api_docs/python/math_ops.md#imag) * [`inv`](../../api_docs/python/math_ops.md#inv) * [`invert_permutation`](../../api_docs/python/math_ops.md#invert_permutation) + * [`lgamma`](../../api_docs/python/math_ops.md#lgamma) * [`listdiff`](../../api_docs/python/math_ops.md#listdiff) * [`log`](../../api_docs/python/math_ops.md#log) * [`matmul`](../../api_docs/python/math_ops.md#matmul) @@ -355,6 +359,7 @@ * [`gradients`](../../api_docs/python/train.md#gradients) * [`histogram_summary`](../../api_docs/python/train.md#histogram_summary) * [`image_summary`](../../api_docs/python/train.md#image_summary) + * [`LooperThread`](../../api_docs/python/train.md#LooperThread) * [`merge_all_summaries`](../../api_docs/python/train.md#merge_all_summaries) * [`merge_summary`](../../api_docs/python/train.md#merge_summary) * [`MomentumOptimizer`](../../api_docs/python/train.md#MomentumOptimizer) @@ -369,3 +374,6 @@ * [`write_graph`](../../api_docs/python/train.md#write_graph) * [`zero_fraction`](../../api_docs/python/train.md#zero_fraction) +* **[Wraps python functions](../../api_docs/python/script_ops.md)**: + * [`py_func`](../../api_docs/python/script_ops.md#py_func) + diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index 878bae23d4..66a9c20e04 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -527,6 +527,63 @@ Computes sin of x element-wise. A `Tensor`. Has the same type as `x`. +- - - + +### `tf.lgamma(x, name=None)` {#lgamma} + +Computes `ln(|gamma(x)|)` element-wise. + +##### Args: + + +* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + + +- - - + +### `tf.erf(x, name=None)` {#erf} + +Computes Gauss error function of `x` element-wise. + +##### Args: + + +* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + + +- - - + +### `tf.erfc(x, name=None)` {#erfc} + +Computes complementary error function of `x` element-wise. + +##### Args: + + +* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + + ## Matrix Math Functions diff --git a/tensorflow/g3doc/api_docs/python/script_ops.md b/tensorflow/g3doc/api_docs/python/script_ops.md new file mode 100644 index 0000000000..aa1ff4c9c5 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/script_ops.md @@ -0,0 +1,46 @@ +<!-- This file is machine generated: DO NOT EDIT! --> + +# Wraps python functions + +Note: Functions taking `Tensor` arguments can also take anything accepted by +[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor). + +[TOC] + +## Script Language Operators. + +TensorFlow provides allows you to wrap python/numpy functions as +TensorFlow operators. + +## Other Functions and Classes +- - - + +### `tf.py_func(func, inp, Tout, name=None)` {#py_func} + +Wraps a python function and uses it as a tensorflow op. + +Given a python function `func`, which takes numpy arrays as its +inputs and returns numpy arrays as its outputs. E.g., + + def my_func(x): + return np.sinh(x) + inp = tf.placeholder(..., tf.float32) + y = py_func(my_func, [inp], [tf.float32]) + +The above snippet constructs a tf graph which invokes a numpy +sinh(x) as an op in the graph. + +##### Args: + + +* <b>`func`</b>: A python function. +* <b>`inp`</b>: A list of `Tensor`. +* <b>`Tout`</b>: A list of tensorflow data types indicating what `func` + returns. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A list of `Tensor` which `func` computes. + + diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md index cdcb5e2c0f..afaae8facf 100644 --- a/tensorflow/g3doc/api_docs/python/sparse_ops.md +++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md @@ -157,7 +157,7 @@ Alias for field number 1 - - - -### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, name=None)` {#sparse_to_dense} +### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, validate_indices=True, name=None)` {#sparse_to_dense} Converts a sparse representation into a dense tensor. @@ -177,6 +177,10 @@ dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] All other values in `dense` are set to `default_value`. If `sparse_values` is a scalar, all sparse indices are set to this single value. +Indices should be sorted in lexicographic order, and indices must not +contain any repeats. If `validate_indices` is True, these properties +are checked during execution. + ##### Args: @@ -189,6 +193,8 @@ is a scalar, all sparse indices are set to this single value. `sparse_indices`, or a scalar value to be used for all sparse indices. * <b>`default_value`</b>: A 0-D `Tensor` of the same type as `sparse_values`. Value to set for indices not specified in `sparse_indices`. Defaults to zero. +* <b>`validate_indices`</b>: A boolean value. If True, indices are checked to make + sure they are sorted in lexicographic order and that there are no repeats. * <b>`name`</b>: A name for the operation (optional). ##### Returns: @@ -199,7 +205,7 @@ is a scalar, all sparse indices are set to this single value. - - - -### `tf.sparse_tensor_to_dense(sp_input, default_value=0, name=None)` {#sparse_tensor_to_dense} +### `tf.sparse_tensor_to_dense(sp_input, default_value=0, validate_indices=True, name=None)` {#sparse_tensor_to_dense} Converts a `SparseTensor` into a dense tensor. @@ -218,12 +224,17 @@ string tensor with values: [x x x x x] [c x x x x]] +Indices must be without repeats. This is only +tested if validate_indices is True. + ##### Args: * <b>`sp_input`</b>: The input `SparseTensor`. * <b>`default_value`</b>: Scalar value to set for indices not specified in `sp_input`. Defaults to zero. +* <b>`validate_indices`</b>: A boolean value. If `True`, indices are checked to make + sure they are sorted in lexicographic order and that there are no repeats. * <b>`name`</b>: A name prefix for the returned tensors (optional). ##### Returns: @@ -257,15 +268,18 @@ For example, if `sp_input.shape = [2, 3, 4]` with non-empty values: [0, 0, 0]: 0 [0, 1, 0]: 10 [1, 0, 3]: 103 - [1, 1, 2]: 112 - [1, 1, 3]: 113 + [1, 1, 2]: 150 + [1, 1, 3]: 149 + [1, 1, 4]: 150 [1, 2, 1]: 121 and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool tensor with False everywhere except at positions - (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121). + (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 149), (1, 1, 150), + (1, 2, 121). +Note that repeats are allowed in the input SparseTensor. This op is useful for converting `SparseTensor`s into dense formats for compatibility with ops that expect dense tensors. diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index a1b7244c56..7a18323685 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -784,7 +784,7 @@ Sets the list of old checkpoint filenames. ##### Raises: -* <b>`AssertionError`</b>: If the list of checkpoint filenames has already been set. +* <b>`AssertionError`</b>: If last_checkpoints is not a list. - - - diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index 2a9687535c..4bfbd3007f 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -32,7 +32,7 @@ opt = GradientDescentOptimizer(learning_rate=0.1) # Add Ops to the graph to minimize a cost by updating a list of variables. # "cost" is a Tensor, and the list of variables contains tf.Variable # objects. -opt_op = opt.minimize(cost, <list of variables>) +opt_op = opt.minimize(cost, var_list=<list of variables>) ``` In the training program you will just have to run the returned Op. @@ -1471,8 +1471,8 @@ summary has a summary value for each tag-value pair in `tags` and `values`. ##### Args: -* <b>`tags`</b>: A 1-D `string` `Tensor`. Tags for the summaries. -* <b>`values`</b>: A 1-D `float32` or `float64` Tensor. Values for the summaries. +* <b>`tags`</b>: A `string` `Tensor`. Tags for the summaries. +* <b>`values`</b>: A real numeric Tensor. Values for the summaries. * <b>`collections`</b>: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. * <b>`name`</b>: A name for the operation (optional). @@ -1874,3 +1874,216 @@ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') * <b>`as_text`</b>: If `True`, writes the graph as an ASCII proto. + +## Other Functions and Classes +- - - + +### `class tf.train.LooperThread` {#LooperThread} + +A thread that runs code repeatedly, optionally on a timer. + +This thread class is intended to be used with a `Coordinator`. It repeatedly +runs code specified either as `target` and `args` or by the `run_loop()` +method. + +Before each run the thread checks if the coordinator has requested stop. In +that case the looper thread terminates immediately. + +If the code being run raises an exception, that exception is reported to the +coordinator and the thread terminates. The coordinator will then request all +the other threads it coordinates to stop. + +You typically pass looper threads to the supervisor `Join()` method. +- - - + +#### `tf.train.LooperThread.__init__(coord, timer_interval_secs, target=None, args=None)` {#LooperThread.__init__} + +Create a LooperThread. + +##### Args: + + +* <b>`coord`</b>: a Coordinator. +* <b>`timer_interval_secs`</b>: Time boundaries at which to call Run(), or None + if it should be called back to back. +* <b>`target`</b>: Optional callable object that will be executed in the thread. +* <b>`args`</b>: Optional arguments to pass to `target` when calling it. + +##### Raises: + + +* <b>`ValueError`</b>: If one of the arguments is invalid. + + +- - - + +#### `tf.train.LooperThread.daemon` {#LooperThread.daemon} + +A boolean value indicating whether this thread is a daemon thread (True) or not (False). + +This must be set before start() is called, otherwise RuntimeError is +raised. Its initial value is inherited from the creating thread; the +main thread is not a daemon thread and therefore all threads created in +the main thread default to daemon = False. + +The entire Python program exits when no alive non-daemon threads are +left. + + +- - - + +#### `tf.train.LooperThread.getName()` {#LooperThread.getName} + + + + +- - - + +#### `tf.train.LooperThread.ident` {#LooperThread.ident} + +Thread identifier of this thread or None if it has not been started. + +This is a nonzero integer. See the thread.get_ident() function. Thread +identifiers may be recycled when a thread exits and another thread is +created. The identifier is available even after the thread has exited. + + +- - - + +#### `tf.train.LooperThread.isAlive()` {#LooperThread.isAlive} + +Return whether the thread is alive. + +This method returns True just before the run() method starts until just +after the run() method terminates. The module function enumerate() +returns a list of all alive threads. + + +- - - + +#### `tf.train.LooperThread.isDaemon()` {#LooperThread.isDaemon} + + + + +- - - + +#### `tf.train.LooperThread.is_alive()` {#LooperThread.is_alive} + +Return whether the thread is alive. + +This method returns True just before the run() method starts until just +after the run() method terminates. The module function enumerate() +returns a list of all alive threads. + + +- - - + +#### `tf.train.LooperThread.join(timeout=None)` {#LooperThread.join} + +Wait until the thread terminates. + +This blocks the calling thread until the thread whose join() method is +called terminates -- either normally or through an unhandled exception +or until the optional timeout occurs. + +When the timeout argument is present and not None, it should be a +floating point number specifying a timeout for the operation in seconds +(or fractions thereof). As join() always returns None, you must call +isAlive() after join() to decide whether a timeout happened -- if the +thread is still alive, the join() call timed out. + +When the timeout argument is not present or None, the operation will +block until the thread terminates. + +A thread can be join()ed many times. + +join() raises a RuntimeError if an attempt is made to join the current +thread as that would cause a deadlock. It is also an error to join() a +thread before it has been started and attempts to do so raises the same +exception. + + +- - - + +#### `tf.train.LooperThread.loop(coord, timer_interval_secs, target, args=None)` {#LooperThread.loop} + +Start a LooperThread that calls a function periodically. + +If `timer_interval_secs` is None the thread calls `target(args)` +repeatedly. Otherwise `target(args)` is called every `timer_interval_secs` +seconds. The thread terminates when a stop of the coordinator is +requested. + +##### Args: + + +* <b>`coord`</b>: A Coordinator. +* <b>`timer_interval_secs`</b>: Number. Time boundaries at which to call `target`. +* <b>`target`</b>: A callable object. +* <b>`args`</b>: Optional arguments to pass to `target` when calling it. + +##### Returns: + + The started thread. + + +- - - + +#### `tf.train.LooperThread.name` {#LooperThread.name} + +A string used for identification purposes only. + +It has no semantics. Multiple threads may be given the same name. The +initial name is set by the constructor. + + +- - - + +#### `tf.train.LooperThread.run()` {#LooperThread.run} + + + + +- - - + +#### `tf.train.LooperThread.run_loop()` {#LooperThread.run_loop} + +Called at 'timer_interval_secs' boundaries. + + +- - - + +#### `tf.train.LooperThread.setDaemon(daemonic)` {#LooperThread.setDaemon} + + + + +- - - + +#### `tf.train.LooperThread.setName(name)` {#LooperThread.setName} + + + + +- - - + +#### `tf.train.LooperThread.start()` {#LooperThread.start} + +Start the thread's activity. + +It must be called at most once per thread object. It arranges for the +object's run() method to be invoked in a separate thread of control. + +This method will raise a RuntimeError if called more than once on the +same thread object. + + +- - - + +#### `tf.train.LooperThread.start_loop()` {#LooperThread.start_loop} + +Called when the thread starts. + + + diff --git a/tensorflow/g3doc/extras/README.txt b/tensorflow/g3doc/extras/README.txt index 2c9682d2fb..765809a762 100644 --- a/tensorflow/g3doc/extras/README.txt +++ b/tensorflow/g3doc/extras/README.txt @@ -1,2 +1,3 @@ This directory holds extra files we'd like to be able -to link to and serve from within tensorflow.org +to link to and serve from within tensorflow.org. +They are excluded from versioning.
\ No newline at end of file diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index fe943fac6c..4b2e623f00 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -844,7 +844,8 @@ For more details, see In general, changes to specifications must be backwards-compatible: changing the specification of an Op must not break prior serialized `GraphDef` protocol -buffers constructed from older specfications. +buffers constructed from older specfications. The details of `GraphDef` +compatibility are [described here](../../resources/versions.md#graphs). There are several ways to preserve backwards-compatibility. @@ -897,7 +898,8 @@ generated Python code may change in a way that isn't compatible with old callers. The Python API may be kept compatible by careful changes in a hand-written Python wrapper, by keeping the old signature except possibly adding new optional arguments to the end. Generally incompatible changes may only be -made when TensorFlow's changes major versions. +made when TensorFlow's changes major versions, and must conform to the +[`GraphDef` version semantics](../../resources/versions.md#graphs). ## GPU Support {#mult-archs} diff --git a/tensorflow/g3doc/how_tos/index.md b/tensorflow/g3doc/how_tos/index.md index 748ecfd398..c9ab79aa2a 100644 --- a/tensorflow/g3doc/how_tos/index.md +++ b/tensorflow/g3doc/how_tos/index.md @@ -1,4 +1,4 @@ -# Overview +# How-Tos ## Variables: Creation, Initializing, Saving, and Restoring diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md index f53ae3a18c..d19093871d 100644 --- a/tensorflow/g3doc/resources/index.md +++ b/tensorflow/g3doc/resources/index.md @@ -12,7 +12,7 @@ implementation can be found in our white paper: If you use TensorFlow in your research and would like to cite the TensorFlow system, we suggest you cite the paper above. -You can use this [BibTeX entry](../resources/bib.md). As the project progresses, we +You can use this [BibTeX entry](bib.md). As the project progresses, we may update the suggested citation with new papers. Please only use the TensorFlow name and marks when accurately referencing this @@ -55,3 +55,8 @@ https://github.com/tensorflow/tensorflow/issues) on GitHub. If you need help with using TensorFlow, please do not use the issue tracker for that. Instead, direct your questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow). +## Versioning + +TensorFlow uses [Semantic Versioning 2.0](http://semver.org). For details on +the versioning of our public API and binary compatibility, see the [versioning +document](versions.md). diff --git a/tensorflow/g3doc/resources/leftnav_files b/tensorflow/g3doc/resources/leftnav_files index 2e1940b5d4..b0df3a8368 100644 --- a/tensorflow/g3doc/resources/leftnav_files +++ b/tensorflow/g3doc/resources/leftnav_files @@ -3,3 +3,4 @@ uses.md faq.md glossary.md dims_types.md +versions.md diff --git a/tensorflow/g3doc/resources/versions.md b/tensorflow/g3doc/resources/versions.md new file mode 100644 index 0000000000..a16bd8f549 --- /dev/null +++ b/tensorflow/g3doc/resources/versions.md @@ -0,0 +1,143 @@ +# TensorFlow Version Semantics + +## Semantic Versioning 2.0 + +Once we reach version 1.0, TensorFlow will follow Semantic Versioning 2.0 +(semver). For details, see <http://semver.org>.  Each release version of +TensorFlow has the form `MAJOR.MINOR.PATCH`.  Changes to the each number have +the following meaning: + +* **MAJOR**:  Backwards incompatible changes.  Code and data that worked with + a previous major release will not necessarily work with a new release. + However, in some cases existing TensorFlow data (graphs, checkpoints, and + other protobufs) may be migratable to the newer release; see below for details + on data compatibility. + +* **MINOR**: Backwards compatible features, speed improvements, etc.  Code and + data that worked with a previous minor release *and* which depends only the + public API will continue to work unchanged.  For details on what is and is + not the public API, see below. + +* **PATCH**: Backwards compatible bug fixes. + +Before 1.0, semver allows backwards incompatible changes at any time.  However, +to support users now, we will use the format `0.MAJOR.MINOR` (shifted one step +to the right).  Thus 0.5.0 to 0.6.0 may be backwards incompatible, but 0.6.0 to +0.6.1 will include only backwards compatible features and bug fixes. + +At some point (especially as we approach 1.0) we will likely use prerelease +versions such as X.Y.Z-alpha.1, but we do not yet have specific plans (beyond +the restrictions of semver). + + +## Public API + +Only the public API of TensorFlow is backwards compatible across minor and patch +versions.  The public API consists of + +* The documented [C++ and Python APIs](../api_docs). + +* The following protocol buffer files: + [`attr_value`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto), + [`config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/config.proto), + [`event`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/event.proto), + [`graph`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto), + [`op_def`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto), + [`reader_base`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/reader_base.proto), + [`summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto), + [`tensor`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto), + [`tensor_shape`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto), + and [`types`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto). + +The public C++ API is exposed through the header files in +[`tensorflow/core/public`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/public). +The public Python API is unfortunately **not** everything available through the +tensorflow python module and its submodules, since we do not yet use `__all__` +everywhere ([#421](https://github.com/tensorflow/tensorflow/issues/421)). + Please refer to the documentation to determine whether a given Python feature +is part of the public API. For now, the protocol buffers are defined in +[`tensorflow/core/framework/*.proto`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/framework) +([#484](https://github.com/tensorflow/tensorflow/issues/484)). + + +## Details That Are Not Public + +The following are specifically **not** part of the public API: they are allowed +to change without notice across minor releases and even patch releases if bug +fixes require it: + +* **Details of composite ops:**  Many public functions in Python expand to + several primitive ops in the graph, and these details will be part of any + graphs saved to disk as GraphDefs.  These details are allowed to change for + minor releases. In particular, regressions tests that check for exact + matching between graphs are likely to break across minor releases, even though + the behavior of the graph should be unchanged and existing checkpoints will + still work. + +* **Floating point numerical details:** The specific floating point values + computed by ops may change at any time: users should rely only on approximate + accuracy and numerical stability, not on the specific bits computed.  Changes + to numerical formulas in minor and patch releases should result in comparable + or improved accuracy, with the caveat that in machine learning improved + accuracy of specific formulas may result in worse accuracy for the overall + system. + +* **Random numbers:** The specific random numbers computed by the [random + ops](../api_docs/python/constant_op.html#random-tensors) may change at any + time: users should rely only on approximately correct distributions and + statistical strength, not the specific bits computed.  However, we will make + changes to random bits rarely and ideally never for patch releases, and all + such intended changes will be documented. + + +## Compatibility for Graphs and Checkpoints {#graphs} + +Many users of TensorFlow will be saving graphs and trained models to disk for +later evaluation or more training, often changing versions of TensorFlow in the +process.  First, following semver, any graph or checkpoint written out with one +version of TensorFlow can be loaded and evaluated with a later version of +TensorFlow with the same major release.  However, we will endeavour to preserve +backwards compatibility even across major releases when possible, so that the +serialized files are usable over long periods of time. + +There are two main classes of saved TensorFlow data: graphs and checkpoints. +Graphs describe the data flow graphs of ops to be run during training and +inference, and checkpoints contain the saved tensor values of variables in a +graph. + +Graphs are serialized via the `GraphDef` protocol buffer.  To facilitate (rare) +backwards incompatible changes to graphs, each `GraphDef` has an integer version +separate from the TensorFlow version.  The semantics are: + +* Each version of TensorFlow supports an interval of `GraphDef` versions.  This + interval with be constant across patch releases, and will only grow across + minor releases.  Dropping support for a `GraphDef` version will only occur + for a major release of TensorFlow. + +* Newly created graphs use the newest `GraphDef` version. + +* If a given version of TensorFlow supports the `GraphDef` version of a graph, + it will load and evaluate with the same behavior as when it was written out + (except for floating point numerical details and random numbers), regardless + of the major version of TensorFlow.  In particular, all checkpoint files will + be compatible. + +* If the `GraphDef` upper bound is increased to X in a (minor) release, there + will be at least six months before the lower bound is increased to X. + +For example (numbers and versions hypothetical), TensorFlow 1.2 might support +`GraphDef` versions 4 to 7.  TensorFlow 1.3 could add `GraphDef` version 8 and +support versions 4 to 8.  At least six months later, TensorFlow 2.0.0 could drop +support for versions 4 to 7, leaving version 8 only. + +Finally, when support for a `GraphDef` version is dropped, we will attempt to +provide tools for automatically converting graphs to a newer supported +`GraphDef` version. + + +## C++ API Compatibility + +Only patch releases will be binary compatible at the C++ level.  That is, minor +releases are backwards compatible in terms of behavior but may require a +recompile for downstream C++ code.  As always, backwards compatibility is only +provided for the public C++ API. diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md index 8fbdcfcd21..98e1d60fbc 100644 --- a/tensorflow/g3doc/tutorials/index.md +++ b/tensorflow/g3doc/tutorials/index.md @@ -1,4 +1,4 @@ -# Overview +# Tutorials ## MNIST For ML Beginners diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc index f579ce138c..58f5f15d5c 100644 --- a/tensorflow/models/embedding/word2vec_kernels.cc +++ b/tensorflow/models/embedding/word2vec_kernels.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/simple_philox.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/lib/strings/regexp.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/guarded_philox_random.h" diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD index 25dce65f28..87e11bab62 100644 --- a/tensorflow/models/image/cifar10/BUILD +++ b/tensorflow/models/image/cifar10/BUILD @@ -9,6 +9,7 @@ py_library( name = "cifar10_input", srcs = ["cifar10_input.py"], srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], deps = [ "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index b9b89473e8..32234db496 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -43,11 +43,9 @@ import tarfile import tensorflow.python.platform from six.moves import urllib -from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.models.image.cifar10 import cifar10_input -from tensorflow.python.platform import gfile FLAGS = tf.app.flags.FLAGS @@ -57,15 +55,12 @@ tf.app.flags.DEFINE_integer('batch_size', 128, tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', """Path to the CIFAR-10 data directory.""") -# Process images of this size. Note that this differs from the original CIFAR -# image size of 32 x 32. If one alters this number, then the entire model -# architecture will change and any model would need to be retrained. -IMAGE_SIZE = 24 - # Global constants describing the CIFAR-10 data set. -NUM_CLASSES = 10 -NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 -NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 +IMAGE_SIZE = cifar10_input.IMAGE_SIZE +NUM_CLASSES = cifar10_input.NUM_CLASSES +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + # Constants describing the training process. MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. @@ -139,91 +134,21 @@ def _variable_with_weight_decay(name, shape, stddev, wd): return var -def _generate_image_and_label_batch(image, label, min_queue_examples): - """Construct a queued batch of images and labels. - - Args: - image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32. - label: 1-D Tensor of type.int32 - min_queue_examples: int32, minimum number of samples to retain - in the queue that provides of batches of examples. - - Returns: - images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. - labels: Labels. 1D tensor of [batch_size] size. - """ - # Create a queue that shuffles the examples, and then - # read 'FLAGS.batch_size' images + labels from the example queue. - num_preprocess_threads = 16 - images, label_batch = tf.train.shuffle_batch( - [image, label], - batch_size=FLAGS.batch_size, - num_threads=num_preprocess_threads, - capacity=min_queue_examples + 3 * FLAGS.batch_size, - min_after_dequeue=min_queue_examples) - - # Display the training images in the visualizer. - tf.image_summary('images', images) - - return images, tf.reshape(label_batch, [FLAGS.batch_size]) - - def distorted_inputs(): """Construct distorted input for CIFAR training using the Reader ops. - Raises: - ValueError: if no data_dir - Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. - """ - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] - for f in filenames: - if not gfile.Exists(f): - raise ValueError('Failed to find file: ' + f) - - # Create a queue that produces the filenames to read. - filename_queue = tf.train.string_input_producer(filenames) - # Read examples from files in the filename queue. - read_input = cifar10_input.read_cifar10(filename_queue) - reshaped_image = tf.cast(read_input.uint8image, tf.float32) - - height = IMAGE_SIZE - width = IMAGE_SIZE - - # Image processing for training the network. Note the many random - # distortions applied to the image. - - # Randomly crop a [height, width] section of the image. - distorted_image = tf.image.random_crop(reshaped_image, [height, width]) - - # Randomly flip the image horizontally. - distorted_image = tf.image.random_flip_left_right(distorted_image) - - # Because these operations are not commutative, consider randomizing - # randomize the order their operation. - distorted_image = tf.image.random_brightness(distorted_image, - max_delta=63) - distorted_image = tf.image.random_contrast(distorted_image, - lower=0.2, upper=1.8) - - # Subtract off the mean and divide by the variance of the pixels. - float_image = tf.image.per_image_whitening(distorted_image) - - # Ensure that the random shuffling has good mixing properties. - min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * - min_fraction_of_examples_in_queue) - print ('Filling queue with %d CIFAR images before starting to train. ' - 'This will take a few minutes.' % min_queue_examples) - - # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples) + Raises: + ValueError: If no data_dir + """ + if not FLAGS.data_dir: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') + return cifar10_input.distorted_inputs(data_dir=data_dir, + batch_size=FLAGS.batch_size) def inputs(eval_data): @@ -232,56 +157,18 @@ def inputs(eval_data): Args: eval_data: bool, indicating if one should use the train or eval data set. - Raises: - ValueError: if no data_dir - Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') - - if not eval_data: - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] - num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN - else: - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'test_batch.bin')] - num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL - - for f in filenames: - if not gfile.Exists(f): - raise ValueError('Failed to find file: ' + f) - - # Create a queue that produces the filenames to read. - filename_queue = tf.train.string_input_producer(filenames) - - # Read examples from files in the filename queue. - read_input = cifar10_input.read_cifar10(filename_queue) - reshaped_image = tf.cast(read_input.uint8image, tf.float32) - - height = IMAGE_SIZE - width = IMAGE_SIZE - - # Image processing for evaluation. - # Crop the central [height, width] of the image. - resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, - width, height) - - # Subtract off the mean and divide by the variance of the pixels. - float_image = tf.image.per_image_whitening(resized_image) - - # Ensure that the random shuffling has good mixing properties. - min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(num_examples_per_epoch * - min_fraction_of_examples_in_queue) - - # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples) + data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') + return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, + batch_size=FLAGS.batch_size) def inference(images): diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py index ac73c493a3..ffe8facd27 100644 --- a/tensorflow/models/image/cifar10/cifar10_input.py +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -19,9 +19,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import tensorflow.python.platform +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.python.platform import gfile + +# Process images of this size. Note that this differs from the original CIFAR +# image size of 32 x 32. If one alters this number, then the entire model +# architecture will change and any model would need to be retrained. +IMAGE_SIZE = 24 + +# Global constants describing the CIFAR-10 data set. +NUM_CLASSES = 10 +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 + def read_cifar10(filename_queue): """Reads and parses examples from CIFAR10 data files. @@ -82,3 +97,144 @@ def read_cifar10(filename_queue): result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result + + +def _generate_image_and_label_batch(image, label, min_queue_examples, + batch_size): + """Construct a queued batch of images and labels. + + Args: + image: 3-D Tensor of [height, width, 3] of type.float32. + label: 1-D Tensor of type.int32 + min_queue_examples: int32, minimum number of samples to retain + in the queue that provides of batches of examples. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, height, width, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + # Create a queue that shuffles the examples, and then + # read 'batch_size' images + labels from the example queue. + num_preprocess_threads = 16 + images, label_batch = tf.train.shuffle_batch( + [image, label], + batch_size=batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * batch_size, + min_after_dequeue=min_queue_examples) + + # Display the training images in the visualizer. + tf.image_summary('images', images) + + return images, tf.reshape(label_batch, [batch_size]) + + +def distorted_inputs(data_dir, batch_size): + """Construct distorted input for CIFAR training using the Reader ops. + + Args: + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for training the network. Note the many random + # distortions applied to the image. + + # Randomly crop a [height, width] section of the image. + distorted_image = tf.image.random_crop(reshaped_image, [height, width]) + + # Randomly flip the image horizontally. + distorted_image = tf.image.random_flip_left_right(distorted_image) + + # Because these operations are not commutative, consider randomizing + # randomize the order their operation. + distorted_image = tf.image.random_brightness(distorted_image, + max_delta=63) + distorted_image = tf.image.random_contrast(distorted_image, + lower=0.2, upper=1.8) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(distorted_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * + min_fraction_of_examples_in_queue) + print ('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size) + + +def inputs(eval_data, data_dir, batch_size): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + if not eval_data: + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + else: + filenames = [os.path.join(data_dir, 'test_batch.bin')] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for evaluation. + # Crop the central [height, width] of the image. + resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, + width, height) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(resized_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(num_examples_per_epoch * + min_fraction_of_examples_in_queue) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 62aa3ee0c5..ee2e769ede 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -69,6 +69,20 @@ py_tests( ) cc_library( + name = "py_func_lib", + srcs = ["lib/core/py_func.cc"], + hdrs = [ + "lib/core/py_func.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/py/numpy:headers", + "//util/python:python_headers", + ], +) + +cc_library( name = "py_record_reader_lib", srcs = [ "lib/io/py_record_reader.cc", @@ -107,17 +121,28 @@ py_test( ) cc_library( - name = "python_op_gen_main", + name = "python_op_gen", srcs = [ "framework/python_op_gen.cc", "framework/python_op_gen.h", - "framework/python_op_gen_main.cc", ], visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:protos_cc", ], + alwayslink = 1, +) + +cc_library( + name = "python_op_gen_main", + srcs = [ + "framework/python_op_gen_main.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":python_op_gen", + ], ) # What is needed for tf_gen_op_wrapper_py. @@ -154,6 +179,7 @@ py_library( "framework/importer.py", "framework/random_seed.py", "framework/tensor_util.py", + "framework/load_library.py", # TODO(josh11b): Move this to the framework directory "ops/common_shapes.py", ], @@ -483,6 +509,9 @@ tf_gen_op_wrapper_py( "MatMul", "Sigmoid", "Tanh", + "Lgamma", + "Erf", + "Erfc", ], require_shape_functions = True, ) @@ -531,6 +560,14 @@ tf_gen_op_wrapper_py( ) tf_gen_op_wrapper_py( + name = "script_ops", + hidden = [ + "PyFunc", + ], + require_shape_functions = True, +) + +tf_gen_op_wrapper_py( name = "state_ops", hidden = [ "Variable", @@ -631,6 +668,7 @@ py_library( "ops/random_ops.py", "ops/rnn.py", "ops/rnn_cell.py", + "ops/script_ops.py", "ops/seq2seq.py", "ops/sparse_grad.py", "ops/sparse_ops.py", @@ -658,6 +696,7 @@ py_library( ":nn_ops", ":parsing_ops", ":random_ops", + ":script_ops", ":sparse_ops", ":string_ops", ":summary_ops", @@ -710,8 +749,18 @@ tf_proto_library_py( name = "protos_all", srcs = glob( ["**/*.proto"], - exclude = ["util/protobuf/compare_test.proto"], + exclude = [ + "util/protobuf/compare_test.proto", + "training/saver.proto", + ], ), + deps = [":public_protos_py"], +) + +tf_proto_library_py( + name = "public_protos", + srcs = ["training/saver.proto"], + visibility = ["//visibility:public"], ) tf_proto_library_py( @@ -785,6 +834,8 @@ tf_py_wrap_cc( swig_includes = [ "client/events_writer.i", "client/tf_session.i", + "framework/python_op_gen.i", + "lib/core/py_func.i", "lib/core/status.i", "lib/core/status_helper.i", "lib/core/strings.i", @@ -795,8 +846,10 @@ tf_py_wrap_cc( "util/port.i", ], deps = [ + ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", + ":python_op_gen", ":tf_session_helper", "//util/python:python_headers", ], diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 2afbea6e63..5c8dfc74a7 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -594,7 +594,7 @@ class InteractiveSession(BaseSession): @@close """ - def __init__(self, target='', graph=None): + def __init__(self, target='', graph=None, config=None): """Creates a new interactive TensorFlow session. If no `graph` argument is specified when constructing the session, @@ -610,8 +610,9 @@ class InteractiveSession(BaseSession): Defaults to using an in-process engine. At present, no value other than the empty string is supported. graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional) `ConfigProto` proto used to configure the session. """ - super(InteractiveSession, self).__init__(target, graph) + super(InteractiveSession, self).__init__(target, graph, config) self._default_session = self.as_default() self._default_session.__enter__() self._explicit_graph = graph diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 2d6a73eb9e..dcb64e8a9e 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -200,7 +200,13 @@ tensorflow::ImportNumpy(); // END TYPEMAPS FOR tensorflow::TF_Run_wrapper() //////////////////////////////////////////////////////////////////////////////// - +// Typemaps for TF_GetOpList. +// The wrapped function TF_GetOpList returns a TF_Buffer pointer. This typemap +// creates a Python string from the TF_Buffer and returns it. +%typemap(out) TF_Buffer TF_GetOpList { + $result = PyString_FromStringAndSize( + reinterpret_cast<const char*>($1.data), $1.length); +} // Include the functions from tensor_c_api.h, except TF_Run. %ignoreall @@ -219,6 +225,9 @@ tensorflow::ImportNumpy(); %unignore TF_CloseSession; %unignore TF_DeleteSession; %unignore TF_ExtendGraph; +%unignore TF_NewLibrary; +%unignore TF_LoadLibrary; +%unignore TF_GetOpList; %include "tensorflow/core/public/tensor_c_api.h" %ignoreall diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 84163ca1c0..9d4f2c54f9 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -36,6 +36,7 @@ @@convert_to_tensor_or_indexed_slices @@get_default_graph @@import_graph_def +@@load_op_library ## Graph collections @@ -89,3 +90,6 @@ from tensorflow.python.framework.tensor_shape import Dimension from tensorflow.python.framework.tensor_shape import TensorShape from tensorflow.python.framework.dtypes import * + +# Load a TensorFlow plugin +from tensorflow.python.framework.load_library import * diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index a20f706f38..e998310a69 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -105,17 +105,19 @@ def all_libraries(module_to_name, members, documented): "rnn", "state_saving_rnn", "bidirectional_rnn", "dynamic_rnn", "seq2seq", "rnn_cell"], prefix=PREFIX_TEXT), - library('client', "Running Graphs", client_lib), + library("client", "Running Graphs", client_lib), library("train", "Training", tf.train, exclude_symbols=["Feature", "Features", "BytesList", "FloatList", "Int64List", "Example", "InferenceExample", "FeatureList", "FeatureLists", "RankingExample", "SequenceExample"]), + library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT) ] _hidden_symbols = ["Event", "Summary", "xrange", "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", - "GPUOptions", "SessionInterface", "BaseSession"] + "GPUOptions", "GraphOptions", "SessionInterface", + "BaseSession"] def main(unused_argv): if not FLAGS.out_dir: diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index efc977aeb5..14b990fffa 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -622,14 +622,16 @@ class ImportGraphDefTest(tf.test.TestCase): def testVersionLow(self): with tf.Graph().as_default(): pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ " - r"needs \d+ <= version <= \d+. Please regenerate your graph.$") + r"needs %d <= version <= %d. Please regenerate your graph.$" % + (tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX)) with self.assertRaisesRegexp(ValueError, pat): tf.import_graph_def(self._MakeGraphDef("", version=-1)) def testVersionHigh(self): with tf.Graph().as_default(): pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ " - r"needs \d+ <= version <= \d+. Please upgrade TensorFlow.$") + r"needs %d <= version <= %d. Please upgrade TensorFlow.$" % + (tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX)) with self.assertRaisesRegexp(ValueError, pat): tf.import_graph_def(self._MakeGraphDef("", version=1 << 30)) diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py new file mode 100644 index 0000000000..9436638022 --- /dev/null +++ b/tensorflow/python/framework/load_library.py @@ -0,0 +1,74 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Function for loading TensorFlow plugins.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import imp +import sys + +from tensorflow.core.framework import op_def_pb2 +from tensorflow.python import pywrap_tensorflow as py_tf +from tensorflow.python.util import compat + + +def load_op_library(library_filename): + """Loads a TensorFlow plugin, containing custom ops and kernels. + + Pass "library_filename" to a platform-specific mechanism for dynamically + loading a library. The rules for determining the exact location of the + library are platform-specific and are not documented here. + Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be + defined in the library. + + Args: + library_filename: Path to the plugin. + Relative or absolute filesystem path to a dynamic library file. + + Returns: + A python module containing the Python wrappers for Ops defined in + the plugin. + + Raises: + RuntimeError: when unable to load the library or get the python wrappers. + """ + status = py_tf.TF_NewStatus() + + lib_handle = py_tf.TF_LoadLibrary(library_filename, status) + try: + if py_tf.TF_GetCode(status) != 0: + raise RuntimeError(compat.as_text(py_tf.TF_Message(status))) + finally: + py_tf.TF_DeleteStatus(status) + + op_list_str = py_tf.TF_GetOpList(lib_handle) + op_list = op_def_pb2.OpList() + op_list.ParseFromString(op_list_str) + wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str)) + + # Get a unique name for the module. + module_name = hashlib.md5(wrappers).hexdigest() + module = imp.new_module(module_name) + # pylint: disable=exec-used + exec(wrappers, module.__dict__) + # Stash away the library handle for making calls into the dynamic library. + module.LIB_HANDLE = lib_handle + # OpDefs of the list of ops defined in the library. + module.OP_LIST = op_list + sys.modules[module_name] = module + return module diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index c511b2ea28..390a293c95 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -542,6 +542,35 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False): % (error_prefix, value, type(value))) +def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False): + """Converts `values` to a list of `Tensor` objects. + + Args: + values: A list of objects that can be consumed by `tf.convert_to_tensor()`. + dtype: (Optional.) The required `DType` of the returned `Tensor` objects. + name: (Optional.) A name prefix to used when a new `Tensor` is + created, in which case element `i` will be given the name `name + + '_' + i`. + as_ref: True if the caller wants the results as ref tensors. + + Returns: + A list of `Tensor` and/or `IndexedSlices` objects. + + Raises: + TypeError: If no conversion function is registered for an element in + `values`. + RuntimeError: If a registered conversion function returns an invalid + value. + """ + if not isinstance(values, collections.Sequence): + raise TypeError("values must be a list.") + ret = [] + for i, value in enumerate(values): + n = None if name is None else "%s_%d" % (name, i) + ret.append(convert_to_tensor(value, dtype=dtype, name=n, as_ref=as_ref)) + return ret + + def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, as_ref=False): """Converts the given object to a `Tensor` or an `IndexedSlices`. @@ -2218,7 +2247,7 @@ class Graph(object): """ try: old_stack = self._name_stack - if not name: # Both for name=None nad name="" we re-set to empty scope. + if not name: # Both for name=None and name="" we re-set to empty scope. new_stack = (None, None) elif name and name[-1] == "/": new_stack = (name[:-1], name[:-1]) @@ -2734,7 +2763,7 @@ def device(dev): """Wrapper for `Graph.device()` using the default graph. See - [`Graph.name_scope()`](../../api_docs/python/framework.md#Graph.name_scope) + [`Graph.device()`](../../api_docs/python/framework.md#Graph.device) for more details. Args: diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index ae28319a38..898c4acf18 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen.h" #include <stdio.h> +#include <sstream> #include <unordered_map> #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/port.h" @@ -252,18 +254,19 @@ string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, } } -void PrintReturns(const OpDef& op_def, - const std::vector<string>& output_type_string) { +static string GetReturns(const OpDef& op_def, + const std::vector<string>& output_type_string) { + string result; DCHECK_EQ(op_def.output_arg_size(), output_type_string.size()); const int num_outs = op_def.output_arg_size(); - printf("\n Returns:\n"); + strings::Appendf(&result, "\n Returns:\n"); if (num_outs == 0) { - printf(" The created Operation.\n"); + strings::Appendf(&result, " The created Operation.\n"); } else { if (num_outs == 1) { StringPiece description = op_def.output_arg(0).description(); if (ConsumeEquals(&description)) { // Skip the generated type info. - printf("%s", Indent(4, 4, description).c_str()); + strings::Appendf(&result, "%s", Indent(4, 4, description).c_str()); } else { // Special case of one output, don't use the name of the output unless // there is no description. @@ -282,7 +285,7 @@ void PrintReturns(const OpDef& op_def, } else if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - printf("%s", Indent(4, 4, desc).c_str()); + strings::Appendf(&result, "%s", Indent(4, 4, desc).c_str()); } } else { std::vector<string> out_names(num_outs); @@ -293,8 +296,8 @@ void PrintReturns(const OpDef& op_def, out_names[i] = strings::StrCat("output", i); } } - printf(" A tuple of `Tensor` objects (%s).\n", - str_util::Join(out_names, ", ").c_str()); + strings::Appendf(&result, " A tuple of `Tensor` objects (%s).\n", + str_util::Join(out_names, ", ").c_str()); for (int i = 0; i < num_outs; ++i) { string desc = strings::StrCat(out_names[i], ": "); StringPiece description = op_def.output_arg(i).description(); @@ -317,10 +320,16 @@ void PrintReturns(const OpDef& op_def, strings::StrAppend(&desc, type); } } - printf("%s", Indent(4, 6, desc).c_str()); + strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); } } } + return result; +} + +void PrintReturns(const OpDef& op_def, + const std::vector<string>& output_type_string) { + printf("%s", GetReturns(op_def, output_type_string).c_str()); } string StringToPython(const string& str) { @@ -400,8 +409,8 @@ string AttrValueToPython(const string& type, const AttrValue& value) { } } -// Requires: ValidateOpDef(op_def).ok() -void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { +static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { + string result; // Map from attr name to the first input arg it is inferred from. std::unordered_map<string, string> inferred_attrs; // This has all the input args followed by those attrs that don't have @@ -472,7 +481,8 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { const string def_suffix = strings::StrCat(parameters, has_args ? ", " : "", "name=None):"); - printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str()); + strings::Appendf(&result, "%s\n", + WordWrap(def_prefix, def_suffix, kRightMargin).c_str()); // Format the Op's descriptions so that it can be a Python docstring. string comment; @@ -485,10 +495,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { } } - printf(R"( r"""%s - Args: -)", - comment.c_str()); + strings::Appendf(&result, " r\"\"\"%s\n Args:\n", comment.c_str()); // Inputs for (int i = 0; i < op_def.input_arg_size(); ++i) { @@ -504,7 +511,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - printf("%s", Indent(4, 6, desc).c_str()); + strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); } // Attrs @@ -569,10 +576,10 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { AppendWithinWidth(&desc, attr.description(), kRightMargin - 4 /* indent */); } - printf("%s", Indent(4, 6, desc).c_str()); + strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); } - printf(" name: A name for the operation (optional).\n"); + strings::Appendf(&result, " name: A name for the operation (optional).\n"); std::vector<string> output_type_string; output_type_string.reserve(op_def.output_arg_size()); @@ -580,7 +587,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { output_type_string.push_back( ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true)); } - PrintReturns(op_def, output_type_string); + strings::StrAppend(&result, GetReturns(op_def, output_type_string)); string return_prefix = strings::StrCat(" return _op_def_lib.apply_op("); string return_args = strings::StrCat("\"", op_def.name(), "\", "); @@ -589,13 +596,12 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { } strings::StrAppend(&return_args, "name=name)"); - printf(R"( """ -%s -)", - // Wrap the arguments, and indent to the (. - WordWrap(return_prefix, return_args, kRightMargin).c_str()); + strings::Appendf(&result, " \"\"\"\n%s\n", + // Wrap the arguments, and indent to the (. + WordWrap(return_prefix, return_args, kRightMargin).c_str()); - printf("\n\n"); + strings::Appendf(&result, "\n\n"); + return result; } void GenerateLowerCaseOpName(const string& str, string* result) { @@ -616,11 +622,12 @@ void GenerateLowerCaseOpName(const string& str, string* result) { } // namespace -void PrintPythonOps(const OpList& ops, const string& hidden_ops, +string GetPythonOps(const OpList& ops, const string& hidden_ops, bool require_shapes) { + string result; // Header // TODO(josh11b): Mention the library for which wrappers are being generated. - printf(R"("""Python wrappers around Brain. + strings::Appendf(&result, R"("""Python wrappers around Brain. This file is MACHINE GENERATED! Do not edit. """ @@ -662,10 +669,12 @@ from tensorflow.python.ops import op_def_library continue; } - PrintPythonOp(op_def, is_hidden, lower_case_name); + strings::StrAppend(&result, + GetPythonOp(op_def, is_hidden, lower_case_name)); if (!require_shapes) { - printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str()); + strings::Appendf(&result, "ops.RegisterShape(\"%s\")(None)\n", + op_def.name().c_str()); } auto added = out->Add(); @@ -673,7 +682,7 @@ from tensorflow.python.ops import op_def_library RemoveDescriptionsFromOpDef(added); } - printf(R"(def _InitOpDefLibrary(): + strings::Appendf(&result, R"(def _InitOpDefLibrary(): op_list = op_def_pb2.OpList() text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) op_def_registry.register_op_list(op_list) @@ -687,7 +696,26 @@ _InitOpDefLibrary.op_list_ascii = """%s""" _op_def_lib = _InitOpDefLibrary() )", - cleaned_ops.DebugString().c_str()); + cleaned_ops.DebugString().c_str()); + return result; +} + +void PrintPythonOps(const OpList& ops, const string& hidden_ops, + bool require_shapes) { + printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str()); +} + +string GetAllPythonOps(const char* hidden, bool require_shapes) { + OpList ops; + OpRegistry::Global()->Export(false, &ops); + return GetPythonOps(ops, hidden, require_shapes); +} + +string GetPythonWrappers(const char* buf, size_t len) { + string op_list_str(buf, len); + OpList ops; + ops.ParseFromString(op_list_str); + return GetPythonOps(ops, "", false); } } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index b998f2247b..faac3caab5 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -22,10 +22,19 @@ limitations under the License. namespace tensorflow { -// Result is printed to stdout. hidden_ops should be a comma-separated +// hidden_ops should be a comma-separated // list of Op names that should get a leading _ in the output. +// The Print* version prints the output to stdout, Get* version returns the +// output as a string. void PrintPythonOps(const OpList& ops, const string& hidden_ops, bool require_shapes); +string GetPythonOps(const OpList& ops, const string& hidden_ops, + bool require_shapes); + +// Get the python wrappers for a list of ops in a OpList. +// buf should be a pointer to a buffer containing the binary encoded OpList +// proto, and len should be the length of that buffer. +string GetPythonWrappers(const char* buf, size_t len); } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i new file mode 100644 index 0000000000..08f53f101b --- /dev/null +++ b/tensorflow/python/framework/python_op_gen.i @@ -0,0 +1,24 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/framework/python_op_gen.h" +%} + +%ignoreall; +%unignore tensorflow::GetPythonWrappers; +%include "tensorflow/python/framework/python_op_gen.h" diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 2b65e483de..0f4eb744f1 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -171,6 +171,8 @@ class TensorFlowTestCase(googletest.TestCase): A Session object that should be used as a context manager to surround the graph building and execution code in a test case. """ + if self.id().endswith(".test_session"): + self.skipTest("Not a test.") def prepare_config(config): if config is None: config = config_pb2.ConfigProto() diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index 0ea573932b..ab0676d9ec 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -364,5 +364,14 @@ class ConcatOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) self.assertLess(err, 1e-11) + def testConcatTuple(self): + c1 = np.random.rand(4, 4) + c2 = np.random.rand(4, 4) + with self.test_session(): + concat_list_t = tf.concat(0, [c1, c2]) + concat_tuple_t = tf.concat(0, (c1, c2)) + self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 58302a683d..6de4c905b1 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -330,6 +330,13 @@ class ControlFlowTest(tf.test.TestCase): result = exit_i.eval() self.assertAllEqual(10, result) + def testCondBool(self): + values = tf.constant(10) + fn1 = lambda: tf.add(values, 1) + fn2 = lambda: tf.sub(values, 1) + with self.assertRaisesRegexp(TypeError, "must not be a Python bool"): + _ = control_flow_ops.cond(False, fn1, fn2) + def testCondIndexedSlices(self): with self.test_session(): values = tf.constant(10) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index a823250d51..8f2720f1cf 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import tensorflow.python.platform import numpy as np @@ -55,6 +57,11 @@ class UnaryOpTest(tf.test.TestCase): tf_cpu = y.eval() self.assertShapeEqual(np_ans, y) self.assertAllClose(np_ans, tf_cpu) + + # TODO(ebrevdo): add gradient for lgamma (digamma) and remove lgamma here. + if tf_func in (tf.lgamma,): + return # Return early + if x.dtype == np.float32: s = list(np.shape(x)) jacob_t, jacob_n = tf.test.compute_gradient(inx, @@ -94,6 +101,17 @@ class UnaryOpTest(tf.test.TestCase): def _sigmoid(self, x): return 1.0 / (1.0 + np.exp(-x)) + def _replace_domain_error_with_inf(self, fn): + def func(x): + try: + return fn(x) + except ValueError, e: + if "domain error" in e.message: + return np.inf * np.ones_like(x) + else: + raise e + return func + def testFloatBasic(self): x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) y = (x + .5).astype(np.float32) # no zero @@ -113,6 +131,12 @@ class UnaryOpTest(tf.test.TestCase): self._compareBoth(y, np.sign, tf.sign) self._compareBoth(x, np.sin, tf.sin) self._compareBoth(x, np.cos, tf.cos) + self._compareBoth( + x, + np.vectorize(self._replace_domain_error_with_inf(math.lgamma)), + tf.lgamma) + self._compareBoth(x, np.vectorize(math.erf), tf.erf) + self._compareBoth(x, np.vectorize(math.erfc), tf.erfc) def testFloatTanhEdge(self): x = np.arange(40, 40 + 6).reshape(6).astype(np.float32) diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index f02e16a4ae..db8d4ba5c4 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -1124,6 +1124,33 @@ class FIFOQueueTest(tf.test.TestCase): thread.join() self.assertAllEqual(elem, results) + def testDtypes(self): + with self.test_session() as sess: + dtypes = [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, + tf.int64, tf.bool, tf.complex64] + shape = (32, 4, 128) + q = tf.FIFOQueue(32, dtypes, [shape[1:]] * len(dtypes)) + + input_tuple = [] + for dtype in dtypes: + np_dtype = dtype.as_numpy_dtype + np_array = np.random.randint(-10, 10, shape) + if dtype == tf.bool: + np_array = np_array > 0 + elif dtype == tf.complex64: + np_array = np.sqrt(np_array.astype(np_dtype)) + else: + np_array = np_array.astype(np_dtype) + input_tuple.append(np_array) + + q.enqueue_many(input_tuple).run() + + output_tuple_t = q.dequeue_many(32) + output_tuple = sess.run(output_tuple_t) + + for (input_elem, output_elem) in zip(input_tuple, output_tuple): + self.assertAllEqual(input_elem, output_elem) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py index 2ded0375a8..bcaaa8cc4e 100644 --- a/tensorflow/python/kernel_tests/gradient_checker_test.py +++ b/tensorflow/python/kernel_tests/gradient_checker_test.py @@ -27,6 +27,7 @@ import tensorflow as tf class GradientCheckerTest(tf.test.TestCase): def testAddSimple(self): + np.random.seed(1) # Fix seed to avoid flakiness with self.test_session(use_gpu=False): # a test case for Add operation size = (2, 3) @@ -40,6 +41,7 @@ class GradientCheckerTest(tf.test.TestCase): assert error < 1e-4 def testAddSimpleGPU(self): + np.random.seed(2) # Fix seed to avoid flakiness with self.test_session(use_gpu=True): # a test case for Add operation size = (2, 3) @@ -53,6 +55,7 @@ class GradientCheckerTest(tf.test.TestCase): assert error < 1e-4 def testAddCustomized(self): + np.random.seed(3) # Fix seed to avoid flakiness with self.test_session(): # a test case for Add operation size = (2, 3) @@ -74,6 +77,7 @@ class GradientCheckerTest(tf.test.TestCase): assert error < 1e-10 def testGather(self): + np.random.seed(4) # Fix seed to avoid flakiness with self.test_session(): p_shape = (4, 2) p_size = 8 @@ -89,6 +93,7 @@ class GradientCheckerTest(tf.test.TestCase): assert error < 1e-4 def testNestedGather(self): + np.random.seed(5) # Fix seed to avoid flakiness with self.test_session(): p_shape = (8, 2) p_size = 16 @@ -110,6 +115,9 @@ class GradientCheckerTest(tf.test.TestCase): # Gradient checker for MNIST. def BuildAndTestMiniMNIST(param_index, tag): + # Fix seed to avoid occasional flakiness + np.random.seed(6) + # Hyperparameters batch = 3 inputs = 16 diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 5a0ffce6b4..a470fb7274 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -642,6 +642,60 @@ class ParseSequenceExampleTest(tf.test.TestCase): "feature_list_dense_defaults": {"d": None}, }, expected_feat_list_values=expected_feature_list_output) + def testSequenceExampleWithSparseAndDenseFeatureLists(self): + feature_list_dense_keys = ["a"] + feature_list_dense_types = [tf.int64] + feature_list_dense_shapes = [(2,)] + + original = sequence_example(feature_lists=feature_lists({ + "a": feature_list([ + int64_feature([3, 4]), + int64_feature([1, 0])]), + "st_a": feature_list([ + float_feature([3.0, 4.0]), + float_feature([5.0]), + float_feature([])]), + "st_b": feature_list([ + bytes_feature([b"a"]), + bytes_feature([]), + bytes_feature([]), + bytes_feature([b"b", b"c"])])})) + + serialized = original.SerializeToString() + + expected_st_a = ( + np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices + np.array([3.0, 4.0, 5.0], dtype=np.float32), # values + np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2 + + expected_st_b = ( + np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices + np.array(["a", "b", "c"], dtype=np.str), # values + np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2 + + expected_st_c = ( + np.empty((0, 2), dtype=np.int64), # indices + np.empty((0,), dtype=np.int64), # values + np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0 + + expected_feature_list_output = { + "a": np.array([[3, 4], [1, 0]], dtype=np.int64), + "st_a": expected_st_a, + "st_b": expected_st_b, + "st_c": expected_st_c, + } + + self._test( + { + "debug_name": "in1", + "serialized": tf.convert_to_tensor(serialized), + "feature_list_dense_types": feature_list_dense_types, + "feature_list_dense_keys": feature_list_dense_keys, + "feature_list_dense_shapes": feature_list_dense_shapes, + "feature_list_sparse_keys": ["st_a", "st_b", "st_c"], + "feature_list_sparse_types": [tf.float32, tf.string, tf.int64] + }, expected_feat_list_values=expected_feature_list_output) + def testSequenceExampleListWithInconsistentDataFails(self): feature_list_dense_types = [tf.int64] feature_list_dense_shapes = [(2,)] @@ -687,6 +741,29 @@ class ParseSequenceExampleTest(tf.test.TestCase): expected_err_re=("Feature list: a, Index: 0. Data types don't match. " "Expected type: int64")) + def testSequenceExampleListWithWrongSparseDataTypeFails(self): + feature_list_sparse_types = [tf.int64] + + original = sequence_example(feature_lists=feature_lists({ + "a": feature_list([ + int64_feature([3, 4]), + int64_feature([1, 2, 3]), + float_feature([2])]) + })) + + serialized = original.SerializeToString() + + self._test( + { + "debug_name": "in1", + "serialized": tf.convert_to_tensor(serialized), + "feature_list_sparse_types": feature_list_sparse_types, + "feature_list_sparse_keys": ["a"] + }, + expected_err_re=( + "Name: in1, Feature List: a, Index: 2. Data types don't match. " + "Expected type: int64 Feature is: float_list")) + def testSequenceExampleListWithWrongShapeFails(self): feature_list_dense_types = [tf.int64] feature_list_dense_shapes = [(2,)] diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py new file mode 100644 index 0000000000..742402b8b7 --- /dev/null +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -0,0 +1,84 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for py_func op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.python.ops import script_ops + + +class PyOpTest(tf.test.TestCase): + + def testBasic(self): + + def my_func(x, y): + return np.sinh(x) + np.cosh(y) + + # scalar + with self.test_session(): + x = tf.constant(1.0, tf.float32) + y = tf.constant(2.0, tf.float32) + z = tf.py_func(my_func, [x, y], [tf.float32]) + self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32)) + + # array + with self.test_session(): + x = tf.constant([1.0, 2.0], tf.float64) + y = tf.constant([2.0, 3.0], tf.float64) + z = tf.py_func(my_func, [x, y], [tf.float64]) + self.assertAllEqual( + z[0].eval(), + my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) + + # a bit exotic type (complex64) + with self.test_session(): + x = tf.constant(1+2j, tf.complex64) + y = tf.constant(3+4j, tf.complex64) + z, = tf.py_func(my_func, [x, y], [tf.complex64]) + self.assertAllClose(z.eval(), my_func(1+2j, 3+4j)) + + # a bit excotic function (rfft) + with self.test_session(): + x = tf.constant([1., 2., 3., 4.], tf.float32) + def rfft(x): + return np.fft.rfft(x).astype(np.complex64) + y, = tf.py_func(rfft, [x], [tf.complex64]) + self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.])) + + def testLarge(self): + with self.test_session() as sess: + x = tf.zeros([1000000], dtype=np.float32) + y = tf.py_func(lambda x: x + 1, [x], [tf.float32]) + z = tf.py_func(lambda x: x * 2, [x], [tf.float32]) + for _ in xrange(100): + sess.run([y[0].op, z[0].op]) + + def testCleanup(self): + for _ in range(1000): + g = tf.Graph() + with g.as_default(): + c = tf.constant([1.], tf.float32) + _ = tf.py_func(lambda x: x + 1, [c], [tf.float32]) + self.assertTrue(script_ops._py_funcs.size() < 100) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index b1188d0672..2882182a03 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -235,7 +235,7 @@ class TextLineReaderTest(tf.test.TestCase): def _LineText(self, f, l): return tf.compat.as_bytes("%d: %d" % (f, l)) - def _CreateFiles(self): + def _CreateFiles(self, crlf=False): filenames = [] for i in range(self._num_files): fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) @@ -246,11 +246,10 @@ class TextLineReaderTest(tf.test.TestCase): # Always include a newline after the record unless it is # at the end of the file, in which case we include it sometimes. if j + 1 != self._num_lines or i == 0: - f.write(b"\n") + f.write(b"\r\n" if crlf else b"\n") return filenames - def testOneEpoch(self): - files = self._CreateFiles() + def _testOneEpoch(self, files): with self.test_session() as sess: reader = tf.TextLineReader(name="test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) @@ -268,6 +267,12 @@ class TextLineReaderTest(tf.test.TestCase): "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value]) + def testOneEpochLF(self): + self._testOneEpoch(self._CreateFiles(crlf=False)) + + def testOneEpochCRLF(self): + self._testOneEpoch(self._CreateFiles(crlf=True)) + def testSkipHeaderLines(self): files = self._CreateFiles() with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 7ff3851da7..3b79ae341b 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -174,6 +174,28 @@ class SumReductionTest(tf.test.TestCase): def testGradient4(self): self._compareGradient([2, 3, 4, 2], [], None) + def testHighRank(self): + # Do a bunch of random high dimensional reductions + np.random.seed(42) + for _ in range(20): + rank = np.random.randint(4, 10 + 1) + axes, = np.nonzero(np.random.randint(2, size=rank)) + shape = tuple(np.random.randint(1, 3 + 1, size=rank)) + data = np.random.randint(1024, size=shape) + self._compareAll(data, axes) + # Check some particular axis patterns + for rank in 4, 7, 10: + shape = tuple(np.random.randint(1, 3 + 1, size=rank)) + data = np.random.randint(1024, size=shape) + for axes in ([], np.arange(rank), np.arange(0, rank, 2), + np.arange(1, rank, 2)): + self._compareAll(data, axes) + + def testExpand(self): + # Reduce an empty tensor to a nonempty tensor + x = np.zeros((5, 0)) + self._compareAll(x, [1]) + class MeanReductionTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 81be48990b..38ba890c74 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -227,15 +227,23 @@ class TileTest(tf.test.TestCase): def testSimple(self): with self.test_session(): - inp = np.random.rand(4, 1).astype("f") - a = tf.constant([float(x) for x in inp.ravel(order="C")], - shape=[4, 1], dtype=tf.float32) + inp = np.random.rand(4, 1).astype(np.float32) + a = tf.constant(inp) tiled = tf.tile(a, [1, 4]) result = tiled.eval() self.assertEqual(result.shape, (4, 4)) self.assertEqual([4, 4], tiled.get_shape()) self.assertTrue((result == np.tile(inp, (1, 4))).all()) + def testEmpty(self): + with self.test_session(): + inp = np.random.rand(2, 3).astype(np.float32) + a = tf.constant(inp) + tiled = tf.tile(a, [5, 0]) + result = tiled.eval() + self.assertEqual(result.shape, (10, 0)) + self.assertEqual([10, 0], tiled.get_shape()) + def testTypes(self): types_to_test = { "bool": (tf.bool, bool), diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index c6e91dcd71..d6ee16c8e2 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -19,7 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=unused-import, g-bad-import-order import tensorflow.python.platform +# pylint: enable=unused-import, g-bad-import-order import numpy as np @@ -46,13 +48,17 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): constant_op.constant(shape, dtypes.int64)) def _SparseTensor_2x3x4(self, dtype): + # Includes two entries with the form [1, 1, x] : 150. ind = np.array([ [0, 0, 1], - [0, 1, 0], [0, 1, 2], + [0, 1, 0], + [0, 1, 2], [1, 0, 3], - [1, 1, 1], [1, 1, 3], + [1, 1, 0], + [1, 1, 1], + [1, 1, 2], [1, 2, 2]]) - val = np.array([1, 10, 12, 103, 111, 113, 122]) + val = np.array([1, 10, 12, 103, 150, 149, 150, 122]) shape = np.array([2, 3, 4]) return ops.SparseTensor( constant_op.constant(ind, dtypes.int64), @@ -90,7 +96,8 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): expected_output = np.zeros((2, 3, 200), dtype=np.bool) expected_trues = [(0, 0, 1), (0, 1, 10), (0, 1, 12), - (1, 0, 103), (1, 1, 111), (1, 1, 113), (1, 2, 122)] + (1, 0, 103), (1, 1, 149), (1, 1, 150), + (1, 2, 122)] for expected_true in expected_trues: expected_output[expected_true] = True diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py index ee9a697a0b..6ea1e6d8eb 100644 --- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py +++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py @@ -25,9 +25,11 @@ import tensorflow as tf def _SparseToDense(sparse_indices, output_size, sparse_values, - default_value): + default_value, validate_indices=True): return tf.sparse_to_dense(sparse_indices, output_size, - sparse_values, default_value) + sparse_values, + default_value=default_value, + validate_indices=validate_indices) class SparseToDenseTest(tf.test.TestCase): @@ -107,10 +109,24 @@ class SparseToDenseTest(tf.test.TestCase): def testBadDefault(self): with self.test_session(): - dense = _SparseToDense([1, 3], [5], [1, 2], [1, 2]) + dense = _SparseToDense([1, 3], [5], [1, 2], [0]) with self.assertRaisesOpError("default_value should be a scalar"): dense.eval() + def testInvalidIndicesWithWithoutValidation(self): + with self.test_session(): + dense = _SparseToDense( + sparse_indices=[[1], [1]], output_size=[5], + sparse_values=[-1.0, 1.0], default_value=0.0) + with self.assertRaisesOpError( + "not lexicographically sorted or containing repeats"): + dense.eval() + # Disable checks + dense_without_validation = _SparseToDense( + sparse_indices=[[1], [1]], output_size=[5], + sparse_values=[-1.0, 1.0], default_value=0.0, validate_indices=False) + dense_without_validation.eval() + def testShapeInferenceKnownShape(self): with self.test_session(use_gpu=False): indices = tf.placeholder(tf.int64) diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index c6af05ff22..c769987a47 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -186,8 +186,8 @@ class TransposeTest(tf.test.TestCase): def testError(self): with self.assertRaises(ValueError): tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]]) - self._testError(np.arange(0., 2 ** 10).reshape([2] * 10), - np.arange(10), + self._testError(np.arange(0., 2 ** 11).reshape([2] * 11), + np.arange(11), "not implemented") with self.assertRaises(IndexError): tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3]) diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc new file mode 100644 index 0000000000..87c64014bf --- /dev/null +++ b/tensorflow/python/lib/core/py_func.cc @@ -0,0 +1,338 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/py_func.h" + +#include <Python.h> +#include "numpy/arrayobject.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +static mutex mu; +static bool initialized GUARDED_BY(mu) = false; +static PyObject* py_trampoline GUARDED_BY(mu) = nullptr; + +// Returns the py_trampoline that is used to pass the control to the +// python runtime. +PyObject* GetPyTrampoline() { + mutex_lock l(mu); + return py_trampoline; +} + +// Module initialization (mainly import numpy) if needed. +void InitIfNeeded() { + mutex_lock l(mu); + if (!initialized) { + PyGILState_STATE py_threadstate; + py_threadstate = PyGILState_Ensure(); + import_array(); + PyGILState_Release(py_threadstate); + initialized = true; + } +} + +// Returns a single-thread threadpool used to execute python +// trampoline and the python function. It is single threaded because +// GIL is needed running the trampoline. +thread::ThreadPool* py_thread() { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "PyTrampoline", 1); + return w; +} + +// Returns the corresponding numpy dtype in 'np' for tf data type +// 'tf'. Returns an error if the type is not supported by this +// module. +Status TfDTypeToNpDType(const DataType& tf, int* np) { + switch (tf) { + case DT_FLOAT: + *np = NPY_FLOAT32; + break; + case DT_DOUBLE: + *np = NPY_FLOAT64; + break; + case DT_INT32: + *np = NPY_INT32; + break; + case DT_UINT8: + *np = NPY_UINT8; + break; + case DT_INT8: + *np = NPY_INT8; + break; + case DT_INT16: + *np = NPY_INT16; + break; + case DT_INT64: + *np = NPY_INT64; + break; + case DT_BOOL: + *np = NPY_BOOL; + break; + case DT_COMPLEX64: + *np = NPY_COMPLEX64; + break; + default: + return errors::Unimplemented("Unsupported tf type ", DataTypeString(tf)); + } + return Status::OK(); +} + +// Creates a numpy array in 'ret' and copies the content of tensor 't' +// into 'ret'. +Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret) { + int typenum; + TF_RETURN_IF_ERROR(TfDTypeToNpDType(t.dtype(), &typenum)); + PyArray_Descr* descr = PyArray_DescrFromType(typenum); + CHECK(descr); + std::vector<npy_intp> dims; + for (int i = 0; i < t.dims(); ++i) { + dims.push_back(t.dim_size(i)); + } + PyObject* obj = PyArray_Empty(dims.size(), dims.data(), descr, 0); + if (obj == nullptr) { + return errors::Internal("Failed to allocate np array: ", + t.shape().ShortDebugString()); + } + PyArrayObject* np_array = reinterpret_cast<PyArrayObject*>(obj); + CHECK(DataTypeCanUseMemcpy(t.dtype())); + StringPiece p = t.tensor_data(); + memcpy(np_array->data, p.data(), p.size()); + *ret = PyArray_Return(np_array); + return Status::OK(); +} + +// A call to the registered python function. +struct PyCall { + // Passed to python runtime to call the python function registered + // with this "token". + string token; + + // Inputs and outputs of this function invokation. + std::vector<Tensor> ins; + std::vector<Tensor> out; +}; + +// Givens the 'call', prepares the token and inputs as a python tuple +// that is appropriate for calling the trampoline. +Status MakeArgTuple(PyCall* call, PyObject** tuple) { + int64 n = call->ins.size(); + PyObject* lst = PyList_New(n); + CHECK(lst); + for (int64 i = 0; i < n; ++i) { + const Tensor& t = call->ins[i]; + PyObject* a; + Status s = ConvertTensorToNdarray(t, &a); + if (!s.ok()) { + Py_DECREF(lst); + return s; + } + PyList_SetItem(lst, i, a); + } + *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst); + CHECK(*tuple); + return Status::OK(); +} + +// Returns the corresponding tf dtype in 'tf' for numpy data type +// 'np'. Returns an error if the type is not supported by this +// module. +Status NpDTypeToTfDType(const int np, DataType* tf) { + switch (np) { + case NPY_FLOAT32: + *tf = DT_FLOAT; + break; + case NPY_FLOAT64: + *tf = DT_DOUBLE; + break; + case NPY_INT32: + *tf = DT_INT32; + break; + case NPY_UINT8: + *tf = DT_UINT8; + break; + case NPY_INT8: + *tf = DT_INT8; + break; + case NPY_INT16: + *tf = DT_INT16; + break; + case NPY_INT64: + *tf = DT_INT64; + break; + case NPY_BOOL: + *tf = DT_BOOL; + break; + case NPY_COMPLEX64: + *tf = DT_COMPLEX64; + break; + default: + return errors::Unimplemented("Unsupported numpy type ", np); + } + return Status::OK(); +} + +// Given an numpy ndarray object 'obj', creates a corresponding tf +// Tensor in '*ret'. +Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { + PyArrayObject* a = reinterpret_cast<PyArrayObject*>(obj); + DataType dtype; + TF_RETURN_IF_ERROR(NpDTypeToTfDType(PyArray_TYPE(a), &dtype)); + CHECK(DataTypeCanUseMemcpy(dtype)); + TensorShape shape; + for (int i = 0; i < PyArray_NDIM(a); ++i) { + shape.AddDim(PyArray_SHAPE(a)[i]); + } + Tensor t(dtype, shape); + StringPiece p = t.tensor_data(); + memcpy(const_cast<char*>(p.data()), a->data, p.size()); + *ret = t; + return Status::OK(); +} + +// Calls the registered py function through the trampoline. +Status DoCallPyFunc(PyCall* call) { + PyObject* trampoline = GetPyTrampoline(); + if (trampoline == nullptr) { + return errors::InvalidArgument( + "Missing py trampoline. Most likely, it is a link error."); + } + // Prepare the argument. + PyObject* args = nullptr; + TF_RETURN_IF_ERROR(MakeArgTuple(call, &args)); + CHECK(args); + + // Invokes the trampoline. + PyObject* result = PyEval_CallObject(trampoline, args); + Py_DECREF(args); + if (result == nullptr) { + return errors::Internal("Failed to run py callback ", call->token, + ": see error log."); + } + + // Process the return values and converts them to tf Tensors. + Status s; + if (PyList_Check(result)) { + // 'result' is a list. + call->out.clear(); + for (int i = 0; i < PyList_Size(result); ++i) { + Tensor t; + s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t); + if (!s.ok()) { + break; + } + call->out.push_back(t); + } + } else if (PyArray_Check(result)) { + // 'result' is a single ndarray. + Tensor t; + s = ConvertNdarrayToTensor(result, &t); + if (s.ok()) { + call->out.push_back(t); + } + } else { + // 'result' is a plain python scalar. We convert it to an numpy + // scalar then convert it to a Tensor. + PyObject* scalar = PyArray_FromScalar(result, nullptr); + if (scalar == nullptr) { + s = errors::InvalidArgument( + call->token, + " returns a value which can't be converted into numpy scalar."); + } else { + Tensor t; + s = ConvertNdarrayToTensor(scalar, &t); + if (s.ok()) { + call->out.push_back(t); + } + Py_DECREF(scalar); + } + } + Py_DECREF(result); + return s; +} + +// Calls the python function in a separate thread. Arranges to call +// done() when the python function returns. +void CallPyFunc(PyCall* call, std::function<void(Status)> done) { + InitIfNeeded(); + py_thread()->Schedule([call, done]() { + PyGILState_STATE py_threadstate; + py_threadstate = PyGILState_Ensure(); + Status s = DoCallPyFunc(call); + PyGILState_Release(py_threadstate); + done(s); + }); +} + +} // end namespace + +void InitializePyTrampoline(PyObject* trampoline) { + mutex_lock l(mu); + if (py_trampoline == nullptr) { + py_trampoline = trampoline; + Py_INCREF(py_trampoline); + } else { + LOG(WARNING) << "InitializeCallback should only be called once"; + } +} + +class PyFuncOp : public AsyncOpKernel { + public: + explicit PyFuncOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); + } + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + PyCall* call = new PyCall; + call->token = token_; + for (int i = 0; i < ctx->num_inputs(); ++i) { + call->ins.push_back(ctx->input(i)); + } + CallPyFunc(call, [this, ctx, call, done](Status s) { + std::unique_ptr<PyCall> delete_me(call); + OP_REQUIRES_OK_ASYNC(ctx, s, done); + OP_REQUIRES_ASYNC( + ctx, call->out.size() == ctx->num_outputs(), + errors::InvalidArgument(token_, " returns ", call->out.size(), + " values, but expects to see ", + ctx->num_outputs(), " values."), + done); + for (int i = 0; i < call->out.size(); ++i) { + const auto& t = call->out[i]; + OP_REQUIRES_ASYNC( + ctx, t.dtype() == output_type(i), + errors::InvalidArgument(i, "-th value returned by ", token_, " is ", + DataTypeString(t.dtype()), ", but expects ", + DataTypeString(output_type(i))), + done); + ctx->set_output(i, t); + } + done(); + }); + } + + private: + string token_; + + TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp); +}; +REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp); + +} // end namespace tensorflow diff --git a/tensorflow/python/lib/core/py_func.h b/tensorflow/python/lib/core/py_func.h new file mode 100644 index 0000000000..2de52fb492 --- /dev/null +++ b/tensorflow/python/lib/core/py_func.h @@ -0,0 +1,47 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_ + +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +#include <Python.h> + +namespace tensorflow { + +// Called by py code on initialization. +// +// "trampoline" must represent a python function which has the +// following signature: +// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar +// +// The trampoline takes two arguments, the first is a string token +// used by the python frontend's dispatching logic; the second is a +// list of numpy ndarrays. +// +// The trampoline can return a single numpy ndarray, a list of numpy +// ndarrays, or a simply python scalar. The C++ runtime converts them, +// if supported, back to Tensor objects. +// +// This is called by script_ops.py during its module initialization. +// +// TODO(zhifengc): Support distributed runtime. +void InitializePyTrampoline(PyObject* trampoline); + +} // end namespace tensorflow + +#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_ diff --git a/tensorflow/python/lib/core/py_func.i b/tensorflow/python/lib/core/py_func.i new file mode 100644 index 0000000000..c85bbc1c55 --- /dev/null +++ b/tensorflow/python/lib/core/py_func.i @@ -0,0 +1,29 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/lib/core/py_func.h" +%} + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::InitializePyTrampoline; + +%include "tensorflow/python/lib/core/py_func.h" + +%unignoreall diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index e3526d32c4..d3b62f03c6 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -23,7 +23,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops @@ -231,7 +230,22 @@ ops.NoGradient("Size") def _TileGrad(op, grad): """Sum reduces grad along the tiled dimensions.""" assert isinstance(grad, ops.Tensor) - return [gen_array_ops._tile_grad(grad, op.inputs[1]), None] + input_shape = array_ops.shape(op.inputs[0]) + # We interleave multiples and input_shape to get split_shape, + # reshape grad to split_shape, and reduce along all even + # dimensions (the tiled dimensions) to get the result + # with shape input_shape. For example + # input_shape = [20, 30, 40] + # multiples = [2, 3, 4] + # split_shape = [2, 20, 3, 30, 4, 40] + # axes = [0, 2, 4] + split_shape = array_ops.reshape(array_ops.transpose( + array_ops.pack([op.inputs[1], input_shape])), [-1]) + axes = math_ops.range(0, array_ops.size(split_shape), 2) + input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) + # Fix shape inference + input_grad.set_shape(op.inputs[0].get_shape()) + return [input_grad, None] ops.NoGradient("TileGrad") diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 613bdf49f0..0f36ed7e41 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -299,7 +299,7 @@ def concat(concat_dim, values, name="concat"): Returns: A `Tensor` resulting from concatenation of the input tensors. """ - if not isinstance(values, (list)): + if not isinstance(values, (list, tuple)): values = [values] # TODO(mrry): Change to return values? if len(values) == 1: # Degenerate case of one tensor. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index ce78458515..f3d17aa12d 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -69,6 +69,7 @@ from __future__ import print_function import six from six.moves import xrange # pylint: disable=redefined-builtin + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -82,6 +83,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import,undefined-variable from tensorflow.python.ops.gen_control_flow_ops import * +from tensorflow.python.platform import logging # We override the 'tuple' for a control flow op, so we keep python's @@ -630,6 +632,8 @@ def cond(pred, fn1, fn2, name=None): raise TypeError("fn2 must be callable.") # Add the Switch to the graph. + if isinstance(pred, bool): + raise TypeError("pred must not be a Python bool") p_2, p_1 = switch(pred, pred) pivot_1 = array_ops.identity(p_1, name="switch_t") pivot_2 = array_ops.identity(p_2, name="switch_f") @@ -1172,7 +1176,7 @@ def with_dependencies(dependencies, output_tensor, name=None): TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. """ with ops.op_scope(dependencies + [output_tensor], name, - "control_dependency") as name: + "control_dependency") as name: with ops.device(output_tensor.device or ops.get_default_graph().get_default_device()): with ops.control_dependencies(dependencies): @@ -1237,6 +1241,7 @@ def group(*inputs, **kwargs): # 2-level tree. The root node is the returned NoOp node. # deps contains 1 NoOp node for each device. deps = [] + def device_key(dev): """A sort key that allows None to be compared to strings.""" return "" if dev is None else dev @@ -1244,6 +1249,7 @@ def group(*inputs, **kwargs): deps.append(_GroupControlDeps(dev, ops_on_device[dev])) return _GroupControlDeps(None, deps, name=name) + def tuple(tensors, name=None, control_inputs=None): """Group tensors together. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 5afc8a779b..33d4ac2a0b 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -308,6 +308,22 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): c_dense = math_ops.mul(c_sparse, 1.0) self.assertAllClose(np_val, c_dense.eval()) + def testIndexedSlicesToTensorList(self): + with self.test_session(): + numpy_list = [] + dense_list = [] + sparse_list = [] + for _ in range(3): + np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) + c = constant_op.constant(np_val) + c_sparse = math_ops._as_indexed_slices(c) + numpy_list.append(np_val) + dense_list.append(c) + sparse_list.append(c_sparse) + packed_dense = array_ops.pack(dense_list) + packed_sparse = array_ops.pack(sparse_list) + self.assertAllClose(packed_dense.eval(), packed_sparse.eval()) + def testInt64Indices(self): with self.test_session(): np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 55b20a1d10..aa5a2edf86 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -223,6 +225,28 @@ def _TanhGrad(op, grad): return grad * (1 - math_ops.square(y)) +@ops.RegisterGradient("Erf") +def _ErfGrad(op, grad): + """Returns grad * 2/sqrt(pi) * exp(-x**2).""" + x = op.inputs[0] + two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) + return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) + + +@ops.RegisterGradient("Erfc") +def _ErfcGrad(op, grad): + """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" + x = op.inputs[0] + two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) + return -grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) + + +@ops.RegisterGradient("Lgamma") +def _LgammaGrad(op, grad): # pylint: disable=unused-argument + # TODO(ebrevdo): implement digamma + raise NotImplementedError("grad(Lgamma) == Digamma is not implemented") + + @ops.RegisterGradient("Sigmoid") def _SigmoidGrad(op, grad): """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index ec382da9b2..fa12adf8ce 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -49,6 +49,9 @@ mathematical functions to your graph. @@minimum @@cos @@sin +@@lgamma +@@erf +@@erfc ## Matrix Math Functions @@ -1097,6 +1100,57 @@ def tanh(x, name=None): return gen_math_ops._tanh(x, name=name) +def lgamma(x, name=None): + """Computes `ln(|gamma(x)|)` element-wise. + + Args: + x: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. + name: A name for the operation (optional). + + Returns: + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + """ + with ops.op_scope([x], name, "Lgamma") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops._lgamma(x, name=name) + + +def erf(x, name=None): + """Computes Gauss error function of `x` element-wise. + + Args: + x: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. + name: A name for the operation (optional). + + Returns: + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + """ + with ops.op_scope([x], name, "Erf") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops._erf(x, name=name) + + +def erfc(x, name=None): + """Computes complementary error function of `x` element-wise. + + Args: + x: A Tensor with type `float`, `double`, `int32`, `int64`, + or `qint32`. + name: A name for the operation (optional). + + Returns: + A Tensor with the same type as `x` if `x.dtype != qint32` otherwise + the return type is `quint8`. + """ + with ops.op_scope([x], name, "Erfc") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops._erfc(x, name=name) + + ops.RegisterShape("Abs")(common_shapes.unchanged_shape) ops.RegisterShape("Ceil")(common_shapes.unchanged_shape) ops.RegisterShape("Conj")(common_shapes.unchanged_shape) @@ -1119,6 +1173,9 @@ ops.RegisterShape("Sqrt")(common_shapes.unchanged_shape) ops.RegisterShape("Square")(common_shapes.unchanged_shape) ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape) ops.RegisterShape("Tanh")(common_shapes.unchanged_shape) +ops.RegisterShape("Lgamma")(common_shapes.unchanged_shape) +ops.RegisterShape("Erf")(common_shapes.unchanged_shape) +ops.RegisterShape("Erfc")(common_shapes.unchanged_shape) ops.RegisterShape("Cast")(common_shapes.unchanged_shape) ops.RegisterShape("ComplexAbs")(common_shapes.unchanged_shape) ops.RegisterShape("FFT2D")(common_shapes.unchanged_shape) diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 72adf9e498..6fecea8666 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -686,7 +686,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, if sampled_logits.dtype != acc_weights.dtype: acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype) sampled_logits += sparse_ops.sparse_to_dense( - sparse_indices, sampled_logits_shape, acc_weights, 0.0) + sparse_indices, sampled_logits_shape, acc_weights, + default_value=0.0, validate_indices=False) if subtract_log_q: # Subtract log of Q(l), prior probability that l appears in sampled. diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index 149bfe712a..94d874f067 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -376,16 +376,14 @@ class OpDefLibrary(object): try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype - values = ops.convert_n_to_tensor_or_indexed_slices( - values, name=input_arg.name, - dtype=dtype if dtype else None, + values = ops.convert_n_to_tensor( + values, name=input_arg.name, dtype=dtype if dtype else None, as_ref=input_arg.is_ref) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? - values = ops.convert_n_to_tensor_or_indexed_slices( - values, as_ref=input_arg.is_ref) + values = ops.convert_n_to_tensor(values, as_ref=input_arg.is_ref) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( @@ -659,8 +657,7 @@ class OpDefLibrary(object): input_types=input_types, attrs=attr_protos, op_def=op_def) outputs = op.outputs - return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs), - output_structure) + return _Restructure(ops.convert_n_to_tensor(outputs), output_structure) else: return g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 69cd7fbc56..c9cbfb1d7d 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -404,6 +404,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name context_dense_types=None, context_dense_defaults=None, context_dense_shapes=None, + feature_list_sparse_keys=None, + feature_list_sparse_types=None, feature_list_dense_keys=None, feature_list_dense_types=None, feature_list_dense_shapes=None, @@ -461,6 +463,12 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name map will be treated as empty (zero length) if not found in the `FeatureList` map. + The key `feature_list_sparse_keys[j]` is mapped to a `SparseTensor` of type + `feature_list_sparse_types[j]`. This `SparseTensor` represents a ragged + vector. Its indices are `[time, index]`, where `time` is the FeatureList + entry `index` is the value's index in the list of values associated with that + time. + `debug_name` may contain a descriptive name for the corresponding serialized proto. This may be useful for debugging purposes, but it has no effect on the output. If not `None`, `debug_name` must be a scalar. @@ -485,6 +493,12 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name The shape of the data for each context_dense feature referenced by `context_dense_keys`. Required for any input tensors identified by `context_dense_keys` whose shapes are anything other than `[]` or `[1]`. + feature_list_sparse_keys: A list of string keys in the `SequenceExample`'s + feature_lists. The results for these keys will be returned as + `SparseTensor` objects. + feature_list_sparse_types: A list of `DTypes`, same length as `sparse_keys`. + Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), + and `tf.string` (`BytesList`) are supported. feature_list_dense_keys: A list of string keys in the `SequenceExample`'s features_lists. The results for these keys will be returned as `Tensor`s. feature_list_dense_types: A list of `DTypes`, same length as @@ -528,6 +542,10 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name context_dense_shapes = ( [[]] * len(context_dense_keys) if context_dense_shapes is None else context_dense_shapes) + feature_list_sparse_keys = ( + [] if feature_list_sparse_keys is None else feature_list_sparse_keys) + feature_list_sparse_types = ( + [] if feature_list_sparse_types is None else feature_list_sparse_types) feature_list_dense_keys = ( [] if feature_list_dense_keys is None else feature_list_dense_keys) feature_list_dense_types = ( @@ -545,6 +563,7 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name num_context_dense = len(context_dense_keys) num_feature_list_dense = len(feature_list_dense_keys) num_context_sparse = len(context_sparse_keys) + num_feature_list_sparse = len(feature_list_sparse_keys) if len(context_dense_shapes) != num_context_dense: raise ValueError( @@ -567,15 +586,28 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name raise ValueError( "len(context_sparse_types) != len(context_sparse_keys): %d vs. %d" % (len(context_sparse_types), num_context_sparse)) - if num_context_dense + num_context_sparse + num_feature_list_dense == 0: + if len(feature_list_sparse_types) != num_feature_list_sparse: + raise ValueError( + "len(feature_list_sparse_types) != len(feature_list_sparse_keys): " + "%d vs. %d" + % (len(feature_list_sparse_types), num_feature_list_sparse)) + if (num_context_dense + num_context_sparse + + num_feature_list_dense + num_feature_list_sparse) == 0: raise ValueError( "Must provide at least one context_sparse key, context_dense key, " - "or feature_list_dense key") + ", feature_list_sparse key, or feature_list_dense key") if not set(context_dense_keys).isdisjoint(set(context_sparse_keys)): raise ValueError( - "Context_Dense and context_sparse keys must not intersect; " + "context_dense and context_sparse keys must not intersect; " "intersection: %s" % set(context_dense_keys).intersection(set(context_sparse_keys))) + if not set(feature_list_dense_keys).isdisjoint( + set(feature_list_sparse_keys)): + raise ValueError( + "feature_list_dense and feature_list_sparse keys must not intersect; " + "intersection: %s" % + set(feature_list_dense_keys).intersection( + set(feature_list_sparse_keys))) if not isinstance(feature_list_dense_defaults, dict): raise TypeError("feature_list_dense_defaults must be a dict") for k, v in feature_list_dense_defaults.items(): @@ -613,6 +645,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name context_sparse_types=context_sparse_types, context_dense_keys=context_dense_keys, context_dense_shapes=context_dense_shapes, + feature_list_sparse_keys=feature_list_sparse_keys, + feature_list_sparse_types=feature_list_sparse_types, feature_list_dense_keys=feature_list_dense_keys, feature_list_dense_types=feature_list_dense_types, feature_list_dense_shapes=feature_list_dense_shapes, @@ -622,7 +656,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name (context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, - feature_list_dense_values) = outputs + feature_list_sparse_indices, feature_list_sparse_values, + feature_list_sparse_shapes, feature_list_dense_values) = outputs context_sparse_tensors = [ ops.SparseTensor(ix, val, shape) for (ix, val, shape) @@ -630,12 +665,18 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name context_sparse_values, context_sparse_shapes)] + feature_list_sparse_tensors = [ + ops.SparseTensor(ix, val, shape) for (ix, val, shape) + in zip(feature_list_sparse_indices, + feature_list_sparse_values, + feature_list_sparse_shapes)] + context_output = dict( zip(context_sparse_keys + context_dense_keys, context_sparse_tensors + context_dense_values)) feature_list_output = dict( - zip(feature_list_dense_keys, - feature_list_dense_values)) + zip(feature_list_sparse_keys + feature_list_dense_keys, + feature_list_sparse_tensors + feature_list_dense_values)) return (context_output, feature_list_output) @@ -651,6 +692,7 @@ def _ParseSingleSequenceExampleShape(op): num_context_dense = op.get_attr("Ncontext_dense") num_feature_list_dense = op.get_attr("Nfeature_list_dense") context_dense_shapes = op.get_attr("context_dense_shapes") + num_feature_list_sparse = op.get_attr("Nfeature_list_sparse") feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes") context_sparse_index_shapes = [ tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)] @@ -661,6 +703,12 @@ def _ParseSingleSequenceExampleShape(op): context_dense_shapes = [ tensor_shape.TensorShape(dense_shape) for dense_shape in context_dense_shapes] + feature_list_sparse_index_shapes = [ + tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)] + feature_list_sparse_value_shapes = [ + tensor_shape.vector(None) for _ in range(num_feature_list_sparse)] + feature_list_sparse_shape_shapes = [ + tensor_shape.vector(2) for _ in range(num_feature_list_sparse)] feature_list_dense_shapes = [ tensor_shape.vector(None).concatenate(dense_shape) for dense_shape in feature_list_dense_shapes] @@ -668,7 +716,8 @@ def _ParseSingleSequenceExampleShape(op): assert num_feature_list_dense == len(feature_list_dense_shapes) return (context_sparse_index_shapes + context_sparse_value_shapes + context_sparse_shape_shapes + context_dense_shapes + - feature_list_dense_shapes) + feature_list_sparse_index_shapes + feature_list_sparse_value_shapes + + feature_list_sparse_shape_shapes + feature_list_dense_shapes) ops.RegisterShape("StringToNumber")( diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py new file mode 100644 index 0000000000..caf376987f --- /dev/null +++ b/tensorflow/python/ops/script_ops.py @@ -0,0 +1,135 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""## Script Language Operators. + +TensorFlow provides allows you to wrap python/numpy functions as +TensorFlow operators. + +""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import ops +from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import gen_script_ops + + +class FuncRegistry(object): + """A helper class to keep track of registered py functions. + + FuncRegistry keeps a map from unique tokens (string) to python + functions, which takes numpy arrays and outputs numpy arrays. + """ + + def __init__(self): + self._unique_id = 0 + self._funcs = {} + + def insert(self, func): + """Registers `func` and returns a unique token for this entry.""" + token = self._next_unique_token() + self._funcs[token] = func + return token + + def remove(self, token): + """Removes the registered function corresponding to `token`.""" + self._funcs.pop(token, None) + + def __call__(self, token, args): + """Calls the registered function for `token` with args.""" + func = self._funcs[token] + if func is None: + raise ValueError("callback %s is not found" % token) + return func(*args) + + def size(self): + """Returns how many functions are currently registered.""" + return len(self._funcs) + + def _next_unique_token(self): + """Returns a unique token.""" + uid = self._unique_id + self._unique_id += 1 + return "pyfunc_%d" % uid + +# Global registry for py functions. +_py_funcs = FuncRegistry() + +pywrap_tensorflow.InitializePyTrampoline(_py_funcs) + + +class CleanupFunc(object): + """A helper class to remove a registered function from _py_funcs.""" + + def __init__(self, token): + self._token = token + + def __del__(self): + _py_funcs.remove(self._token) + + +def py_func(func, inp, Tout, name=None): + """Wraps a python function and uses it as a tensorflow op. + + Given a python function `func`, which takes numpy arrays as its + inputs and returns numpy arrays as its outputs. E.g., + + def my_func(x): + return np.sinh(x) + inp = tf.placeholder(..., tf.float32) + y = py_func(my_func, [inp], [tf.float32]) + + The above snippet constructs a tf graph which invokes a numpy + sinh(x) as an op in the graph. + + Args: + func: A python function. + inp: A list of `Tensor`. + Tout: A list of tensorflow data types indicating what `func` + returns. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` which `func` computes. + """ + token = _py_funcs.insert(func) + # We tie the registered function's life-time with the current + # default graph. I.e., when the current graph is destroyed, we + # should remove its py funcs. + cleanup = CleanupFunc(token) + g = ops.get_default_graph() + # pylint: disable=protected-access + # + # TODO(zhifengc): Consider adding a Graph method to collect + # `cleanup` objects in one of its member. + if not hasattr(g, "_cleanup_py_funcs_used_in_graph"): + g._cleanup_py_funcs_used_in_graph = [] + + # When g is destroyed, elements in _cleanup_py_funcs_used_in_graph + # will be destroyed and their __del__ will remove the 'token' from + # the funcs registry. + g._cleanup_py_funcs_used_in_graph.append(cleanup) + + return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name) + # pylint: enable=protected-access + + +ops.RegisterShape("PyFunc")(common_shapes.unknown_shape) + +ops.NoGradient("PyFunc") diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 99e7e708f1..1a7ce8e40f 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -310,6 +310,7 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, + validate_indices=True, name=None): """Converts a sparse representation into a dense tensor. @@ -329,6 +330,10 @@ def sparse_to_dense(sparse_indices, All other values in `dense` are set to `default_value`. If `sparse_values` is a scalar, all sparse indices are set to this single value. + Indices should be sorted in lexicographic order, and indices must not + contain any repeats. If `validate_indices` is True, these properties + are checked during execution. + Args: sparse_indices: A 0-D, 1-D, or 2-D `Tensor` of type `int32` or `int64`. `sparse_indices[i]` contains the complete index where `sparse_values[i]` @@ -339,6 +344,8 @@ def sparse_to_dense(sparse_indices, `sparse_indices`, or a scalar value to be used for all sparse indices. default_value: A 0-D `Tensor` of the same type as `sparse_values`. Value to set for indices not specified in `sparse_indices`. Defaults to zero. + validate_indices: A boolean value. If True, indices are checked to make + sure they are sorted in lexicographic order and that there are no repeats. name: A name for the operation (optional). Returns: @@ -348,11 +355,15 @@ def sparse_to_dense(sparse_indices, return gen_sparse_ops._sparse_to_dense(sparse_indices, output_shape, sparse_values, - default_value, + default_value=default_value, + validate_indices=validate_indices, name=name) -def sparse_tensor_to_dense(sp_input, default_value=0, name=None): +def sparse_tensor_to_dense(sp_input, + default_value=0, + validate_indices=True, + name=None): """Converts a `SparseTensor` into a dense tensor. This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s. @@ -370,10 +381,15 @@ def sparse_tensor_to_dense(sp_input, default_value=0, name=None): [x x x x x] [c x x x x]] + Indices must be without repeats. This is only + tested if validate_indices is True. + Args: sp_input: The input `SparseTensor`. default_value: Scalar value to set for indices not specified in `sp_input`. Defaults to zero. + validate_indices: A boolean value. If `True`, indices are checked to make + sure they are sorted in lexicographic order and that there are no repeats. name: A name prefix for the returned tensors (optional). Returns: @@ -390,7 +406,8 @@ def sparse_tensor_to_dense(sp_input, default_value=0, name=None): return sparse_to_dense(sp_input.indices, sp_input.shape, sp_input.values, - default_value, + default_value=default_value, + validate_indices=validate_indices, name=name) @@ -410,15 +427,18 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): [0, 0, 0]: 0 [0, 1, 0]: 10 [1, 0, 3]: 103 - [1, 1, 2]: 112 - [1, 1, 3]: 113 + [1, 1, 2]: 150 + [1, 1, 3]: 149 + [1, 1, 4]: 150 [1, 2, 1]: 121 and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool tensor with False everywhere except at positions - (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121). + (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 149), (1, 1, 150), + (1, 2, 121). + Note that repeats are allowed in the input SparseTensor. This op is useful for converting `SparseTensor`s into dense formats for compatibility with ops that expect dense tensors. @@ -460,7 +480,10 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): sp_new = ops.SparseTensor(new_indices, new_values, new_shape) - return sparse_tensor_to_dense(sp_new, False, name=name) + # validate_indices may be False because we allow duplicates in new_indices: + # repeated indices are allowed when creating an indicator matrix. + return sparse_tensor_to_dense( + sp_new, default_value=False, validate_indices=False, name=name) def sparse_retain(sp_input, to_retain): diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 2075e3c913..e2180737df 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -42,6 +42,7 @@ from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * from tensorflow.python.ops.random_ops import * +from tensorflow.python.ops.script_ops import py_func from tensorflow.python.ops.sparse_ops import * from tensorflow.python.ops.state_ops import assign from tensorflow.python.ops.state_ops import assign_add diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py index 800ab7bc7e..dc6d2b17f2 100644 --- a/tensorflow/python/ops/summary_ops.py +++ b/tensorflow/python/ops/summary_ops.py @@ -165,8 +165,8 @@ def scalar_summary(tags, values, collections=None, name=None): summary has a summary value for each tag-value pair in `tags` and `values`. Args: - tags: A 1-D `string` `Tensor`. Tags for the summaries. - values: A 1-D `float32` or `float64` Tensor. Values for the summaries. + tags: A `string` `Tensor`. Tags for the summaries. + values: A real numeric Tensor. Values for the summaries. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. name: A name for the operation (optional). diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py index d7ae189c21..4e84623c79 100644 --- a/tensorflow/python/platform/default/_flags.py +++ b/tensorflow/python/platform/default/_flags.py @@ -32,7 +32,7 @@ class _FlagValues(object): self.__dict__['__parsed'] = False def _parse_flags(self): - result = _global_parser.parse_args() + result, _ = _global_parser.parse_known_args() for flag_name, val in vars(result).items(): self.__dict__['__flags'][flag_name] = val self.__dict__['__parsed'] = True diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py index 3868576c2f..e6cd57d5a9 100644 --- a/tensorflow/python/platform/default/flags_test.py +++ b/tensorflow/python/platform/default/flags_test.py @@ -86,7 +86,8 @@ class FlagsTest(googletest.TestCase): if __name__ == "__main__": # Test command lines sys.argv.extend(["--bool_a", "--nobool_negation", "--bool_c=True", - "--bool_d=False", "--bool_e=gibberish"]) + "--bool_d=False", "--bool_e=gibberish", "--unknown_flag", + "and_argument"]) # googletest.main() tries to interpret the above flags, so use the # direct functions instead. diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index ce9770b3f4..65ea4d2e17 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -19,6 +19,7 @@ limitations under the License. %include "tensorflow/python/util/port.i" +%include "tensorflow/python/lib/core/py_func.i" %include "tensorflow/python/lib/core/status.i" %include "tensorflow/python/lib/core/status_helper.i" @@ -27,3 +28,5 @@ limitations under the License. %include "tensorflow/python/client/events_writer.i" %include "tensorflow/python/client/tf_session.i" + +%include "tensorflow/python/framework/python_op_gen.i" diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index efd6f2a807..e7510fe325 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -256,3 +256,97 @@ class Coordinator(object): elif stragglers: raise RuntimeError("Coordinator stopped with threads still running: %s", " ".join(stragglers)) + + +# Threads for the standard services. +class LooperThread(threading.Thread): + """A thread that runs code repeatedly, optionally on a timer. + + This thread class is intended to be used with a `Coordinator`. It repeatedly + runs code specified either as `target` and `args` or by the `run_loop()` + method. + + Before each run the thread checks if the coordinator has requested stop. In + that case the looper thread terminates immediately. + + If the code being run raises an exception, that exception is reported to the + coordinator and the thread terminates. The coordinator will then request all + the other threads it coordinates to stop. + + You typically pass looper threads to the supervisor `Join()` method. + """ + + def __init__(self, coord, timer_interval_secs, target=None, args=None): + """Create a LooperThread. + + Args: + coord: a Coordinator. + timer_interval_secs: Time boundaries at which to call Run(), or None + if it should be called back to back. + target: Optional callable object that will be executed in the thread. + args: Optional arguments to pass to `target` when calling it. + + Raises: + ValueError: If one of the arguments is invalid. + """ + if not isinstance(coord, Coordinator): + raise ValueError("'coord' argument must be a Coordinator: %s" % coord) + super(LooperThread, self).__init__() + self.daemon = True + self._coord = coord + self._timer_interval_secs = timer_interval_secs + self._target = target + if self._target: + if args is None: + self._args = () + else: + self._args = args + elif args: + raise ValueError("'args' argument require that you also pass 'target'") + + @staticmethod + def loop(coord, timer_interval_secs, target, args=None): + """Start a LooperThread that calls a function periodically. + + If `timer_interval_secs` is None the thread calls `target(args)` + repeatedly. Otherwise `target(args)` is called every `timer_interval_secs` + seconds. The thread terminates when a stop of the coordinator is + requested. + + Args: + coord: A Coordinator. + timer_interval_secs: Number. Time boundaries at which to call `target`. + target: A callable object. + args: Optional arguments to pass to `target` when calling it. + + Returns: + The started thread. + """ + looper = LooperThread(coord, timer_interval_secs, target=target, args=args) + looper.start() + return looper + + # pylint: disable=broad-except + def run(self): + with self._coord.stop_on_exception(): + self.start_loop() + if self._timer_interval_secs is None: + # Call back-to-back. + while not self._coord.should_stop(): + self.run_loop() + else: + # Next time at which to call run_loop(), starts as 'now'. + next_timer_time = time.time() + while not self._coord.wait_for_stop(next_timer_time - time.time()): + next_timer_time += self._timer_interval_secs + self.run_loop() + # pylint: enable=broad-except + + def start_loop(self): + """Called when the thread starts.""" + pass + + def run_loop(self): + """Called at 'timer_interval_secs' boundaries.""" + if self._target: + self._target(*self._args) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 6b70ddae3e..2091687e7c 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -44,7 +44,7 @@ class Optimizer(object): # Add Ops to the graph to minimize a cost by updating a list of variables. # "cost" is a Tensor, and the list of variables contains tf.Variable # objects. - opt_op = opt.minimize(cost, <list of variables>) + opt_op = opt.minimize(cost, var_list=<list of variables>) ``` In the training program you will just have to run the returned Op. diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 126f520328..34db5c3cd3 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -813,9 +813,8 @@ class Saver(object): last_checkpoints: A list of checkpoint filenames. Raises: - AssertionError: If the list of checkpoint filenames has already been set. + AssertionError: If last_checkpoints is not a list. """ - assert not self._last_checkpoints assert isinstance(last_checkpoints, list) # We use a timestamp of +inf so that this checkpoint will never be # deleted. This is both safe and backwards compatible to a previous diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 737c980c6c..b9f5f9a54b 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -139,6 +139,7 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer # Utility classes for training. from tensorflow.python.training.coordinator import Coordinator +from tensorflow.python.training.coordinator import LooperThread from tensorflow.python.training.queue_runner import * # For the module level doc. diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index 6c94856772..c88dc88d29 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -319,7 +319,7 @@ void PopContextAndCheckNowNull(CUcontext expected) { CUcontext popped; CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxPopCurrent_v2(&popped)); CHECK_EQ(expected, popped); - CHECK(nullptr == CurrentContext()); + DCHECK(nullptr == CurrentContext()); VLOG(3) << "popped context " << expected << " and current context is now null"; } @@ -395,7 +395,7 @@ ScopedActivateContext::ScopedActivateContext(CUcontext context, ScopedActivateContext::~ScopedActivateContext() { if (tls_in_multi_op_activation.get()) { - CHECK_EQ(context_, CurrentContext()); + DCHECK_EQ(context_, CurrentContext()); if (FLAGS_gpuexec_cuda_sync_around_driver_calls) { auto res = dynload::cuCtxSynchronize(); if (res != CUDA_SUCCESS) { @@ -470,7 +470,7 @@ static port::Status InternalInit() { LOG(ERROR) << "injecting CUDA init error; initialization will fail"; } else if (internal::CachedDsoLoader::GetLibcudaDsoHandle().ok()) { // We only call cuInit if we can dynload libcuda. - + res = dynload::cuInit(0 /* = flags */); } @@ -570,7 +570,7 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) { { // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their // context creation: see http://b/13248943 - + res = dynload::cuCtxCreate_v2(context, flags, device); } if (res == CUDA_SUCCESS) { @@ -737,7 +737,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) { { // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their // module loading: see http://b/13248943 - + res = dynload::cuModuleLoadDataEx(module, ptx_data, ARRAYSIZE(options), options, option_values); } diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts index b3511ce7b6..c3e23c760b 100644 --- a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts +++ b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts @@ -25,23 +25,23 @@ module TF { function router(route: string): ((tag: string, run: string) => string) { return function(tag: string, run: string): string { - return "/" + route + "?tag=" + encodeURIComponent(tag) + return "/data/" + route + "?tag=" + encodeURIComponent(tag) + "&run=" + encodeURIComponent(run); }; } export function runsUrl() { - return "/runs"; + return "/data/runs"; } export var scalarsUrl = router("scalars"); export var histogramsUrl = router("histograms"); export var compressedHistogramsUrl = router("compressedHistograms"); export var imagesUrl = router("images"); export function individualImageUrl(query: string) { - return "/individualImage?" + query; + return "/data/individualImage?" + query; } export function graphUrl(run: string) { - return "/graph?run=" + encodeURIComponent(run); + return "/data/graph?run=" + encodeURIComponent(run); } } diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts index 17b753a2f9..66ed9fc84a 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts @@ -18,17 +18,19 @@ limitations under the License. declare module graphlib { interface GraphOptions { - name: string; + name?: string; /** * Direction for rank nodes. Can be TB, BT, LR, or RL, where T = top, * B = bottom, L = left, and R = right. */ - rankdir: string; - type: string|number; + rankdir?: string; + type?: string|number; /** Number of pixels between each rank in the layout. */ ranksep?: number; /** Number of pixels that separate nodes horizontally in the layout. */ nodesep?: number; + /** Number of pixels that separate edges horizontally in the layout */ + edgesep?: number; } export interface EdgeObject { @@ -58,7 +60,10 @@ declare module graphlib { edges(): EdgeObject[]; outEdges(name: string): E[]; inEdges(name: string): E[]; - /** Returns those nodes in the graph that have no in-edges. Takes O(|V|) time. */ + /** + * Returns those nodes in the graph that have no in-edges. + * Takes O(|V|) time. + */ sources(): string[]; /** * Remove the node with the id v in the graph or do nothing if diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts index 41f00c54f4..a9b2cf1934 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts @@ -19,7 +19,6 @@ module tf.graph { /** Delimiter used in node names to denote namespaces. */ export const NAMESPACE_DELIM = "/"; -const FULL_GRAPH_NAME = "fullGraph"; export const ROOT_NAME = "__root__"; // Separator between the source and the destination name of the edge. @@ -315,8 +314,8 @@ class OpNodeImpl implements OpNode { * @param rawNode The raw node. * @param normalizedInputs An array of normalized * inputs that denote the incoming edges to the current node. Each input - * contains the normalized name of the source node, whether it has a number - * part and whether it is a control dependency. + * contains the normalized name of the source node, whether it has a + * number part and whether it is a control dependency. */ constructor(rawNode: tf.TFNode, normalizedInputs: NormalizedInput[]) { this.op = rawNode.op; @@ -340,8 +339,8 @@ export function createMetanode(name: string, opt = {}): Metanode { } /** - * Joins the information from the stats file (memory, compute time) with the graph - * information. + * Joins the information from the stats file (memory, compute time) with the + * graph information. */ export function joinStatsInfoWithGraph(graph: SlimGraph, statsJson: TFStats): void { @@ -894,7 +893,8 @@ export function hasSimilarDegreeSequence(graph1: graphlib.Graph<any, any>, /** * Returns the hierarchical path of the current node, based on the node's name. - * For example, if the name is 'a/b/c', the returned path is ['a', 'a/b', 'a/b/c']. + * For example, if the name is 'a/b/c', the returned path is + * ['a', 'a/b', 'a/b/c']. */ export function getHierarchicalPath(name: string, seriesNames?: { [name: string]: string }): string[] { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts index 1c8e1b2e18..504e47f4d5 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts @@ -21,8 +21,6 @@ limitations under the License. */ module tf.graph.hierarchy { -const LOG_PREFIX_MSG = "Graph hierarchy: "; - /** * Class used as output for getPredecessors and getSuccessors methods */ @@ -469,7 +467,8 @@ function addNodes(h: Hierarchy, graph: SlimGraph) { } parent = child; } - // Assuming node name is 'a/b/c', assign the OpNode as a child of the metanode 'a/b'. + // Assuming node name is 'a/b/c', assign the OpNode as a child of the + // metanode 'a/b'. h.setNode(node.name, node); node.parentNode = parent; parent.metagraph.setNode(node.name, node); @@ -567,7 +566,8 @@ function addEdges(h: Hierarchy, graph: SlimGraph, * @param hierarchy * @param threshold If the series has this many nodes or more, then group them * into a series. - * @return A dictionary from node name to series node name that contains the node + * @return A dictionary from node name to series node name that contains the + * node. */ function groupSeries(metanode: Metanode, hierarchy: Hierarchy, seriesNames: { [name: string]: string }, threshold: number) { @@ -589,9 +589,6 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy, if (nodeMemberNames.length < threshold) { return; } - let firstMember = seriesNode.metagraph.node(nodeMemberNames[0]); - let seriesType = firstMember.type; - hierarchy.setNode(seriesName, seriesNode); // add to the index metagraph.setNode(seriesName, seriesNode); _.each(nodeMemberNames, n => { @@ -620,7 +617,8 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy, function clusterNodes(metagraph: graphlib.Graph<GroupNode|OpNode, Metaedge>): {[clusterId: string]: string[]} { let result: {[clusterId: string]: string[]} = {}; - return _.reduce(metagraph.nodes(), function(clusters: {[clusterId: string]: string[]}, n: string) { + return _.reduce(metagraph.nodes(), + (clusters: {[clusterId: string]: string[]}, n: string) => { let child = metagraph.node(n); if (child.type === NodeType.META) { // skip metanodes @@ -702,7 +700,8 @@ function detectSeries(clusters: {[clusterId: string]: string[]}, let seriesNodes = [seriesInfoArray[0]]; for (let index = 1; index < seriesInfoArray.length; index++) { let nextNode = seriesInfoArray[index]; - if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) { + if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + + 1) { seriesNodes.push(nextNode); continue; } diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts index 5a0559627f..b003e33177 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts @@ -33,14 +33,19 @@ export const PARAMS = { * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - nodeSep: 110, + nodeSep: 5, /** * Dagre's ranksep param - number of pixels * between each rank in the layout. * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - rankSep: 25 + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5, }, /** Graph parameter for metanode. */ series: { @@ -50,7 +55,7 @@ export const PARAMS = { * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - nodeSep: 90, + nodeSep: 5, /** * Dagre's ranksep param - number of pixels * between each rank in the layout. @@ -58,6 +63,11 @@ export const PARAMS = { * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5 }, /** * Padding is used to correctly position the graph SVG inside of its parent @@ -166,6 +176,10 @@ export const PARAMS = { } }, annotations: { + /** Maximum possible width of the bounding box for in annotations */ + inboxWidth: 50, + /** Maximum possible width of the bounding box for out annotations */ + outboxWidth: 50, /** X-space between the shape and each annotation-node. */ xOffset: 10, /** Y-space between each annotation-node. */ @@ -202,7 +216,7 @@ export const PARAMS = { }; /** Calculate layout for a scene of a group node. */ -export function scene(renderNodeInfo: render.RenderGroupNodeInformation) +export function layoutScene(renderNodeInfo: render.RenderGroupNodeInfo) : void { // Update layout, size, and annotations of its children nodes and edges. if (renderNodeInfo.node.isGroupNode) { @@ -218,9 +232,32 @@ export function scene(renderNodeInfo: render.RenderGroupNodeInformation) }; /** + * Updates the total width of an unexpanded node which includes the size of its + * in and out annotations. + */ +function updateTotalWidthOfNode(renderInfo: render.RenderNodeInfo): void { + renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ? + PARAMS.annotations.inboxWidth : 0; + renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ? + PARAMS.annotations.outboxWidth : 0; + // Assign the width of the core box (the main shape of the node). + renderInfo.coreBox.width = renderInfo.width; + renderInfo.coreBox.height = renderInfo.height; + // TODO(jimbo): Account for font width rather than using a magic number. + let labelLength = renderInfo.node.name.length - + renderInfo.node.name.lastIndexOf(NAMESPACE_DELIM) - 1; + let charWidth = 3; // 3 pixels per character. + // Compute the total width of the node. + renderInfo.width = Math.max(renderInfo.coreBox.width + + renderInfo.inboxWidth + renderInfo.outboxWidth, + labelLength * charWidth); + +} + +/** * Update layout, size, and annotations of its children nodes and edges. */ -function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation) +function layoutChildren(renderNodeInfo: render.RenderGroupNodeInfo) : void { let children = renderNodeInfo.coreGraph.nodes().map(n => { return renderNodeInfo.coreGraph.node(n); @@ -238,25 +275,25 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation) break; case NodeType.META: if (!childNodeInfo.expanded) { - // set fixed width and scalable height based on cardinality + // Set fixed width and scalable height based on cardinality _.extend(childNodeInfo, PARAMS.nodeSize.meta); childNodeInfo.height = PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); } else { let childGroupNodeInfo = - <render.RenderGroupNodeInformation>childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. + <render.RenderGroupNodeInfo>childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. } break; case NodeType.SERIES: if (childNodeInfo.expanded) { _.extend(childNodeInfo, PARAMS.nodeSize.series.expanded); let childGroupNodeInfo = - <render.RenderGroupNodeInformation>childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. + <render.RenderGroupNodeInfo>childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. } else { let childGroupNodeInfo = - <render.RenderGroupNodeInformation>childNodeInfo; + <render.RenderGroupNodeInfo>childNodeInfo; let seriesParams = childGroupNodeInfo.node.hasNonControlEdges ? PARAMS.nodeSize.series.vertical : @@ -267,7 +304,11 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation) default: throw Error("Unrecognized node type: " + childNodeInfo.node.type); } - + // Compute total width of un-expanded nodes. Width of expanded nodes + // has already been computed. + if (!childNodeInfo.expanded) { + updateTotalWidthOfNode(childNodeInfo); + } // Layout each child's annotations layoutAnnotation(childNodeInfo); }); @@ -279,13 +320,14 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation) * @param params layout parameters * @return width and height of the core graph */ -function dagreLayout(graph: graphlib.Graph<any, any>, params) - : {height: number, width: number} { +function dagreLayout( + graph: graphlib.Graph<render.RenderNodeInfo, render.RenderMetaedgeInfo>, + params): {height: number, width: number} { _.extend(graph.graph(), { - nodeSep: params.nodeSep, - rankSep: params.rankSep - }); - + nodesep: params.nodeSep, + ranksep: params.rankSep, + edgesep: params.edgeSep + }); let bridgeNodeNames = []; let nonBridgeNodeNames = []; @@ -307,11 +349,8 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params) height: 0, }; } - dagre.layout(graph); - let graphLabel = graph.graph(); - // Calculate the true bounding box of the graph by iterating over nodes and // edges rather than accepting dagre's word for it. In particular, we should // ignore the extra-wide bridge nodes and bridge edges, and allow for @@ -323,33 +362,65 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params) _.each(nonBridgeNodeNames, nodeName => { let nodeInfo = graph.node(nodeName); let w = 0.5 * nodeInfo.width; - let x1 = nodeInfo.x - w - nodeInfo.inboxWidth; - let x2 = nodeInfo.x + w + nodeInfo.outboxWidth; + let x1 = nodeInfo.x - w; + let x2 = nodeInfo.x + w; minX = x1 < minX ? x1 : minX; maxX = x2 > maxX ? x2 : maxX; - let labelLength = - nodeName.length - nodeName.lastIndexOf(NAMESPACE_DELIM); - // TODO(jimbo): Account for font width rather than using a magic number. - let charWidth = 3; // 3 pixels per character. - let lw = 0.5 * labelLength * charWidth; - let lx1 = nodeInfo.x - lw; - let lx2 = nodeInfo.x + lw; - minX = lx1 < minX ? lx1 : minX; - maxX = lx2 > maxX ? lx2 : maxX; // TODO(jimbo): Account for the height of labels above op nodes here. - let h = 0.5 * nodeInfo.outerHeight; + let h = 0.5 * nodeInfo.height; let y1 = nodeInfo.y - h; let y2 = nodeInfo.y + h; minY = y1 < minY ? y1 : minY; maxY = y2 > maxY ? y2 : maxY; }); _.each(graph.edges(), edgeObj => { - let renderMetaedgeInfo = graph.edge(edgeObj); - if (renderMetaedgeInfo.structural) { + let edgeInfo = graph.edge(edgeObj); + if (edgeInfo.structural) { return; // Skip structural edges from min/max calculations. } - _.each(renderMetaedgeInfo.points, - (point: { x: number, y: number }) => { + + // Since the node size passed to dagre includes the in and out + // annotations, the endpoints of the edge produced by dagre may not + // point to the actual node shape (rectangle, ellipse). We correct the + // end-points by finding the intersection of a line between the + // next-to-last (next-to-first) point and the destination (source) + // rectangle. + let sourceNode = graph.node(edgeInfo.metaedge.v); + let destNode = graph.node(edgeInfo.metaedge.w); + + // Straight 3-points edges are special case, since they are curved after + // our default correction. To keep them straight, we remove the mid point + // and correct the first and the last point to be the center of the + // source and destination node respectively. + if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) { + if (sourceNode != null) { + let cxSource = sourceNode.expanded ? + sourceNode.x : computeCXPositionOfNodeShape(sourceNode); + edgeInfo.points[0].x = cxSource; + } + if (destNode != null) { + let cxDest = destNode.expanded ? + destNode.x : computeCXPositionOfNodeShape(destNode); + edgeInfo.points[2].x = cxDest; + } + // Remove the middle point so the edge doesn't curve. + edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]]; + } + // Correct the destination endpoint of the edge. + let nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2]; + // The destination node might be null if this is a bridge edge. + if (destNode != null) { + edgeInfo.points[edgeInfo.points.length - 1] = + intersectPointAndNode(nextToLastPoint, destNode); + } + // Correct the source endpoint of the edge. + let secondPoint = edgeInfo.points[1]; + // The source might be null if this is a bridge edge. + if (sourceNode != null) { + edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode); + } + + _.each(edgeInfo.points, (point: render.Point) => { minX = point.x < minX ? point.x : minX; maxX = point.x > maxX ? point.x : maxX; minY = point.y < minY ? point.y : minY; @@ -365,8 +436,7 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params) nodeInfo.y -= minY; }); _.each(graph.edges(), edgeObj => { - _.each(graph.edge(edgeObj).points, - (point: { x: number, y: number }) => { + _.each(graph.edge(edgeObj).points, (point: render.Point) => { point.x -= minX; point.y -= minY; }); @@ -374,16 +444,15 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params) return { width: maxX - minX, - height: maxY - minY, + height: maxY - minY }; } -/** Layout a metanode. */ -function layoutMetanode(renderNodeInfo): void { +/** Layout a metanode. Only called for an expanded node. */ +function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void { // First, copy params specific to meta nodes onto this render info object. let params = PARAMS.subscene.meta; - renderNodeInfo = _.extend(renderNodeInfo, params); - + _.extend(renderNodeInfo, params); // Invoke dagre.layout() on the core graph and record the bounding box // dimensions. _.extend(renderNodeInfo.coreBox, @@ -392,70 +461,70 @@ function layoutMetanode(renderNodeInfo): void { // Calculate the position of nodes in isolatedInExtract relative to the // top-left corner of inExtractBox (the bounding box for all inExtract nodes) // and calculate the size of the inExtractBox. - let hasInExtract = renderNodeInfo.isolatedInExtract.length > 0; - - renderNodeInfo.inExtractBox.width = hasInExtract ? - _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0; + let maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract, + renderNode => renderNode.width).width; + renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ? + maxInExtractWidth : 0; renderNodeInfo.inExtractBox.height = - _.reduce(renderNodeInfo.isolatedInExtract, (height, child: any, i) => { + _.reduce(renderNodeInfo.isolatedInExtract, (height, child, i) => { let yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.inExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; }, 0); // Calculate the position of nodes in isolatedOutExtract relative to the // top-left corner of outExtractBox (the bounding box for all outExtract // nodes) and calculate the size of the outExtractBox. - let hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0; - renderNodeInfo.outExtractBox.width = hasOutExtract ? - _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0; + let maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract, + renderNode => renderNode.width).width; + renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ? + maxOutExtractWidth : 0; renderNodeInfo.outExtractBox.height = - _.reduce(renderNodeInfo.isolatedOutExtract, (height, child: any, i) => { + _.reduce(renderNodeInfo.isolatedOutExtract, (height, child, i) => { let yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.outExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; }, 0); + // Add the in-extract and out-extract width to the core box width. + renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width + + renderNodeInfo.outExtractBox.width; + renderNodeInfo.coreBox.height = + params.labelHeight + + Math.max( + renderNodeInfo.inExtractBox.height, + renderNodeInfo.coreBox.height, + renderNodeInfo.outExtractBox.height + ); // Determine the whole metanode's width (from left to right). - renderNodeInfo.width = - params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight + - (hasInExtract ? - renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) + - (hasOutExtract ? - params.extractXOffset + renderNodeInfo.outExtractBox.width : 0); - - // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes. + renderNodeInfo.width = renderNodeInfo.coreBox.width + + params.paddingLeft + params.paddingRight; + // Determine the whole metanode's height (from top to bottom). renderNodeInfo.height = - renderNodeInfo.labelHeight + - params.paddingTop + - Math.max( - renderNodeInfo.inExtractBox.height, - renderNodeInfo.coreBox.height, - renderNodeInfo.outExtractBox.height - ) + - params.paddingBottom; + renderNodeInfo.paddingTop + + renderNodeInfo.coreBox.height + + renderNodeInfo.paddingBottom; } /** * Calculate layout for series node's core graph. Only called for an expanded * series. */ -function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { +function layoutSeriesNode(node: render.RenderGroupNodeInfo): void { let graph = node.coreGraph; let params = PARAMS.subscene.series; _.extend(node, params); // Layout the core. - _.extend(node.coreBox, - dagreLayout(node.coreGraph, PARAMS.graph.series)); + _.extend(node.coreBox, dagreLayout(node.coreGraph, PARAMS.graph.series)); _.each(graph.nodes(), nodeName => { graph.node(nodeName).excluded = false; @@ -468,24 +537,16 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { /** * Calculate layout for annotations of a given node. - * This will modify positions of the the given node and its annotations. + * This will modify positions of the given node and its annotations. * * @see tf.graph.render.Node and tf.graph.render.Annotation * for description of each property of each render node. * */ - function layoutAnnotation(renderNodeInfo: render.RenderNodeInformation): void { +function layoutAnnotation(renderNodeInfo: render.RenderNodeInfo): void { // If the render node is an expanded metanode, then its annotations will not // be visible and we should skip the annotation calculations. if (renderNodeInfo.expanded) { - _.extend(renderNodeInfo, { - inboxWidth: 0, - inboxHeight: 0, - outboxWidth: 0, - outboxHeight: 0, - outerWidth: renderNodeInfo.width, - outerHeight: renderNodeInfo.height - }); return; } @@ -499,31 +560,20 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { _.each(outAnnotations, a => sizeAnnotation(a)); let params = PARAMS.annotations; - renderNodeInfo.inboxWidth = - inAnnotations.length > 0 ? - (<any>_(inAnnotations).pluck("width").max()) + - params.xOffset + params.labelWidth + params.labelOffset : - 0; - - renderNodeInfo.outboxWidth = - outAnnotations.length > 0 ? - (<any>_(outAnnotations).pluck("width").max()) + - params.xOffset + params.labelWidth + params.labelOffset : - 0; // Calculate annotation node position (a.dx, a.dy) // and total height for in-annotations // After this chunk of code: // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) let inboxHeight = _.reduce(inAnnotations, - (height, a: any, i) => { + (height, a, i) => { let yOffset = i > 0 ? params.yOffset : 0; - a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset; + a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset; a.dy = height + yOffset + a.height / 2; return height + yOffset + a.height; }, 0); - _.each(inAnnotations, (a: any) => { + _.each(inAnnotations, a => { a.dy -= inboxHeight / 2; a.labelOffset = params.labelOffset; @@ -535,14 +585,14 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { // outboxHeight = sum of annotation heights + // (annotation.length - 1 * yOffset) let outboxHeight = _.reduce(outAnnotations, - (height, a: any, i) => { + (height, a, i) => { let yOffset = i > 0 ? params.yOffset : 0; - a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset; + a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset; a.dy = height + yOffset + a.height / 2; return height + yOffset + a.height; }, 0); - _.each(outAnnotations, (a: any) => { + _.each(outAnnotations, a => { // adjust by (half of ) the total height // so dy is relative to the host node's center. a.dy -= outboxHeight / 2; @@ -563,7 +613,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { .range([-inTouchHeight, inTouchHeight]); // Calculate annotation edge position - _.each(inAnnotations, (a: any, i) => { + _.each(inAnnotations, (a, i) => { a.points = [ // The annotation node end { @@ -573,7 +623,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { // The host node end { - dx: - renderNodeInfo.width / 2, + dx: - renderNodeInfo.coreBox.width / 2, // only use scale if there are more than one, // otherwise center it vertically dy: inAnnotations.length > 1 ? inY(i) : 0 @@ -591,12 +641,12 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { .domain([0, outAnnotations.length - 1]) .range([-outTouchHeight, outTouchHeight]); - _.each(outAnnotations, (a: any, i) => { + _.each(outAnnotations, (a, i) => { // Add point from the border of the annotation node a.points = [ // The host node end { - dx: renderNodeInfo.width / 2, + dx: renderNodeInfo.coreBox.width / 2, // only use scale if there are more than one, // otherwise center it vertically dy: outAnnotations.length > 1 ? outY(i) : 0 @@ -609,9 +659,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void { ]; }); - renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth + - renderNodeInfo.outboxWidth; - renderNodeInfo.outerHeight = + renderNodeInfo.height = Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); } @@ -640,4 +688,75 @@ function sizeAnnotation(a: render.Annotation): void { } } +/** + * Determines the center position of the node's shape. The position depends + * on if the node has in and out-annotations. + */ +export function computeCXPositionOfNodeShape(renderInfo: render.RenderNodeInfo): + number { + if (renderInfo.expanded) { + return renderInfo.x; + } + let dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0; + return renderInfo.x - renderInfo.width / 2 + dx + + renderInfo.coreBox.width / 2; +} + +/** Returns the angle (in degrees) between two points. */ +function angleBetweenTwoPoints(a: render.Point, b: render.Point): number { + let dx = b.x - a.x; + let dy = b.y - a.y; + return 180 * Math.atan(dy / dx) / Math.PI; +} + +/** + * Returns if a line going through the specified points is a straight line. + */ +function isStraightLine(points: render.Point[]) { + let angle = angleBetweenTwoPoints(points[0], points[1]); + for (let i = 1; i < points.length - 1; i++) { + let newAngle = angleBetweenTwoPoints(points[i], points[i + 1]); + // Have a tolerance of 1 degree. + if (Math.abs(newAngle - angle) > 1) { + return false; + } + angle = newAngle; + } + return true; +} + +/** + * Returns the intersection of a line between the provided point + * and the provided rectangle. + */ +function intersectPointAndNode(point: render.Point, node: render.RenderNodeInfo): + render.Point { + // cx and cy are the center of the rectangle. + let cx = node.expanded ? + node.x : computeCXPositionOfNodeShape(node); + let cy = node.y; + // Calculate the slope + let dx = point.x - cx; + let dy = point.y - cy; + let w = node.expanded ? node.width : node.coreBox.width; + let h = node.expanded ? node.height : node.coreBox.height; + let deltaX, deltaY; + if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) { + // The intersection is above or below the rectangle. + if (dy < 0) { + h = -h; + } + deltaX = dy === 0 ? 0 : h / 2 * dx / dy; + deltaY = h / 2; + } else { + // The intersection is left or right of the rectangle. + if (dx < 0) { + w = -w; + } + deltaX = w / 2; + deltaY = dx === 0 ? 0 : w / 2 * dy / dx; + } + return {x: cx + deltaX, y: cy + deltaY}; +} + } // close module diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts index c99c61a849..956b39d986 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts @@ -22,6 +22,8 @@ limitations under the License. module tf.graph.render { +export type Point = {x: number, y: number}; + /** * Color parameters for op nodes. */ @@ -149,18 +151,19 @@ export interface RenderGraphParams { * Stores the rendering information, such as x and y coordinates, * for each node in the graph. */ -export class RenderGraphInformation { +export class RenderGraphInfo { private hierarchy: hierarchy.Hierarchy; - private index: {[nodeName: string]: RenderNodeInformation}; + private index: {[nodeName: string]: RenderNodeInfo}; private params: RenderGraphParams; private deviceColorMap: d3.scale.Ordinal<string, string>; private memoryUsageScale: d3.scale.Linear<string, string>; private computeTimeScale: d3.scale.Linear<string, string>; // Since the rendering information for each node is constructed lazily, - // upon node's expansion by the user, we keep a map between the node's name and - // whether the rendering information was already constructed for that node. + // upon node's expansion by the user, we keep a map between the node's name + // and whether the rendering information was already constructed for that + // node. private hasSubhierarchy: {[nodeName: string]: boolean}; - root: RenderGroupNodeInformation; + root: RenderGroupNodeInfo; constructor(hierarchy: hierarchy.Hierarchy, params: RenderGraphParams) { this.hierarchy = hierarchy; @@ -185,7 +188,8 @@ export class RenderGraphInformation { .range(params.minMaxColors); // Find also the minimum and maximum compute time. - let computeTimeExtent = d3.extent(topLevelGraph.nodes(), (nodeName, index) => { + let computeTimeExtent = d3.extent(topLevelGraph.nodes(), + (nodeName, index) => { let node = topLevelGraph.node(nodeName); // Some ops don't have stats at all. if (node.stats != null) { @@ -196,27 +200,28 @@ export class RenderGraphInformation { .domain(computeTimeExtent) .range(params.minMaxColors); - // Maps node name to whether the rendering hierarchy was already constructed. + // Maps node name to whether the rendering hierarchy was already + // constructed. this.hasSubhierarchy = {}; this.params = params; - this.root = new RenderGroupNodeInformation(hierarchy.root); + this.root = new RenderGroupNodeInfo(hierarchy.root); this.index[hierarchy.root.name] = this.root; this.buildSubhierarchy(hierarchy.root.name); this.root.expanded = true; } /** - * Get a previously created RenderNodeInformation by its node name. + * Get a previously created RenderNodeInfo by its node name. */ - getRenderNodeByName(nodeName: string): RenderNodeInformation { + getRenderNodeByName(nodeName: string): RenderNodeInfo { return this.index[nodeName]; } /** - * Get a previously created RenderNodeInformation for the specified node name, + * Get a previously created RenderNodeInfo for the specified node name, * or create one if it hasn't been created yet. */ - getOrCreateRenderNodeByName(nodeName: string): RenderNodeInformation { + getOrCreateRenderNodeByName(nodeName: string): RenderNodeInfo { // Polymer may invoke this with null. if (!nodeName) { return null; @@ -228,8 +233,8 @@ export class RenderGraphInformation { let node = this.hierarchy.node(nodeName); let renderInfo = node.isGroupNode ? - new RenderGroupNodeInformation(<GroupNode>node) : - new RenderNodeInformation(node); + new RenderGroupNodeInfo(<GroupNode>node) : + new RenderNodeInfo(node); this.index[nodeName] = renderInfo; if (node.stats) { @@ -291,8 +296,8 @@ export class RenderGraphInformation { /** * Returns true if the renderNode is an isolated node within its parent node. */ - isNodeAuxilliary(renderNode: RenderNodeInformation): boolean { - let parentNode = <RenderGroupNodeInformation>this.getRenderNodeByName( + isNodeAuxilliary(renderNode: RenderNodeInfo): boolean { + let parentNode = <RenderGroupNodeInfo>this.getRenderNodeByName( renderNode.node.parentNode.name); let found = _.find(parentNode.isolatedInExtract, node => { return node.node.name === renderNode.node.name; @@ -322,7 +327,7 @@ export class RenderGraphInformation { } // At this point we know the rendering information is about a group node. - let renderGroupNodeInfo = <RenderGroupNodeInformation> renderNodeInfo; + let renderGroupNodeInfo = <RenderGroupNodeInfo> renderNodeInfo; let metagraph = renderGroupNodeInfo.node.metagraph; let coreGraph = renderGroupNodeInfo.coreGraph; @@ -339,16 +344,16 @@ export class RenderGraphInformation { if (!childNode.isGroupNode) { _.each((<OpNode>childNode).inEmbeddings, embedding => { - let renderMetaedgeInfo = new RenderMetaedgeInformation(null); + let renderMetaedgeInfo = new RenderMetaedgeInfo(null); addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.CONSTANT, this.params); - this.index[embedding.name] = new RenderNodeInformation(embedding); + this.index[embedding.name] = new RenderNodeInfo(embedding); }); _.each((<OpNode>childNode).outEmbeddings, embedding => { - let renderMetaedgeInfo = new RenderMetaedgeInformation(null); + let renderMetaedgeInfo = new RenderMetaedgeInfo(null); addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.SUMMARY, this.params); - this.index[embedding.name] = new RenderNodeInformation(embedding); + this.index[embedding.name] = new RenderNodeInfo(embedding); }); } @@ -357,7 +362,7 @@ export class RenderGraphInformation { // Add render metaedge info for edges in the metagraph. _.each(metagraph.edges(), edgeObj => { let metaedge = metagraph.edge(edgeObj); - let renderMetaedgeInfo = new RenderMetaedgeInformation(metaedge); + let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge); coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo); }); @@ -376,7 +381,7 @@ export class RenderGraphInformation { return; } let parentNodeInfo = - <RenderGroupNodeInformation> this.index[parentNode.name]; + <RenderGroupNodeInfo> this.index[parentNode.name]; // Utility function for computing the name of a bridge node. let getBridgeNodeName = (inbound, ...rest) => @@ -439,7 +444,7 @@ export class RenderGraphInformation { let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges && otherCounts.control[otherName] > this.params.maxControlDegree; - let [annotations, childAnnotations] = + let [, childAnnotations] = inbound ? [renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] : [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations]; @@ -472,7 +477,7 @@ export class RenderGraphInformation { inbound ? { v: targetName, w: nodeName } : { v: nodeName, w: targetName }; - return <RenderMetaedgeInformation> + return <RenderMetaedgeInfo> parentNodeInfo.coreGraph.edge(adjoiningEdgeObj); }; @@ -547,7 +552,7 @@ export class RenderGraphInformation { childAnnotations.push(new Annotation( otherNode, otherRenderInfo, - new RenderMetaedgeInformation(bridgeMetaedge), + new RenderMetaedgeInfo(bridgeMetaedge), AnnotationType.SHORTCUT, inbound), this.params); return; @@ -578,7 +583,7 @@ export class RenderGraphInformation { inbound: inbound, }; bridgeContainerInfo = - new RenderNodeInformation(bridgeContainerNode); + new RenderNodeInfo(bridgeContainerNode); this.index[bridgeContainerName] = bridgeContainerInfo; coreGraph.setNode(bridgeContainerName, bridgeContainerInfo); } @@ -596,7 +601,7 @@ export class RenderGraphInformation { // BridgeNode properties. inbound: inbound, }; - bridgeNodeRenderInfo = new RenderNodeInformation(bridgeNode); + bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode); this.index[bridgeNodeName] = bridgeNodeRenderInfo; coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo); @@ -607,7 +612,7 @@ export class RenderGraphInformation { // Create and add a bridge render metaedge. let bridgeRenderMetaedge = - new RenderMetaedgeInformation(bridgeMetaedge); + new RenderMetaedgeInfo(bridgeMetaedge); bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge; inbound ? coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) : @@ -716,7 +721,7 @@ export class RenderGraphInformation { // BridgeNode properties. inbound: inbound, }; - structuralRenderInfo = new RenderNodeInformation(bridgeNode); + structuralRenderInfo = new RenderNodeInfo(bridgeNode); structuralRenderInfo.structural = true; this.index[structuralNodeName] = structuralRenderInfo; coreGraph.setNode(structuralNodeName, structuralRenderInfo); @@ -725,7 +730,7 @@ export class RenderGraphInformation { } // Create the structural Metaedge and insert it. - let structuralMetaedgeInfo = new RenderMetaedgeInformation(null); + let structuralMetaedgeInfo = new RenderMetaedgeInfo(null); structuralMetaedgeInfo.structural = true; structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout. inbound ? @@ -748,8 +753,8 @@ export class RenderGraphInformation { */ export class Annotation { node: Node; - renderNodeInfo: RenderNodeInformation; - renderMetaedgeInfo: RenderMetaedgeInformation; + renderNodeInfo: RenderNodeInfo; + renderMetaedgeInfo: RenderMetaedgeInfo; annotationType: AnnotationType; /** * Center position of annotation relative to the host @@ -791,8 +796,8 @@ export class Annotation { * @param isIn True if it is an in-annotation. False if it is an * out-annotation. */ - constructor(node: Node, renderNodeInfo: RenderNodeInformation, - renderMetaedgeInfo: RenderMetaedgeInformation, type: AnnotationType, + constructor(node: Node, renderNodeInfo: RenderNodeInfo, + renderMetaedgeInfo: RenderMetaedgeInfo, type: AnnotationType, isIn: boolean) { this.node = node; this.renderNodeInfo = renderNodeInfo; @@ -813,7 +818,7 @@ export enum AnnotationType {SHORTCUT, CONSTANT, SUMMARY, ELLIPSIS}; /** * Manages a list of annotations. Two will be used for each - * RenderNodeInformation, one for in annotations and one for out annotations. + * RenderNodeInfo, one for in annotations and one for out annotations. */ export class AnnotationList { /** @@ -857,7 +862,7 @@ export class AnnotationList { let ellipsisNode = new tf.graph.EllipsisNodeImpl(1); this.list.push(new Annotation(ellipsisNode, - new RenderNodeInformation(ellipsisNode), null, + new RenderNodeInfo(ellipsisNode), null, AnnotationType.ELLIPSIS, annotation.isIn)); } } @@ -865,7 +870,7 @@ export class AnnotationList { /** * Contains rendering information about a node in the hierarchical graph. */ -export class RenderNodeInformation { +export class RenderNodeInfo { /** Reference to the original underlying Node from the hierarchical graph. */ node: Node; /** Whether the node is expanded or not. */ @@ -875,7 +880,9 @@ export class RenderNodeInformation { * shortcuts to high-degree nodes. */ inAnnotations: AnnotationList; - /** List of rendering information about out-annotations (e.g. summary nodes) */ + /** + * List of rendering information about out-annotations (e.g. summary nodes) + */ outAnnotations: AnnotationList; // --- Params specified by layout --- // @@ -884,10 +891,25 @@ export class RenderNodeInformation { x: number; /** Center y position */ y: number; - /** Width of the node's shape */ + /** + * Total width of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ width: number; - /** Height of the node's shape. */ + /** + * Total height of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ height: number; + /** + * Size of the main box of the node, excluding in- and out-annotations. This + * property is used to draw the rectangle/ellipse shape denoting the node. + */ + coreBox: { + width: number, + height: number, + }; + /** Width of the bounding box for all in-annotations. */ inboxWidth: number; /** Width of the bounding box for all out-annotations. */ @@ -930,13 +952,9 @@ export class RenderNodeInformation { paddingRight: number; paddingBottom: number; - /** Width of the whole node including its shape and its annotations */ - outerWidth: number; - /** Height of the whole node including its shape and its annotations */ - outerHeight: number; /** - * Whether a node is extracted as source-like (having high out-degree or matching - * predefined in-extract pattern.) + * Whether a node is extracted as source-like (having high out-degree or + * matching predefined in-extract pattern.) */ isInExtract: boolean; /** @@ -991,11 +1009,9 @@ export class RenderNodeInformation { this.paddingLeft = 0; this.paddingRight = 0; this.paddingBottom = 0; - - this.outerWidth = 0; - this.outerHeight = 0; this.isInExtract = false; this.isOutExtract = false; + this.coreBox = {width: 0, height: 0}; } isInCore(): boolean { @@ -1007,7 +1023,7 @@ export class RenderNodeInformation { * Contains rendering information about a Metaedge from the underlying * hierarchical graph. It may be from either a metagraph or a bridgegraph. */ -export class RenderMetaedgeInformation { +export class RenderMetaedgeInfo { /** * Reference to the original underlying Metaedge from the hierarchical graph, * if any. This will be null for the edges which connect OpNodes to their @@ -1016,15 +1032,15 @@ export class RenderMetaedgeInformation { metaedge: Metaedge; /** - * Reference to the adjoining RenderMeteaedgeInformation from the parent's + * Reference to the adjoining RenderMeteaedgeInfo from the parent's * coreGraph. This is used during layout to determine the point at which this * edge should touch the node's bounding box. This property will be null for * edges which terminate at a node on both ends (all non-bridge edges). */ - adjoiningMetaedge: RenderMetaedgeInformation; + adjoiningMetaedge: RenderMetaedgeInfo; /** - * Most of the time, a RenderMetaedgeInformation object represents a real + * Most of the time, a RenderMetaedgeInfo object represents a real * edge between nodes in the underlying graph structure. But sometimes, an * edge only exsts for layout purposes. These structural edges are added * during buildSubhierarchy() to force dagre.layout() to put bridge nodes @@ -1044,12 +1060,12 @@ export class RenderMetaedgeInformation { * X and Y coordinate pairs of the points in the path of the edge. * @see tf.graph.node.subsceneAdjustPaths */ - points: any[]; + points: Point[]; /** * D3 selection of the group containing the path that displays this edge. */ - edgeGroup: d3.Selection<RenderMetaedgeInformation>; + edgeGroup: d3.Selection<RenderMetaedgeInfo>; constructor(metaedge: Metaedge) { this.metaedge = metaedge; @@ -1059,23 +1075,24 @@ export class RenderMetaedgeInformation { } } -function addInAnnotation(node: RenderNodeInformation, predecessor: Node, - predecessorRenderInfo: RenderNodeInformation, edge: any, - type: AnnotationType, params: RenderGraphParams): void { +function addInAnnotation(node: RenderNodeInfo, predecessor: Node, + predecessorRenderInfo: RenderNodeInfo, + edge: RenderMetaedgeInfo, type: AnnotationType, + params: RenderGraphParams): void { let annotation = new Annotation(predecessor, predecessorRenderInfo, edge, type, true); node.inAnnotations.push(annotation, params); } -function addOutAnnotation(node: RenderNodeInformation, successor: Node, - successorRenderInfo: RenderNodeInformation, edge: any, +function addOutAnnotation(node: RenderNodeInfo, successor: Node, + successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo, type: AnnotationType, params: RenderGraphParams): void { let annotation = new Annotation(successor, successorRenderInfo, edge, type, false); node.outAnnotations.push(annotation, params); } -function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>, +function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>, depth: number) { _.each(graph.nodes(), nodeName => { let child = graph.node(nodeName); @@ -1084,7 +1101,7 @@ function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>, switch (child.node.type) { case NodeType.META: case NodeType.SERIES: - setGroupNodeDepth(<RenderGroupNodeInformation>child, depth - 1); + setGroupNodeDepth(<RenderGroupNodeInfo>child, depth - 1); break; // Do nothing for leaf } @@ -1092,35 +1109,31 @@ function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>, }); }; -export class RenderGroupNodeInformation extends RenderNodeInformation { +export class RenderGroupNodeInfo extends RenderNodeInfo { node: GroupNode; /** * The core graph is derived from the underlying node's metagraph, minus * the extracted source-like and sink-like nodes. */ - coreGraph: graphlib.Graph<RenderNodeInformation, RenderMetaedgeInformation>; - /** Size of the bounding box for a metanode's core graph. */ - coreBox: { - width: number, - height: number, - }; + coreGraph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>; /** Size of the bounding box for a metanode's isolated in-extract children. */ inExtractBox: {width: number, height: number}; - /** Size of the bounding box for a metanode's isolated out-extract children. */ + /** + * Size of the bounding box for a metanode's isolated out-extract children. + */ outExtractBox: {width: number, height: number}; /** Array of isolated in-extract nodes. */ - isolatedInExtract: RenderNodeInformation[]; + isolatedInExtract: RenderNodeInfo[]; /** Array of isolated out-extract nodes. */ - isolatedOutExtract: RenderNodeInformation[]; + isolatedOutExtract: RenderNodeInfo[]; constructor(groupNode: GroupNode) { super(groupNode); let metagraph = groupNode.metagraph; let gl = metagraph.graph(); this.coreGraph = - createGraph<RenderNodeInformation, RenderMetaedgeInformation>( + createGraph<RenderNodeInfo, RenderMetaedgeInfo>( gl.name, GraphType.CORE, { compound: true }); - this.coreBox = {width: 0, height: 0}; this.inExtractBox = {width: 0, height: 0}; this.outExtractBox = {width: 0, height: 0}; this.isolatedInExtract = []; @@ -1128,7 +1141,7 @@ export class RenderGroupNodeInformation extends RenderNodeInformation { } } -function setGroupNodeDepth(renderInfo: RenderGroupNodeInformation, +function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo, depth: number): void { if (renderInfo.coreGraph) { setGraphDepth(renderInfo.coreGraph, depth); @@ -1142,8 +1155,9 @@ function setGroupNodeDepth(renderInfo: RenderGroupNodeInformation, * @param v Source name. * @param w Sink name. */ -function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>, - v: string, w: string, params: RenderGraphParams) { +function createShortcut( + graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>, + v: string, w: string, params: RenderGraphParams) { let src = graph.node(v); let sink = graph.node(w); let edge = graph.edge(v, w); @@ -1173,7 +1187,7 @@ function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>, * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its * edges. Otherwise, only extract all in-edges. */ -function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string, +function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string, params: RenderGraphParams, forceDetach?: boolean) { let graph = renderNode.coreGraph; let child = graph.node(n); @@ -1204,7 +1218,7 @@ function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string, * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its * edges. Otherwise, only remove all out-edges. */ -export function makeInExtract(renderNode: RenderGroupNodeInformation, n: string, +export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string, params: RenderGraphParams, forceDetach?: boolean) { let graph = renderNode.coreGraph; let child = graph.node(n); @@ -1251,7 +1265,7 @@ function hasTypeIn(node: Node, types: string[]): boolean { } /** Move nodes that are speficied to be excluded out of the core graph. */ -function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation, +function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; _.each(graph.nodes(), n => { @@ -1268,7 +1282,7 @@ function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation, } /** Remove edges from pre-defined out-extract patterns */ -function extractPredefinedSink(renderNode: RenderGroupNodeInformation, +function extractPredefinedSink(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; _.each(graph.nodes(), n => { @@ -1283,7 +1297,7 @@ function extractPredefinedSink(renderNode: RenderGroupNodeInformation, } /** Remove edges from pre-defined in-extract patterns */ -function extractPredefinedSource(renderNode: RenderGroupNodeInformation, +function extractPredefinedSource(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; @@ -1299,7 +1313,7 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInformation, } /** Extract from nodes with in-degree > maxInDegree */ -function extractHighInDegree(renderNode: RenderGroupNodeInformation, +function extractHighInDegree(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; let maxInDegree = params.maxInDegree; @@ -1313,7 +1327,8 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation, // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted // from the core graph, unless the node is only used for control. - let numEdgesToCount = _.reduce(graph.predecessors(n), (numEdgesToCount, pred) => { + let numEdgesToCount = _.reduce(graph.predecessors(n), + (numEdgesToCount, pred) => { let metaedge = graph.edge(pred, n).metaedge; return numEdgesToCount + (metaedge.numRegularEdges ? 1 : 0); }, 0); @@ -1329,7 +1344,7 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation, } /** Extract nodes with out-degree > maxOutDegree */ -function extractHighOutDegree(renderNode: RenderGroupNodeInformation, +function extractHighOutDegree(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; let maxOutDegree = params.maxOutDegree; @@ -1343,7 +1358,8 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation, // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted // from the core graph, unless the node is only used for control. - let numEdgesToCount = _.reduce(graph.successors(n), (numEdgesToCount, succ) => { + let numEdgesToCount = _.reduce(graph.successors(n), + (numEdgesToCount, succ) => { let metaedge = graph.edge(n, succ).metaedge; return numEdgesToCount + (metaedge.numRegularEdges ? 1 : 0); }, 0); @@ -1359,7 +1375,7 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation, } /** Remove control edges from nodes that have too many control edges */ -function removeControlEdges(renderNode: RenderGroupNodeInformation, +function removeControlEdges(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { let graph = renderNode.coreGraph; @@ -1408,10 +1424,10 @@ export function mapIndexToHue(id: number): number { * @param {Object} params render Graph construction parameters. See * <tf-graph-params>'s output */ -function extractHighDegrees(renderNode: RenderGroupNodeInformation, +function extractHighDegrees(renderNode: RenderGroupNodeInfo, params: RenderGraphParams) { - extractSpeficiedNodes(renderNode, params); + extractSpecifiedNodes(renderNode, params); if (params.outExtractTypes) { extractPredefinedSink(renderNode, params); diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts index d973f75fd3..425eea0408 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts @@ -43,7 +43,7 @@ module tf.graph.scene.annotation { * @return selection of appended objects */ export function buildGroup(container, annotationData: render.AnnotationList, - d: render.RenderNodeInformation, sceneBehavior) { + d: render.RenderNodeInfo, sceneBehavior) { // Select all children and join with data. let annotationGroups = container.selectAll(function() { // using d3's selector function @@ -151,7 +151,7 @@ function addAnnotationLabel(aGroup, label, a, additionalClassNames, .append("title").text(titleText); } -function addInteraction(selection, d: render.RenderNodeInformation, +function addInteraction(selection, d: render.RenderNodeInfo, annotation: tf.graph.render.Annotation, sceneBehavior) { selection .on("mouseover", a => { @@ -190,8 +190,9 @@ function addInteraction(selection, d: render.RenderNodeInformation, * @param a annotation node data. * @param scene Polymer scene element. */ -function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation, +function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation, sceneBehavior) { + let cx = layout.computeCXPositionOfNodeShape(d); // Annotations that point to embedded nodes (constants,summary) // don't have a render information attached so we don't stylize these. // Also we don't stylize ellipsis annotations (the string "... and X more"). @@ -208,7 +209,7 @@ function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation, // label position aGroup.select("text." + Class.Annotation.LABEL).transition().attr({ - x: d.x + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), + x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), y: d.y + a.dy }); @@ -218,23 +219,23 @@ function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation, // centered with the node and horizontally centered between the arrow and the // text label. aGroup.select("use.summary").transition().attr({ - x: d.x + a.dx - 3, + x: cx + a.dx - 3, y: d.y + a.dy - 6 }); // Node position (only one of the shape selection will be non-empty.) scene.positionEllipse(aGroup.select("." + Class.Annotation.NODE + " ellipse"), - d.x + a.dx, d.y + a.dy, a.width, a.height); + cx + a.dx, d.y + a.dy, a.width, a.height); scene.positionRect(aGroup.select("." + Class.Annotation.NODE + " rect"), - d.x + a.dx, d.y + a.dy, a.width, a.height); + cx + a.dx, d.y + a.dy, a.width, a.height); scene.positionRect(aGroup.select("." + Class.Annotation.NODE + " use"), - d.x + a.dx, d.y + a.dy, a.width, a.height); + cx + a.dx, d.y + a.dy, a.width, a.height); // Edge position aGroup.select("path." + Class.Annotation.EDGE).transition().attr("d", a => { // map relative position to absolute position let points = a.points.map(p => { - return {x: p.dx + d.x, y: p.dy + d.y}; + return {x: p.dx + cx, y: p.dy + d.y}; }); return edge.interpolate(points); }); diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts index d84bb8e2ca..5ae6244e00 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts @@ -19,9 +19,9 @@ limitations under the License. module tf.graph.scene.edge { -let Scene = tf.graph.scene; // Aliased +export type EdgeData = {v: string, w: string, label: render.RenderMetaedgeInfo}; -export function getEdgeKey(edgeObj) { +export function getEdgeKey(edgeObj: EdgeData) { return edgeObj.v + tf.graph.EDGE_KEY_DELIM + edgeObj.w; } @@ -45,8 +45,9 @@ export function getEdgeKey(edgeObj) { * @return selection of the created nodeGroups */ export function buildGroup(sceneGroup, - graph: graphlib.Graph<tf.graph.render.RenderNodeInformation, - tf.graph.render.RenderMetaedgeInformation>, sceneBehavior) { + graph: graphlib.Graph<tf.graph.render.RenderNodeInfo, + tf.graph.render.RenderMetaedgeInfo>, sceneBehavior) { + let edges: EdgeData[] = []; let edgeData = _.reduce(graph.edges(), (edges, edgeObj) => { let edgeLabel = graph.edge(edgeObj); edges.push({ @@ -55,11 +56,10 @@ export function buildGroup(sceneGroup, label: edgeLabel }); return edges; - }, []); + }, edges); let container = scene.selectOrCreateChild(sceneGroup, "g", Class.Edge.CONTAINER); - let containerNode = container.node(); // Select all children and join with data. // (Note that all children of g.edges are g.edge) @@ -76,7 +76,7 @@ export function buildGroup(sceneGroup, .append("g") .attr("class", Class.Edge.GROUP) .attr("data-edge", getEdgeKey) - .each(function(d) { + .each(function(d: EdgeData) { let edgeGroup = d3.select(this); d.label.edgeGroup = edgeGroup; // index node group for quick highlighting @@ -108,11 +108,11 @@ export function buildGroup(sceneGroup, * For a given d3 selection and data object, create a path to represent the * edge described in d.label. * - * If d.label is defined, it will be a RenderMetaedgeInformation instance. It + * If d.label is defined, it will be a RenderMetaedgeInfo instance. It * will sometimes be undefined, for example for some Annotation edges for which * there is no underlying Metaedge in the hierarchical graph. */ -export function appendEdge(edgeGroup, d, sceneBehavior, edgeClass?) { +export function appendEdge(edgeGroup, d: EdgeData, sceneBehavior, edgeClass?) { edgeClass = edgeClass || Class.Edge.LINE; // set default type if (d.label && d.label.structural) { @@ -123,11 +123,16 @@ export function appendEdge(edgeGroup, d, sceneBehavior, edgeClass?) { .attr("class", edgeClass); }; +export let interpolate = d3.svg.line<{x: number, y: number}>() + .interpolate("basis") + .x((d) => { return d.x; }) + .y((d) => { return d.y; }); + /** * Returns a tween interpolator for the endpoint of an edge path. */ -function getEdgePathInterpolator(d, i, a) { - let renderMetaedgeInfo = d.label; +function getEdgePathInterpolator(d: EdgeData, i: number, a: string) { + let renderMetaedgeInfo = <render.RenderMetaedgeInfo> d.label; let adjoiningMetaedge = renderMetaedgeInfo.adjoiningMetaedge; if (!adjoiningMetaedge) { return d3.interpolate(a, interpolate(renderMetaedgeInfo.points)); @@ -162,11 +167,6 @@ function getEdgePathInterpolator(d, i, a) { }; } -export let interpolate = d3.svg.line() - .interpolate("basis") - .x((d: any) => { return d.x; }) - .y((d: any) => { return d.y; }); - function position(d) { d3.select(this).select("path." + Class.Edge.LINE) .each(function(d) { @@ -179,10 +179,9 @@ function position(d) { * For a given d3 selection and data object, mark the edge as a control * dependency if it contains only control edges. * - * d's label property will be a RenderMetaedgeInformation object. + * d's label property will be a RenderMetaedgeInfo object. */ -function stylize(edgeGroup, d, stylize) { - let a; +function stylize(edgeGroup, d: EdgeData, stylize) { let metaedge = d.label.metaedge; edgeGroup .select("path." + Class.Edge.LINE) diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts index c3cf3d684b..32566c99ef 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts @@ -142,6 +142,10 @@ export class Minimap { * was updated (e.g. when a node was expanded). */ update(): void { + // The origin hasn't rendered yet. Ignore making an update. + if (this.zoomG.childElementCount === 0) { + return; + } let $download = d3.select("#graphdownload"); this.download = <HTMLLinkElement>$download.node(); $download.on("click", d => { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts index bb1d1fdcdc..c5781a2fd0 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts @@ -65,7 +65,7 @@ module tf.graph.scene.node { * @return selection of the created nodeGroups */ export function buildGroup(sceneGroup, - nodeData: render.RenderNodeInformation[], sceneBehavior) { + nodeData: render.RenderNodeInfo[], sceneBehavior) { let container = scene.selectOrCreateChild(sceneGroup, "g", Class.Node.CONTAINER); // Select all children and join with data. @@ -76,7 +76,7 @@ export function buildGroup(sceneGroup, // (It's not listed in the d3 wiki.) return this.childNodes; // this here refers to container.node() }) - .data(nodeData, (d: any) => { + .data(nodeData, (d) => { // make sure that we don't have to swap shape type return d.node.name + ":" + d.node.type; }); @@ -124,7 +124,8 @@ export function buildGroup(sceneGroup, addInteraction(shape, d, sceneBehavior); // build subscene on the top - subsceneBuild(nodeGroup, d, sceneBehavior); + subsceneBuild(nodeGroup, <render.RenderGroupNodeInfo> d, + sceneBehavior); stylize(nodeGroup, d, sceneBehavior); position(nodeGroup, d, sceneBehavior); @@ -168,7 +169,7 @@ export function buildGroup(sceneGroup, * not have a subscene. */ function subsceneBuild(nodeGroup, - renderNodeInfo: render.RenderGroupNodeInformation, sceneBehavior) { + renderNodeInfo: render.RenderGroupNodeInfo, sceneBehavior) { if (renderNodeInfo.node.isGroupNode) { if (renderNodeInfo.expanded) { // Recursively build the subscene. @@ -184,7 +185,7 @@ function subsceneBuild(nodeGroup, /** * Translate the subscene of the given node group */ -function subscenePosition(nodeGroup, d: render.RenderNodeInformation) { +function subscenePosition(nodeGroup, d: render.RenderNodeInfo) { let x0 = d.x - d.width / 2.0 + d.paddingLeft; let y0 = d.y - d.height / 2.0 + d.paddingTop; @@ -199,7 +200,7 @@ function subscenePosition(nodeGroup, d: render.RenderNodeInformation) { * @param d Info about the node being rendered. * @param sceneBehavior parent scene module. */ -function addButton(selection, d: render.RenderNodeInformation, sceneBehavior) { +function addButton(selection, d: render.RenderNodeInfo, sceneBehavior) { let group = scene.selectOrCreateChild( selection, "g", Class.Node.BUTTON_CONTAINER); scene.selectOrCreateChild(group, "circle", Class.Node.BUTTON_CIRCLE); @@ -224,7 +225,7 @@ function addButton(selection, d: render.RenderNodeInformation, sceneBehavior) { * don't need interaction as their surrounding shape has interaction, and if * given interaction would cause conflicts with the expand/collapse button. */ -function addInteraction(selection, d: render.RenderNodeInformation, +function addInteraction(selection, d: render.RenderNodeInfo, sceneBehavior, disableInteraction?: boolean) { if (disableInteraction) { selection.attr("pointer-events", "none"); @@ -282,7 +283,7 @@ export function getContextMenu(node: Node, sceneBehavior) { * @param renderNodeInfo The render node information for the label. * @param sceneBehavior parent scene module. */ -function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInformation, +function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInfo, sceneBehavior) { let namePath = renderNodeInfo.node.name.split("/"); let text = namePath[namePath.length - 1]; @@ -320,14 +321,15 @@ function getLabelFontScale(sceneBehavior) { } return fontScale; } + /** * Set label position of a given node group */ -function labelPosition(nodeGroup, d: render.RenderNodeInformation, +function labelPosition(nodeGroup, cx: number, cy: number, yOffset: number) { scene.selectChild(nodeGroup, "text", Class.Node.LABEL).transition() - .attr("x", d.x) - .attr("y", d.y + yOffset); + .attr("x", cx) + .attr("y", cy + yOffset); }; /** @@ -335,7 +337,7 @@ function labelPosition(nodeGroup, d: render.RenderNodeInformation, * as the shape's data. * * @param nodeGroup - * @param d RenderNodeInformation + * @param d Render node information. * @param nodeClass class for the element. * @param before Reference DOM node for insertion. * @return Selection of the shape. @@ -353,7 +355,7 @@ export function buildShape(nodeGroup, d, nodeClass: string, before?) { case NodeType.SERIES: // Choose the correct stamp to use to represent this series. let stampType = "annotation"; - let groupNodeInfo = <render.RenderGroupNodeInformation>d; + let groupNodeInfo = <render.RenderGroupNodeInfo>d; if (groupNodeInfo.coreGraph) { stampType = groupNodeInfo.node.hasNonControlEdges ? "vertical" : "horizontal"; @@ -377,7 +379,7 @@ export function buildShape(nodeGroup, d, nodeClass: string, before?) { return shapeGroup; }; -export function nodeClass(d: render.RenderNodeInformation) { +export function nodeClass(d: render.RenderNodeInfo) { switch (d.node.type) { case NodeType.OP: return Class.OPNODE; @@ -394,43 +396,43 @@ export function nodeClass(d: render.RenderNodeInformation) { }; /** Modify node and its subscene and its label's positional attributes */ -function position(nodeGroup, d: render.RenderNodeInformation, sceneBehavior) { +function position(nodeGroup, d: render.RenderNodeInfo, sceneBehavior) { let shapeGroup = scene.selectChild(nodeGroup, "g", Class.Node.SHAPE); + let cx = layout.computeCXPositionOfNodeShape(d); switch (d.node.type) { case NodeType.OP: { // position shape let shape = scene.selectChild(shapeGroup, "ellipse"); - scene.positionEllipse(shape, d.x, d.y, d.width, d.height); - labelPosition(nodeGroup, d, d.labelOffset); + scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); break; } case NodeType.META: { // position shape let shape = scene.selectChild(shapeGroup, "rect"); - scene.positionRect(shape, d.x, d.y, d.width, d.height); - if (d.expanded) { + scene.positionRect(shape, d.x, d.y, d.width, d.height); subscenePosition(nodeGroup, d); - // put label on top - labelPosition(nodeGroup, d, + labelPosition(nodeGroup, cx, d.y, - d.height / 2 + d.labelHeight / 2); } else { - labelPosition(nodeGroup, d, 0); + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, 0); } break; } case NodeType.SERIES: { let shape = scene.selectChild(shapeGroup, "use"); - scene.positionRect(shape, d.x, d.y, d.width, d.height); if (d.expanded) { - subscenePosition(nodeGroup, d); - + scene.positionRect(shape, d.x, d.y, d.width, d.height); + subscenePosition(nodeGroup, d); // put label on top - labelPosition(nodeGroup, d, + labelPosition(nodeGroup, cx, d.y, - d.height / 2 + d.labelHeight / 2); } else { - labelPosition(nodeGroup, d, d.labelOffset); + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); } } case NodeType.BRIDGE: { @@ -455,7 +457,7 @@ export enum ColorBy { STRUCTURE, DEVICE, COMPUTE_TIME, MEMORY }; * option. */ export function getFillForNode(templateIndex, colorBy, - renderInfo: render.RenderNodeInformation, isExpanded: boolean): string { + renderInfo: render.RenderNodeInfo, isExpanded: boolean): string { let colorParams = tf.graph.render.MetanodeColors; switch (colorBy) { case ColorBy.STRUCTURE: @@ -493,7 +495,7 @@ export function getFillForNode(templateIndex, colorBy, linearGradient.selectAll("*").remove(); let cumulativeProportion = 0; // For each device, create a stop using the proportion of that device. - _.each(renderInfo.deviceColors, (d: any) => { + _.each(renderInfo.deviceColors, d => { let color = d.color; linearGradient.append("stop") .attr("offset", cumulativeProportion) @@ -522,7 +524,7 @@ export function getFillForNode(templateIndex, colorBy, * Modify node style by toggling class and assign attributes (only for things * that can't be done in css). */ -export function stylize(nodeGroup, renderInfo: render.RenderNodeInformation, +export function stylize(nodeGroup, renderInfo: render.RenderNodeInfo, sceneBehavior, nodeClass?) { nodeClass = nodeClass || Class.Node.SHAPE; let isHighlighted = sceneBehavior.isNodeHighlighted(renderInfo.node.name); diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts index 6e97e904da..24c16e31ee 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts @@ -118,8 +118,8 @@ export function fit(svg, zoomG, d3zoom, callback) { * provided node. */ export function panToNode(nodeName: String, svg, zoomG, d3zoom): boolean { - let node: any = d3.selectAll("[data-name='" + nodeName + "']." - + Class.Node.GROUP)[0][0]; + let node = <SVGAElement> d3.select("[data-name='" + nodeName + "']." + + Class.Node.GROUP).node(); if (!node) { return false; } @@ -247,7 +247,7 @@ export function selectChild(container, tagName: string, className?: string) { * @param sceneClass class attribute of the scene (default="scene"). */ export function buildGroup(container, - renderNode: render.RenderGroupNodeInformation, + renderNode: render.RenderGroupNodeInfo, sceneBehavior, sceneClass: string) { sceneClass = sceneClass || Class.Scene.GROUP; @@ -301,8 +301,7 @@ export function buildGroup(container, // Fade in the scene group if it didn't already exist. if (isNewSceneGroup) { - sceneGroup.attr("opacity", 0) - .transition().attr("opacity", 1); + sceneGroup.attr("opacity", 0).transition().attr("opacity", 1); } return sceneGroup; @@ -315,7 +314,7 @@ export function buildGroup(container, * @param sceneGroup * @param renderNode render node of a metanode or series node. */ -function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) { +function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) { // Translate scenes down by the label height so that when showing graphs in // expanded metanodes, the graphs are below the labels. Do not shift them // down for series nodes as series nodes don't have labels inside of their @@ -324,14 +323,13 @@ function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) { 0 : layout.PARAMS.subscene.meta.labelHeight; // core - translate(selectChild(sceneGroup, "g", Class.Scene.CORE), - 0, yTranslate); + translate(selectChild(sceneGroup, "g", Class.Scene.CORE), 0, yTranslate); // in-extract - let inExtractX = renderNode.coreBox.width === 0 ? - 0 : renderNode.coreBox.width; let hasInExtract = renderNode.isolatedInExtract.length > 0; if (hasInExtract) { + let inExtractX = renderNode.coreBox.width - + renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width; translate(selectChild(sceneGroup, "g", Class.Scene.INEXTRACT), inExtractX, yTranslate); } @@ -339,8 +337,8 @@ function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) { // out-extract let hasOutExtract = renderNode.isolatedOutExtract.length > 0; if (hasOutExtract) { - let outExtractX = inExtractX + renderNode.inExtractBox.width - + renderNode.extractXOffset; + let outExtractX = renderNode.coreBox.width - + renderNode.outExtractBox.width / 2; translate(selectChild(sceneGroup, "g", Class.Scene.OUTEXTRACT), outExtractX, yTranslate); } @@ -355,6 +353,10 @@ export function addGraphClickListener(graphGroup, sceneBehavior) { /** Helper for adding transform: translate(x0, y0) */ export function translate(selection, x0: number, y0: number) { + // If it is already placed on the screen, make it a transition. + if (selection.attr("transform") != null) { + selection = selection.transition("position"); + } selection.attr("transform", "translate(" + x0 + "," + y0 + ")"); }; @@ -382,12 +384,16 @@ export function positionRect(rect, cx: number, cy: number, width: number, * @param renderNode the render node of the group node to position * the button on. */ -export function positionButton(button, - renderNode: render.RenderNodeInformation) { +export function positionButton(button, renderNode: render.RenderNodeInfo) { + let cx = layout.computeCXPositionOfNodeShape(renderNode); // Position the button in the top-right corner of the group node, // with space given the draw the button inside of the corner. - let x = renderNode.x + renderNode.width / 2 - 6; - let y = renderNode.y - renderNode.height / 2 + 6; + let width = renderNode.expanded ? + renderNode.width : renderNode.coreBox.width; + let height = renderNode.expanded ? + renderNode.height : renderNode.coreBox.height; + let x = cx + width / 2 - 6; + let y = renderNode.y - height / 2 + 6; // For unexpanded series nodes, the button has special placement due // to the unique visuals of this group node. if (renderNode.node.type === NodeType.SERIES && !renderNode.expanded) { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts index 41fbbbb9ff..0423e1c863 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts @@ -57,13 +57,13 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} { * @return Unique string for a metanode based on depth, |V|, |E| and * op type histogram. */ - function getSignature(metanode) { +function getSignature(metanode) { // depth=<number> |V|=<number> |E|=<number> - let props = _.map({ - "depth": metanode.depth, - "|V|": metanode.metagraph.nodes().length, - "|E|": metanode.metagraph.edges().length - }, function(v, k) { return k + "=" + v; }).join(" "); + let props = _.map({ + "depth": metanode.depth, + "|V|": metanode.metagraph.nodes().length, + "|E|": metanode.metagraph.edges().length + }, function(v, k) { return k + "=" + v; }).join(" "); // optype1=count1,optype2=count2 let ops = _.map(metanode.opHistogram, function(count, op) { @@ -84,7 +84,8 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} { */ function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) { /** a dict from metanode.signature() => Array of tf.graph.Groups */ - let hashDict = _(h.getNodeMap()).reduce(function(hash, node: OpNode|Metanode, name) { + let hashDict = _(h.getNodeMap()).reduce( + (hash, node: OpNode|Metanode, name) => { if (node.type !== NodeType.META) { return hash; } @@ -156,8 +157,8 @@ function groupTemplateAndAssignId(nnGroups, verifyTemplate) { }, result); } -function sortNodes(names: string[], graph: graphlib.Graph<Metanode|OpNode, Metaedge>, - prefix: string) { +function sortNodes(names: string[], + graph: graphlib.Graph<Metanode|OpNode, Metaedge>, prefix: string) { return _.sortByAll(names, function(name) { let node = graph.node(name); @@ -181,7 +182,8 @@ function sortNodes(names: string[], graph: graphlib.Graph<Metanode|OpNode, Metae }); } -function isSimilarSubgraph(g1: graphlib.Graph<any, any>, g2: graphlib.Graph<any, any>) { +function isSimilarSubgraph(g1: graphlib.Graph<any, any>, + g2: graphlib.Graph<any, any>) { if (!tf.graph.hasSimilarDegreeSequence(g1, g2)) { return false; } @@ -273,25 +275,27 @@ function isSimilarSubgraph(g1: graphlib.Graph<any, any>, g2: graphlib.Graph<any, /** * Returns if two nodes have identical structure. */ - function isSimilarNode(n1: OpNode|Metanode|SeriesNode, n2: OpNode|Metanode|SeriesNode): boolean { +function isSimilarNode(n1: OpNode|Metanode|SeriesNode, + n2: OpNode|Metanode|SeriesNode): boolean { if (n1.type === NodeType.META) { // compare metanode let metanode1 = <Metanode> n1; let metanode2 = <Metanode> n2; - return metanode1.templateId && metanode2.templateId && metanode1.templateId === metanode2.templateId; + return metanode1.templateId && metanode2.templateId && + metanode1.templateId === metanode2.templateId; } else if (n1.type === NodeType.OP && n2.type === NodeType.OP) { // compare leaf node return (<OpNode>n1).op === (<OpNode>n2).op; } else if (n1.type === NodeType.SERIES && n2.type === NodeType.SERIES) { // compare series node sizes and operations // (only need to check one op as all op nodes are identical in series) - let seriesnode1 = <SeriesNode> n1; - let seriesnode2 = <SeriesNode> n2; - let seriesnode1Count = seriesnode1.metagraph.nodeCount(); - return (seriesnode1Count === seriesnode2.metagraph.nodeCount() && + let sn1 = <SeriesNode> n1; + let sn2 = <SeriesNode> n2; + let seriesnode1Count = sn1.metagraph.nodeCount(); + return (seriesnode1Count === sn2.metagraph.nodeCount() && (seriesnode1Count === 0 || - ((<OpNode>seriesnode1.metagraph.node(seriesnode1.metagraph.nodes()[0])).op === - (<OpNode>seriesnode2.metagraph.node(seriesnode2.metagraph.nodes()[0])).op))); + ((<OpNode>sn1.metagraph.node(sn1.metagraph.nodes()[0])).op === + (<OpNode>sn2.metagraph.node(sn2.metagraph.nodes()[0])).op))); } return false; } diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html index 9204b392e3..765803a6a9 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html @@ -83,7 +83,7 @@ * Render node information associated with this node. Optional. If * specified, this is only used when computing the fill of the icon * element. - * @type {tf.graph.render.RenderNodeInformation} + * @type {tf.graph.render.RenderNodeInfo} */ renderInfo: { type: Object, diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html index 996735679f..e9e6f6ce02 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html @@ -98,10 +98,7 @@ Polymer({ properties: { renderHierarchy: Object, name: String, - colorBy: { - type: String, - observer: '_colorByChanged' - }, + colorBy: String, /** @type {d3_zoom} d3 zoom object */ _zoom: Object, highlightedNode: { @@ -201,6 +198,7 @@ Polymer({ progress: Object }, observers: [ + '_colorByChanged(colorBy, renderHierarchy)', '_buildAndFit(renderHierarchy)' ], getNode: function(nodeName) { @@ -234,7 +232,7 @@ Polymer({ this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex(); tf.time('tf-graph-scene (layout):', function() { // layout the scene for this meta / series node - tf.graph.layout.scene(renderHierarchy.root, this); + tf.graph.layout.layoutScene(renderHierarchy.root, this); }.bind(this)); tf.time('tf-graph-scene (build scene):', function() { diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph.html b/tensorflow/tensorboard/components/tf-graph/tf-graph.html index ffb737f761..e35664ae7f 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph.html @@ -112,8 +112,8 @@ Polymer({ // and thus mistakenly pass non-metanode to this module. return; } - var renderGraph = new tf.graph.render.RenderGraphInformation( - graphHierarchy, params); + var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy, + params); // Producing the 'color by' parameters to be consumed // by the tf-graph-controls panel. It contains information about the // min and max values and their respective colors, as well as list diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html index 1d211f9b2a..d6748ff167 100644 --- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html @@ -20,11 +20,11 @@ allows the user to toggle between various dashboards. <paper-toolbar id="toolbar"> <div id="toolbar-content"> <div class="toolbar-title">TensorBoard</div> - <paper-tabs selected="0" noink class="tabs"> - <paper-tab on-click="chooseEvents">Events</paper-tab> - <paper-tab on-click="chooseImages">Images</paper-tab> - <paper-tab on-click="chooseGraphs">Graph</paper-tab> - <paper-tab on-click="chooseHistograms">Histograms</paper-tab> + <paper-tabs selected="0" noink class="tabs" id="tabs"> + <paper-tab data-mode="events" on-click="changeMode">Events</paper-tab> + <paper-tab data-mode="images" on-click="changeMode">Images</paper-tab> + <paper-tab data-mode="graphs" on-click="changeMode">Graph</paper-tab> + <paper-tab data-mode="histograms" on-click="changeMode">Histograms</paper-tab> </paper-tabs> </div> </paper-toolbar> @@ -100,17 +100,9 @@ allows the user to toggle between various dashboards. value: "events", }, }, - chooseEvents: function() { - this.mode = "events"; - }, - chooseImages: function() { - this.mode = "images"; - }, - chooseGraphs: function() { - this.mode = "graphs"; - }, - chooseHistograms: function() { - this.mode = "histograms"; + changeMode: function(ev) { + var mode = ev.target.parentElement.getAttribute('data-mode'); + this._changeMode(mode, true); }, eventDashboard: function(mode) { return mode === "events"; @@ -123,7 +115,47 @@ allows the user to toggle between various dashboards. }, histogramDashboard: function(mode) { return mode === "histograms"; - } + }, + loadPreviousMode: function() { + this._changeMode(this._getModeAndPath().mode, false); + }, + ready: function() { + this._changeMode(this._getModeAndPath().mode, true); + + var tb = this; + window.addEventListener('popstate', function(){ + tb.loadPreviousMode(); + }); + }, + _changeMode: function(mode, isNewState) { + this.mode = mode; + + // Change the selected tab + this.$.tabs.selected = this._tabs().indexOf(mode); + + if (isNewState){ + var basePath = this._getModeAndPath().path; + basePath += basePath[basePath.length - 1] == '/' ? '' : '/'; + history.pushState(null, null, basePath + mode); + } + }, + _getModeAndPath: function() { + // Returns a {mode: 'mode', path: 'basePathWithoutMode'} + // The mode is assumed to be at the end of the pathname. + var tokens = window.location.pathname.split('/'); + var mode = tokens[tokens.length - 1]; + + if (_.contains(this._tabs(), mode)) { + return {mode: mode, path: tokens.slice(0, tokens.length-1).join('/')}; + } else { + // Unrecognized modes turn into events + return {mode: 'events', path: tokens.join('/')}; + } + }, + _tabs: function() { + var elts = Array.prototype.slice.call(this.querySelectorAll('paper-tab')); + return elts.map(function(elt){ return elt.getAttribute('data-mode')}); + }, }); </script> </dom-module> diff --git a/tensorflow/tensorboard/dist/index.html b/tensorflow/tensorboard/dist/index.html index e75a87a4f3..a72f6a62a6 100644 --- a/tensorflow/tensorboard/dist/index.html +++ b/tensorflow/tensorboard/dist/index.html @@ -33,6 +33,7 @@ <link rel="import" href="external/paper-styles/paper-styles.html"> <link rel="import" href="external/paper-toggle-button/paper-toggle-button.html"> <link rel="import" href="external/paper-toolbar/paper-toolbar.html"> + <link rel="import" href="external/paper-tabs/paper-tabs.html"> <link rel="import" href="dist/tf-tensorboard.html"> diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html index 2aa1e46ca3..7d01603e1c 100644 --- a/tensorflow/tensorboard/dist/tf-tensorboard.html +++ b/tensorflow/tensorboard/dist/tf-tensorboard.html @@ -1,3 +1,4 @@ +// AUTOGENERATED FILE - DO NOT MODIFY <html><head><meta charset="UTF-8"> @@ -11,6 +12,8 @@ --tb-orange-strong: #f3913e; --tb-grey-darker: #e2e2e2; --tb-grey-lighter: #f3f3f3; + --tb-ui-dark-accent: #757575; + --tb-ui-light-accent: #e0e0e0; } </style> @@ -32,8 +35,21 @@ +<script>/* Copyright 2015 Google Inc. All Rights Reserved. -<script>/// <reference path="../../../typings/tsd.d.ts" /> +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../../typings/tsd.d.ts" /> var tf; (function (tf) { /** @@ -112,7 +128,21 @@ var tf; tf.escapeQuerySelector = escapeQuerySelector; })(tf || (tf = {})); // close module tf </script> -<script>/// <reference path="../../../typings/tsd.d.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../../typings/tsd.d.ts" /> /// <reference path="common.ts" /> var tf; (function (tf) { @@ -120,7 +150,6 @@ var tf; (function (graph_1) { /** Delimiter used in node names to denote namespaces. */ graph_1.NAMESPACE_DELIM = "/"; - var FULL_GRAPH_NAME = "fullGraph"; graph_1.ROOT_NAME = "__root__"; // Separator between the source and the destination name of the edge. graph_1.EDGE_KEY_DELIM = "--"; @@ -145,6 +174,14 @@ var tf; })(graph_1.NodeType || (graph_1.NodeType = {})); var NodeType = graph_1.NodeType; ; + /** Indicates if a node is to be included in the main graph when rendered. */ + (function (InclusionType) { + InclusionType[InclusionType["INCLUDE"] = 0] = "INCLUDE"; + InclusionType[InclusionType["EXCLUDE"] = 1] = "EXCLUDE"; + InclusionType[InclusionType["UNSPECIFIED"] = 2] = "UNSPECIFIED"; + })(graph_1.InclusionType || (graph_1.InclusionType = {})); + var InclusionType = graph_1.InclusionType; + ; /** * A SlimGraph is inspired by graphlib.Graph, but having only the functionality * that we need. @@ -170,6 +207,7 @@ var tf; this.parentNode = null; this.stats = null; this.setNumMoreNodes(numNodes); + this.include = InclusionType.UNSPECIFIED; } EllipsisNodeImpl.prototype.setNumMoreNodes = function (numNodes) { this.numMoreNodes = numNodes; @@ -190,8 +228,8 @@ var tf; * @param rawNode The raw node. * @param normalizedInputs An array of normalized * inputs that denote the incoming edges to the current node. Each input - * contains the normalized name of the source node, whether it has a number - * part and whether it is a control dependency. + * contains the normalized name of the source node, whether it has a + * number part and whether it is a control dependency. */ function OpNodeImpl(rawNode, normalizedInputs) { this.op = rawNode.op; @@ -206,6 +244,7 @@ var tf; this.inEmbeddings = []; this.outEmbeddings = []; this.parentNode = null; + this.include = InclusionType.UNSPECIFIED; } return OpNodeImpl; })(); @@ -216,8 +255,8 @@ var tf; } graph_1.createMetanode = createMetanode; /** - * Joins the information from the stats file (memory, compute time) with the graph - * information. + * Joins the information from the stats file (memory, compute time) with the + * graph information. */ function joinStatsInfoWithGraph(graph, statsJson) { _.each(statsJson.devStats, function (stats) { @@ -274,6 +313,7 @@ var tf; }; return NodeStats; })(); + graph_1.NodeStats = NodeStats; var MetanodeImpl = (function () { /** A label object for meta-nodes in the graph hierarchy */ function MetanodeImpl(name, opt) { @@ -304,6 +344,7 @@ var tf; this.parentNode = null; this.stats = new NodeStats(0, 0, null); this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; } MetanodeImpl.prototype.getFirstChild = function () { return this.metagraph.node(this.metagraph.nodes()[0]); @@ -404,6 +445,7 @@ var tf; this.deviceHistogram = {}; this.hasNonControlEdges = false; this.stats = new NodeStats(0, 0, null); + this.include = InclusionType.UNSPECIFIED; } return SeriesNodeImpl; })(); @@ -655,7 +697,8 @@ var tf; ; /** * Returns the hierarchical path of the current node, based on the node's name. - * For example, if the name is 'a/b/c', the returned path is ['a', 'a/b', 'a/b/c']. + * For example, if the name is 'a/b/c', the returned path is + * ['a', 'a/b', 'a/b/c']. */ function getHierarchicalPath(name, seriesNames) { var path = []; @@ -679,10 +722,38 @@ var tf; } graph_1.getHierarchicalPath = getHierarchicalPath; ; + /** + * Returns the string for the node inclusion toggle button, dependant + * on the provided current InclusionType. + */ + function getIncludeNodeButtonString(include) { + if (include === tf.graph.InclusionType.EXCLUDE) { + return "Add to main graph"; + } + else { + return "Remove from main graph"; + } + } + graph_1.getIncludeNodeButtonString = getIncludeNodeButtonString; + ; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module tf.graph </script> -<script>/// <reference path="../../../typings/tsd.d.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../../typings/tsd.d.ts" /> /// <reference path="common.ts" /> var tf; (function (tf) { @@ -871,7 +942,21 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // Close module tf.graph.parser. </script> -<script>/// <reference path="graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="graph.ts" /> /// <reference path="template.ts" /> /** * Package for the Graph Hierarchy for TensorFlow graph. @@ -882,7 +967,6 @@ var tf; (function (graph_1) { var hierarchy; (function (hierarchy_1) { - var LOG_PREFIX_MSG = "Graph hierarchy: "; /** * Class for the Graph Hierarchy for TensorFlow graph. */ @@ -1139,6 +1223,17 @@ var tf; } return ordering; }; + /** + * Returns a d3 Ordinal function that can be used to look up the index of + * a node based on its template id. + */ + HierarchyImpl.prototype.getTemplateIndex = function () { + var templateNames = d3.keys(this.templates); + var templateIndex = d3.scale.ordinal() + .domain(templateNames) + .range(d3.range(0, templateNames.length)); + return function (templateId) { return templateIndex(templateId); }; + }; return HierarchyImpl; })(); /** @@ -1192,8 +1287,8 @@ var tf; }, tracker) .then(function () { return tf.runAsyncTask("Detect series", 20, function () { - if (params.groupSeries) { - groupSeries(h.root, h, seriesNames); + if (params.seriesNodeMinSize > 0) { + groupSeries(h.root, h, seriesNames, params.seriesNodeMinSize); } }, tracker); }) @@ -1251,7 +1346,8 @@ var tf; } parent = child; } - // Assuming node name is 'a/b/c', assign the OpNode as a child of the metanode 'a/b'. + // Assuming node name is 'a/b/c', assign the OpNode as a child of the + // metanode 'a/b'. h.setNode(node.name, node); node.parentNode = parent; parent.metagraph.setNode(node.name, node); @@ -1333,14 +1429,17 @@ var tf; * * @param metanode * @param hierarchy - * @return A dictionary from node name to series node name that contains the node + * @param threshold If the series has this many nodes or more, then group them + * into a series. + * @return A dictionary from node name to series node name that contains the + * node. */ - function groupSeries(metanode, hierarchy, seriesNames) { + function groupSeries(metanode, hierarchy, seriesNames, threshold) { var metagraph = metanode.metagraph; _.each(metagraph.nodes(), function (n) { var child = metagraph.node(n); if (child.type === tf.graph.NodeType.META) { - groupSeries(child, hierarchy, seriesNames); + groupSeries(child, hierarchy, seriesNames, threshold); } }); var clusters = clusterNodes(metagraph); @@ -1349,8 +1448,9 @@ var tf; // metagraph. _.each(seriesDict, function (seriesNode, seriesName) { var nodeMemberNames = seriesNode.metagraph.nodes(); - var firstMember = seriesNode.metagraph.node(nodeMemberNames[0]); - var seriesType = firstMember.type; + if (nodeMemberNames.length < threshold) { + return; + } hierarchy.setNode(seriesName, seriesNode); // add to the index metagraph.setNode(seriesName, seriesNode); _.each(nodeMemberNames, function (n) { @@ -1453,7 +1553,8 @@ var tf; var seriesNodes = [seriesInfoArray[0]]; for (var index = 1; index < seriesInfoArray.length; index++) { var nextNode = seriesInfoArray[index]; - if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) { + if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + + 1) { seriesNodes.push(nextNode); continue; } @@ -1489,14 +1590,28 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module tf.graph.hierarchy </script> -<script>/// <reference path="graph.ts" /> -/// <reference path="hierarchy.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ var __extends = (this && this.__extends) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; function __() { this.constructor = d; } __.prototype = b.prototype; d.prototype = new __(); }; +/// <reference path="graph.ts" /> +/// <reference path="hierarchy.ts" /> /** * Package for the Render Hierarchy for TensorFlow graph. */ @@ -1507,10 +1622,22 @@ var tf; var render; (function (render) { /** + * Color parameters for op nodes. + */ + render.OpNodeColors = { + DEFAULT_FILL: "white", + DEFAULT_STROKE: "#b2b2b2" + }; + /** * Color parameters for node encoding. * @type {Object} */ render.MetanodeColors = { + /** + * Default fill and stroke to use when no other information is available. + */ + DEFAULT_FILL: "#d9d9d9", + DEFAULT_STROKE: "#a6a6a6", SATURATION: 0.6, LIGHTNESS: 0.85, /** @@ -1540,11 +1667,18 @@ var tf; GRADIENT_OUTLINE: "#888" }; /** + * Color parameters for op nodes. + */ + render.SeriesNodeColors = { + DEFAULT_FILL: "white", + DEFAULT_STROKE: "#b2b2b2" + }; + /** * Stores the rendering information, such as x and y coordinates, * for each node in the graph. */ - var RenderGraphInformation = (function () { - function RenderGraphInformation(hierarchy, params) { + var RenderGraphInfo = (function () { + function RenderGraphInfo(hierarchy, params) { this.hierarchy = hierarchy; this.index = {}; this.deviceColorMap = d3.scale.ordinal() @@ -1573,15 +1707,67 @@ var tf; this.computeTimeScale = d3.scale.linear() .domain(computeTimeExtent) .range(params.minMaxColors); - // Maps node name to whether the rendering hierarchy was already constructed. + // Maps node name to whether the rendering hierarchy was already + // constructed. this.hasSubhierarchy = {}; this.params = params; - this.root = new RenderGroupNodeInformation(hierarchy.root); + this.root = new RenderGroupNodeInfo(hierarchy.root); this.index[hierarchy.root.name] = this.root; this.buildSubhierarchy(hierarchy.root.name); this.root.expanded = true; } - RenderGraphInformation.prototype.getRenderNodeByName = function (nodeName) { + /** + * Get a previously created RenderNodeInfo by its node name. + */ + RenderGraphInfo.prototype.getRenderNodeByName = function (nodeName) { + return this.index[nodeName]; + }; + /** + * Get a previously created RenderNodeInfo for the specified node name, + * or create one if it hasn't been created yet. + */ + RenderGraphInfo.prototype.getOrCreateRenderNodeByName = function (nodeName) { + var _this = this; + // Polymer may invoke this with null. + if (!nodeName) { + return null; + } + if (nodeName in this.index) { + return this.index[nodeName]; + } + var node = this.hierarchy.node(nodeName); + var renderInfo = node.isGroupNode ? + new RenderGroupNodeInfo(node) : + new RenderNodeInfo(node); + this.index[nodeName] = renderInfo; + if (node.stats) { + renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes); + renderInfo.computeTimeColor = + this.computeTimeScale(node.stats.totalMicros); + } + if (node.isGroupNode) { + // Make a list of tuples (device, proportion), where proportion + // is the fraction of op nodes that have that device. + var pairs = _.pairs(node.deviceHistogram); + if (pairs.length > 0) { + // Compute the total # of devices. + var numDevices = _.sum(pairs, _.last); + renderInfo.deviceColors = _.map(pairs, function (pair) { return ({ + color: _this.deviceColorMap(pair[0]), + // Normalize to a proportion of total # of devices. + proportion: pair[1] / numDevices + }); }); + } + } + else { + var device = renderInfo.node.device; + if (device) { + renderInfo.deviceColors = [{ + color: this.deviceColorMap(device), + proportion: 1.0 + }]; + } + } return this.index[nodeName]; }; /** @@ -1590,7 +1776,7 @@ var tf; * (highlight) a node that isn't drawn yet, by selecting (highlighting) * its nearest ancestor that has been drawn. */ - RenderGraphInformation.prototype.getNearestVisibleAncestor = function (name) { + RenderGraphInfo.prototype.getNearestVisibleAncestor = function (name) { var path = graph_1.getHierarchicalPath(name); for (var i = 0; i < path.length; i++) { var nodeName = path[i]; @@ -1603,10 +1789,26 @@ var tf; return name; }; // TODO(jimbo): Delete this an any code it touches (all deprecated). - RenderGraphInformation.prototype.setDepth = function (depth) { + RenderGraphInfo.prototype.setDepth = function (depth) { setGroupNodeDepth(this.root, +depth); }; - RenderGraphInformation.prototype.buildSubhierarchy = function (nodeName) { + /** + * Returns true if the renderNode is an isolated node within its parent node. + */ + RenderGraphInfo.prototype.isNodeAuxilliary = function (renderNode) { + var parentNode = this.getRenderNodeByName(renderNode.node.parentNode.name); + var found = _.find(parentNode.isolatedInExtract, function (node) { + return node.node.name === renderNode.node.name; + }); + if (found) { + return true; + } + found = _.find(parentNode.isolatedOutExtract, function (node) { + return node.node.name === renderNode.node.name; + }); + return !!found; + }; + RenderGraphInfo.prototype.buildSubhierarchy = function (nodeName) { var _this = this; // Terminate if the rendering hierarchy was already constructed // for this node. @@ -1628,58 +1830,26 @@ var tf; // extracted. Also, due to extraction, the coreGraph may contain disjoint // groups between which there is no visible path (other than annotations). _.each(metagraph.nodes(), function (childName) { - var childNode = metagraph.node(childName); - var childRenderInfo = childNode.isGroupNode ? - new RenderGroupNodeInformation(childNode) : - new RenderNodeInformation(childNode); - _this.index[childName] = childRenderInfo; + var childRenderInfo = _this.getOrCreateRenderNodeByName(childName); + var childNode = childRenderInfo.node; coreGraph.setNode(childName, childRenderInfo); - if (childRenderInfo.node.stats != null) { - childRenderInfo.memoryColor = - _this.memoryUsageScale(childRenderInfo.node.stats.totalBytes); - childRenderInfo.computeTimeColor = - _this.computeTimeScale(childRenderInfo.node.stats.totalMicros); - } if (!childNode.isGroupNode) { _.each(childNode.inEmbeddings, function (embedding) { - var renderMetaedgeInfo = new RenderMetaedgeInformation(null); + var renderMetaedgeInfo = new RenderMetaedgeInfo(null); addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.CONSTANT, _this.params); - _this.index[embedding.name] = new RenderNodeInformation(embedding); + _this.index[embedding.name] = new RenderNodeInfo(embedding); }); _.each(childNode.outEmbeddings, function (embedding) { - var renderMetaedgeInfo = new RenderMetaedgeInformation(null); + var renderMetaedgeInfo = new RenderMetaedgeInfo(null); addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.SUMMARY, _this.params); - _this.index[embedding.name] = new RenderNodeInformation(embedding); + _this.index[embedding.name] = new RenderNodeInfo(embedding); }); - var device = childRenderInfo.node.device; - if (device != null) { - childRenderInfo.deviceColors = [{ - color: _this.deviceColorMap(device), - proportion: 1.0 - }]; - } - } - else { - // Make a list of tuples (device, proportion), where proportion - // is the fraction of op nodes that have that device. - var pairs = _.pairs(childNode.deviceHistogram); - if (pairs.length > 0) { - // Compute the total # of devices. - var numDevices = _.sum(pairs, _.last); - childRenderInfo.deviceColors = _.map(pairs, function (pair) { - return { - color: _this.deviceColorMap(pair[0]), - // Normalize to a proportion of total # of devices. - proportion: pair[1] / numDevices - }; - }); - } } }); // Add render metaedge info for edges in the metagraph. _.each(metagraph.edges(), function (edgeObj) { var metaedge = metagraph.edge(edgeObj); - var renderMetaedgeInfo = new RenderMetaedgeInformation(metaedge); + var renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge); coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo); }); if (this.params.enableExtraction && @@ -1755,7 +1925,7 @@ var tf; otherCounts.control[otherName] > _this.params.maxControlDegree; var _b = inbound ? [renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] : - [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations], annotations = _b[0], childAnnotations = _b[1]; + [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations], childAnnotations = _b[1]; var isOtherHighDegree = inbound ? otherCounts.out[otherName] > _this.params.maxOutDegree : otherCounts.in[otherName] > _this.params.maxInDegree; @@ -1843,7 +2013,7 @@ var tf; // If we can't make a bridge path for any reason, then we add an // annotation instead. if (!canDrawBridgePath) { - childAnnotations.push(new Annotation(otherNode, otherRenderInfo, new RenderMetaedgeInformation(bridgeMetaedge), AnnotationType.SHORTCUT, inbound), _this.params); + childAnnotations.push(new Annotation(otherNode, otherRenderInfo, new RenderMetaedgeInfo(bridgeMetaedge), AnnotationType.SHORTCUT, inbound), _this.params); return; } // At this point, all conditions have been met for drawing a bridge path. @@ -1864,11 +2034,12 @@ var tf; cardinality: 0, parentNode: null, stats: null, + include: graph_1.InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; bridgeContainerInfo = - new RenderNodeInformation(bridgeContainerNode); + new RenderNodeInfo(bridgeContainerNode); _this.index[bridgeContainerName] = bridgeContainerInfo; coreGraph.setNode(bridgeContainerName, bridgeContainerInfo); } @@ -1881,10 +2052,11 @@ var tf; cardinality: 1, parentNode: null, stats: null, + include: graph_1.InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; - bridgeNodeRenderInfo = new RenderNodeInformation(bridgeNode); + bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode); _this.index[bridgeNodeName] = bridgeNodeRenderInfo; coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo); // Set bridgeNode to be a graphlib child of the container node. @@ -1892,7 +2064,7 @@ var tf; bridgeContainerInfo.node.cardinality++; } // Create and add a bridge render metaedge. - var bridgeRenderMetaedge = new RenderMetaedgeInformation(bridgeMetaedge); + var bridgeRenderMetaedge = new RenderMetaedgeInfo(bridgeMetaedge); bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge; inbound ? coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) : @@ -1993,10 +2165,11 @@ var tf; cardinality: 1, parentNode: null, stats: null, + include: graph_1.InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; - structuralRenderInfo = new RenderNodeInformation(bridgeNode); + structuralRenderInfo = new RenderNodeInfo(bridgeNode); structuralRenderInfo.structural = true; _this.index[structuralNodeName] = structuralRenderInfo; coreGraph.setNode(structuralNodeName, structuralRenderInfo); @@ -2004,7 +2177,7 @@ var tf; coreGraph.setParent(structuralNodeName, bridgeContainerName); } // Create the structural Metaedge and insert it. - var structuralMetaedgeInfo = new RenderMetaedgeInformation(null); + var structuralMetaedgeInfo = new RenderMetaedgeInfo(null); structuralMetaedgeInfo.structural = true; structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout. inbound ? @@ -2013,9 +2186,9 @@ var tf; }); }); }; - return RenderGraphInformation; + return RenderGraphInfo; })(); - render.RenderGraphInformation = RenderGraphInformation; + render.RenderGraphInfo = RenderGraphInfo; /** * A class for rendering annotation object which contains label * about the node embedded as annotation, type of annotation and the location @@ -2066,7 +2239,7 @@ var tf; ; /** * Manages a list of annotations. Two will be used for each - * RenderNodeInformation, one for in annotations and one for out annotations. + * RenderNodeInfo, one for in annotations and one for out annotations. */ var AnnotationList = (function () { function AnnotationList() { @@ -2093,7 +2266,7 @@ var tf; return; } var ellipsisNode = new tf.graph.EllipsisNodeImpl(1); - this.list.push(new Annotation(ellipsisNode, new RenderNodeInformation(ellipsisNode), null, AnnotationType.ELLIPSIS, annotation.isIn)); + this.list.push(new Annotation(ellipsisNode, new RenderNodeInfo(ellipsisNode), null, AnnotationType.ELLIPSIS, annotation.isIn)); }; return AnnotationList; })(); @@ -2101,8 +2274,8 @@ var tf; /** * Contains rendering information about a node in the hierarchical graph. */ - var RenderNodeInformation = (function () { - function RenderNodeInformation(node) { + var RenderNodeInfo = (function () { + function RenderNodeInfo(node) { this.node = node; this.expanded = false; this.inAnnotations = new AnnotationList(); @@ -2127,31 +2300,30 @@ var tf; this.paddingLeft = 0; this.paddingRight = 0; this.paddingBottom = 0; - this.outerWidth = 0; - this.outerHeight = 0; this.isInExtract = false; this.isOutExtract = false; + this.coreBox = { width: 0, height: 0 }; } - RenderNodeInformation.prototype.isInCore = function () { + RenderNodeInfo.prototype.isInCore = function () { return !this.isInExtract && !this.isOutExtract; }; - return RenderNodeInformation; + return RenderNodeInfo; })(); - render.RenderNodeInformation = RenderNodeInformation; + render.RenderNodeInfo = RenderNodeInfo; /** * Contains rendering information about a Metaedge from the underlying * hierarchical graph. It may be from either a metagraph or a bridgegraph. */ - var RenderMetaedgeInformation = (function () { - function RenderMetaedgeInformation(metaedge) { + var RenderMetaedgeInfo = (function () { + function RenderMetaedgeInfo(metaedge) { this.metaedge = metaedge; this.adjoiningMetaedge = null; this.structural = false; this.weight = 1; } - return RenderMetaedgeInformation; + return RenderMetaedgeInfo; })(); - render.RenderMetaedgeInformation = RenderMetaedgeInformation; + render.RenderMetaedgeInfo = RenderMetaedgeInfo; function addInAnnotation(node, predecessor, predecessorRenderInfo, edge, type, params) { var annotation = new Annotation(predecessor, predecessorRenderInfo, edge, type, true); node.inAnnotations.push(annotation, params); @@ -2175,23 +2347,22 @@ var tf; }); } ; - var RenderGroupNodeInformation = (function (_super) { - __extends(RenderGroupNodeInformation, _super); - function RenderGroupNodeInformation(groupNode) { + var RenderGroupNodeInfo = (function (_super) { + __extends(RenderGroupNodeInfo, _super); + function RenderGroupNodeInfo(groupNode) { _super.call(this, groupNode); var metagraph = groupNode.metagraph; var gl = metagraph.graph(); this.coreGraph = graph_1.createGraph(gl.name, graph_1.GraphType.CORE, { compound: true }); - this.coreBox = { width: 0, height: 0 }; this.inExtractBox = { width: 0, height: 0 }; this.outExtractBox = { width: 0, height: 0 }; this.isolatedInExtract = []; this.isolatedOutExtract = []; } - return RenderGroupNodeInformation; - })(RenderNodeInformation); - render.RenderGroupNodeInformation = RenderGroupNodeInformation; + return RenderGroupNodeInfo; + })(RenderNodeInfo); + render.RenderGroupNodeInfo = RenderGroupNodeInfo; function setGroupNodeDepth(renderInfo, depth) { if (renderInfo.coreGraph) { setGraphDepth(renderInfo.coreGraph, depth); @@ -2208,6 +2379,15 @@ var tf; var src = graph.node(v); var sink = graph.node(w); var edge = graph.edge(v, w); + // If either of the nodes is explicitly included in the main graph and + // both nodes are in the main graph then do not create the shortcut + // and instead keep the real edge. + if ((src.node.include === graph_1.InclusionType.INCLUDE || + sink.node.include === graph_1.InclusionType.INCLUDE) && + src.node.include !== graph_1.InclusionType.EXCLUDE && + sink.node.include !== graph_1.InclusionType.EXCLUDE) { + return; + } // Add each annotation. addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params); addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params); @@ -2218,48 +2398,55 @@ var tf; * Remove edges from a node, and set its isOutExtract property to true, * and remove the node and move it to isolatedOutExtract. * - * If detachAllEdgesForHighDegree is true, extract all of its edges. - * Otherwise, only extract all in-edges. + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only extract all in-edges. */ - function makeOutExtract(renderNode, n, params) { + function makeOutExtract(renderNode, n, params, forceDetach) { var graph = renderNode.coreGraph; - graph.node(n).isOutExtract = true; + var child = graph.node(n); + child.isOutExtract = true; _.each(graph.predecessors(n), function (p, index) { createShortcut(graph, p, n, params); }); - if (params.detachAllEdgesForHighDegree) { + if (params.detachAllEdgesForHighDegree || forceDetach) { _.each(graph.successors(n), function (s, index) { createShortcut(graph, n, s, params); }); } - if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) { - renderNode.isolatedOutExtract.push(graph.node(n)); + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = graph_1.InclusionType.EXCLUDE; + renderNode.isolatedOutExtract.push(child); graph.removeNode(n); } } /** * Remove edges from a node, set its isInExtract property to true, * and remove the node and move it to isolatedInExtract. - * If detachAllEdgesForHighDegree is true, extract all of its edges. - * Otherwise, only remove all out-edges. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only remove all out-edges. */ - function makeInExtract(renderNode, n, params) { + function makeInExtract(renderNode, n, params, forceDetach) { var graph = renderNode.coreGraph; - graph.node(n).isInExtract = true; + var child = graph.node(n); + child.isInExtract = true; _.each(graph.successors(n), function (s, index) { createShortcut(graph, n, s, params); }); - if (params.detachAllEdgesForHighDegree) { + if (params.detachAllEdgesForHighDegree || forceDetach) { _.each(graph.predecessors(n), function (p, index) { createShortcut(graph, p, n, params); }); } - // Remove the node from the core graph if conditions are met. - if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) { - renderNode.isolatedInExtract.push(graph.node(n)); + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = graph_1.InclusionType.EXCLUDE; + renderNode.isolatedInExtract.push(child); graph.removeNode(n); } } + render.makeInExtract = makeInExtract; /** * Check whether the node's type is a member of the given list of types. * @@ -2286,11 +2473,30 @@ var tf; } return false; } + /** Move nodes that are speficied to be excluded out of the core graph. */ + function extractSpecifiedNodes(renderNode, params) { + var graph = renderNode.coreGraph; + _.each(graph.nodes(), function (n) { + var renderInfo = graph.node(n); + if (renderInfo.node.include === graph_1.InclusionType.EXCLUDE) { + if (renderNode.coreGraph.outEdges(n).length > + renderNode.coreGraph.inEdges(n).length) { + makeOutExtract(renderNode, n, params, true); + } + else { + makeInExtract(renderNode, n, params, true); + } + } + }); + } /** Remove edges from pre-defined out-extract patterns */ function extractPredefinedSink(renderNode, params) { var graph = renderNode.coreGraph; _.each(graph.nodes(), function (n) { var renderInfo = graph.node(n); + if (renderInfo.node.include !== graph_1.InclusionType.UNSPECIFIED) { + return; + } if (hasTypeIn(renderInfo.node, params.outExtractTypes)) { makeOutExtract(renderNode, n, params); } @@ -2301,6 +2507,9 @@ var tf; var graph = renderNode.coreGraph; _.each(graph.nodes(), function (n) { var renderInfo = graph.node(n); + if (renderInfo.node.include !== graph_1.InclusionType.UNSPECIFIED) { + return; + } if (hasTypeIn(renderInfo.node, params.inExtractTypes)) { makeInExtract(renderNode, n, params); } @@ -2312,6 +2521,9 @@ var tf; var maxInDegree = params.maxInDegree; // detect first so degrees don't get affected by other removal var highInDegreeNames = _.filter(graph.nodes(), function (n) { + if (graph.node(n).node.include !== graph_1.InclusionType.UNSPECIFIED) { + return false; + } // Count the in-degree based on only regular edges, unless there are // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted @@ -2335,6 +2547,9 @@ var tf; var maxOutDegree = params.maxOutDegree; // detect first so degrees don't get affected by other removal var highOutDegreeNames = _.filter(graph.nodes(), function (n) { + if (graph.node(n).node.include !== graph_1.InclusionType.UNSPECIFIED) { + return false; + } // Count the out-degree based on only regular edges, unless there are // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted @@ -2400,6 +2615,7 @@ var tf; * <tf-graph-params>'s output */ function extractHighDegrees(renderNode, params) { + extractSpecifiedNodes(renderNode, params); if (params.outExtractTypes) { extractPredefinedSink(renderNode, params); } @@ -2434,6 +2650,9 @@ var tf; _.each(graph.nodes(), function (n) { var child = graph.node(n); var degree = graph.neighbors(n).length; + if (child.node.include !== graph_1.InclusionType.UNSPECIFIED) { + return; + } if (degree === 0) { var hasOutAnnotations = child.outAnnotations.list.length > 0; var hasInAnnotations = child.inAnnotations.list.length > 0; @@ -2441,23 +2660,27 @@ var tf; // This case only happens if detachAllEdgesForHighDegree is false. // (Otherwise all source-like nodes are all isolated already.) renderNode.isolatedInExtract.push(child); + child.node.include = graph_1.InclusionType.EXCLUDE; graph.removeNode(n); } else if (child.isOutExtract) { // This case only happens if detachAllEdgesForHighDegree is false. // // (Otherwise all sink-like nodes are all isolated already.) renderNode.isolatedOutExtract.push(child); + child.node.include = graph_1.InclusionType.EXCLUDE; graph.removeNode(n); } else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) { if (hasOutAnnotations && !hasInAnnotations) { child.isInExtract = true; // for ones with high out-annotations renderNode.isolatedInExtract.push(child); + child.node.include = graph_1.InclusionType.EXCLUDE; graph.removeNode(n); } else if (hasInAnnotations && !hasOutAnnotations) { child.isOutExtract = true; // for ones with high in-annotations renderNode.isolatedOutExtract.push(child); + child.node.include = graph_1.InclusionType.EXCLUDE; graph.removeNode(n); } else { @@ -2470,7 +2693,21 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module tf.graph.render </script> -<script>/// <reference path="graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="graph.ts" /> /// <reference path="hierarchy.ts" /> var tf; (function (tf) { @@ -2568,6 +2805,7 @@ var tf; function groupTemplateAndAssignId(nnGroups, verifyTemplate) { // For each metanode, compare its subgraph (starting from shallower groups) // and assign template id. + var result = {}; return _.reduce(nnGroups, function (templates, nnGroupPair) { var signature = nnGroupPair[0], nnGroup = nnGroupPair[1].nodes, clusters = []; nnGroup.forEach(function (metanode) { @@ -2597,7 +2835,7 @@ var tf; }; }); return templates; - }, {}); + }, result); } function sortNodes(names, graph, prefix) { return _.sortByAll(names, function (name) { @@ -2697,7 +2935,8 @@ var tf; // compare metanode var metanode1 = n1; var metanode2 = n2; - return metanode1.templateId && metanode2.templateId && metanode1.templateId === metanode2.templateId; + return metanode1.templateId && metanode2.templateId && + metanode1.templateId === metanode2.templateId; } else if (n1.type === graph_1.NodeType.OP && n2.type === graph_1.NodeType.OP) { // compare leaf node @@ -2706,13 +2945,13 @@ var tf; else if (n1.type === graph_1.NodeType.SERIES && n2.type === graph_1.NodeType.SERIES) { // compare series node sizes and operations // (only need to check one op as all op nodes are identical in series) - var seriesnode1 = n1; - var seriesnode2 = n2; - var seriesnode1Count = seriesnode1.metagraph.nodeCount(); - return (seriesnode1Count === seriesnode2.metagraph.nodeCount() && + var sn1 = n1; + var sn2 = n2; + var seriesnode1Count = sn1.metagraph.nodeCount(); + return (seriesnode1Count === sn2.metagraph.nodeCount() && (seriesnode1Count === 0 || - (seriesnode1.metagraph.node(seriesnode1.metagraph.nodes()[0]).op === - seriesnode2.metagraph.node(seriesnode2.metagraph.nodes()[0]).op))); + (sn1.metagraph.node(sn1.metagraph.nodes()[0]).op === + sn2.metagraph.node(sn2.metagraph.nodes()[0]).op))); } return false; } @@ -2720,7 +2959,21 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); </script> -<script>/// <reference path="../graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../graph.ts" /> /// <reference path="edge.ts" /> /// <reference path="node.ts" /> /// <reference path="../layout.ts" /> @@ -2824,8 +3077,8 @@ var tf; * provided node. */ function panToNode(nodeName, svg, zoomG, d3zoom) { - var node = d3.selectAll("[data-name='" + nodeName + "']." - + scene.Class.Node.GROUP)[0][0]; + var node = d3.select("[data-name='" + nodeName + "']." + + scene.Class.Node.GROUP).node(); if (!node) { return false; } @@ -2993,8 +3246,7 @@ var tf; position(sceneGroup, renderNode); // Fade in the scene group if it didn't already exist. if (isNewSceneGroup) { - sceneGroup.attr("opacity", 0) - .transition().attr("opacity", 1); + sceneGroup.attr("opacity", 0).transition().attr("opacity", 1); } return sceneGroup; } @@ -3017,17 +3269,17 @@ var tf; // core translate(selectChild(sceneGroup, "g", scene.Class.Scene.CORE), 0, yTranslate); // in-extract - var inExtractX = renderNode.coreBox.width === 0 ? - 0 : renderNode.coreBox.width; var hasInExtract = renderNode.isolatedInExtract.length > 0; if (hasInExtract) { + var inExtractX = renderNode.coreBox.width - + renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width; translate(selectChild(sceneGroup, "g", scene.Class.Scene.INEXTRACT), inExtractX, yTranslate); } // out-extract var hasOutExtract = renderNode.isolatedOutExtract.length > 0; if (hasOutExtract) { - var outExtractX = inExtractX + renderNode.inExtractBox.width - + renderNode.extractXOffset; + var outExtractX = renderNode.coreBox.width - + renderNode.outExtractBox.width / 2; translate(selectChild(sceneGroup, "g", scene.Class.Scene.OUTEXTRACT), outExtractX, yTranslate); } } @@ -3042,6 +3294,10 @@ var tf; ; /** Helper for adding transform: translate(x0, y0) */ function translate(selection, x0, y0) { + // If it is already placed on the screen, make it a transition. + if (selection.attr("transform") != null) { + selection = selection.transition("position"); + } selection.attr("transform", "translate(" + x0 + "," + y0 + ")"); } scene.translate = translate; @@ -3071,10 +3327,15 @@ var tf; * the button on. */ function positionButton(button, renderNode) { + var cx = graph.layout.computeCXPositionOfNodeShape(renderNode); // Position the button in the top-right corner of the group node, // with space given the draw the button inside of the corner. - var x = renderNode.x + renderNode.width / 2 - 6; - var y = renderNode.y - renderNode.height / 2 + 6; + var width = renderNode.expanded ? + renderNode.width : renderNode.coreBox.width; + var height = renderNode.expanded ? + renderNode.height : renderNode.coreBox.height; + var x = cx + width / 2 - 6; + var y = renderNode.y - height / 2 + 6; // For unexpanded series nodes, the button has special placement due // to the unique visuals of this group node. if (renderNode.node.type === graph.NodeType.SERIES && !renderNode.expanded) { @@ -3113,10 +3374,25 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module </script> -<script>/// <reference path="../graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../graph.ts" /> /// <reference path="../render.ts" /> /// <reference path="scene.ts" /> /// <reference path="edge.ts" /> +/// <reference path="contextmenu.ts" /> var tf; (function (tf) { var graph; @@ -3124,7 +3400,7 @@ var tf; var scene; (function (scene) { var annotation; - (function (annotation) { + (function (annotation_1) { /** * Populate a given annotation container group * @@ -3191,7 +3467,7 @@ var tf; var aGroup = d3.select(this); update(aGroup, d, a, sceneBehavior); if (a.annotationType !== tf.graph.render.AnnotationType.ELLIPSIS) { - addInteraction(aGroup, d, sceneBehavior); + addInteraction(aGroup, d, a, sceneBehavior); } }); annotationGroups.exit() @@ -3203,7 +3479,7 @@ var tf; .remove(); return annotationGroups; } - annotation.buildGroup = buildGroup; + annotation_1.buildGroup = buildGroup; ; /** * Maps an annotation enum to a class name used in css rules. @@ -3214,11 +3490,10 @@ var tf; } function buildShape(aGroup, a, sceneBehavior) { if (a.annotationType === tf.graph.render.AnnotationType.SUMMARY) { - var image = scene.selectOrCreateChild(aGroup, "image"); - image.attr({ - "xlink:href": sceneBehavior.resolveUrl("../../lib/svg/summary-icon.svg"), - "height": "12px", - "width": "12px", + var summary = scene.selectOrCreateChild(aGroup, "use"); + summary.attr({ + "class": "summary", + "xlink:href": "#summary-icon", "cursor": "pointer" }); } @@ -3247,7 +3522,7 @@ var tf; .text(label) .append("title").text(titleText); } - function addInteraction(selection, d, sceneBehavior) { + function addInteraction(selection, d, annotation, sceneBehavior) { selection .on("mouseover", function (a) { sceneBehavior.fire("annotation-highlight", { @@ -3270,6 +3545,10 @@ var tf; hostName: d.node.name }); }); + if (annotation.annotationType !== tf.graph.render.AnnotationType.SUMMARY && + annotation.annotationType !== tf.graph.render.AnnotationType.CONSTANT) { + selection.on("contextmenu", tf.graph.scene.contextmenu.getMenu(tf.graph.scene.node.getContextMenu(annotation.node, sceneBehavior))); + } } ; /** @@ -3281,6 +3560,7 @@ var tf; * @param scene Polymer scene element. */ function update(aGroup, d, a, sceneBehavior) { + var cx = graph.layout.computeCXPositionOfNodeShape(d); // Annotations that point to embedded nodes (constants,summary) // don't have a render information attached so we don't stylize these. // Also we don't stylize ellipsis annotations (the string "... and X more"). @@ -3294,7 +3574,7 @@ var tf; } // label position aGroup.select("text." + scene.Class.Annotation.LABEL).transition().attr({ - x: d.x + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), + x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), y: d.y + a.dy }); // Some annotations (such as summary) are represented using a 12x12 image tag. @@ -3302,19 +3582,19 @@ var tf; // If there is an image, we adjust the location of the image to be vertically // centered with the node and horizontally centered between the arrow and the // text label. - aGroup.select("image").transition().attr({ - x: d.x + a.dx - 3, + aGroup.select("use.summary").transition().attr({ + x: cx + a.dx - 3, y: d.y + a.dy - 6 }); // Node position (only one of the shape selection will be non-empty.) - scene.positionEllipse(aGroup.select("." + scene.Class.Annotation.NODE + " ellipse"), d.x + a.dx, d.y + a.dy, a.width, a.height); - scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " rect"), d.x + a.dx, d.y + a.dy, a.width, a.height); - scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " use"), d.x + a.dx, d.y + a.dy, a.width, a.height); + scene.positionEllipse(aGroup.select("." + scene.Class.Annotation.NODE + " ellipse"), cx + a.dx, d.y + a.dy, a.width, a.height); + scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " rect"), cx + a.dx, d.y + a.dy, a.width, a.height); + scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " use"), cx + a.dx, d.y + a.dy, a.width, a.height); // Edge position aGroup.select("path." + scene.Class.Annotation.EDGE).transition().attr("d", function (a) { // map relative position to absolute position var points = a.points.map(function (p) { - return { x: p.dx + d.x, y: p.dy + d.y }; + return { x: p.dx + cx, y: p.dy + d.y }; }); return scene.edge.interpolate(points); }); @@ -3325,7 +3605,21 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module </script> -<script>/// <reference path="../graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../graph.ts" /> /// <reference path="../render.ts" /> /// <reference path="scene.ts" /> var tf; @@ -3336,7 +3630,6 @@ var tf; (function (scene) { var edge; (function (edge) { - var Scene = tf.graph.scene; // Aliased function getEdgeKey(edgeObj) { return edgeObj.v + tf.graph.EDGE_KEY_DELIM + edgeObj.w; } @@ -3361,6 +3654,7 @@ var tf; * @return selection of the created nodeGroups */ function buildGroup(sceneGroup, graph, sceneBehavior) { + var edges = []; var edgeData = _.reduce(graph.edges(), function (edges, edgeObj) { var edgeLabel = graph.edge(edgeObj); edges.push({ @@ -3369,9 +3663,8 @@ var tf; label: edgeLabel }); return edges; - }, []); + }, edges); var container = scene.selectOrCreateChild(sceneGroup, "g", scene.Class.Edge.CONTAINER); - var containerNode = container.node(); // Select all children and join with data. // (Note that all children of g.edges are g.edge) var edgeGroups = container.selectAll(function () { @@ -3416,7 +3709,7 @@ var tf; * For a given d3 selection and data object, create a path to represent the * edge described in d.label. * - * If d.label is defined, it will be a RenderMetaedgeInformation instance. It + * If d.label is defined, it will be a RenderMetaedgeInfo instance. It * will sometimes be undefined, for example for some Annotation edges for which * there is no underlying Metaedge in the hierarchical graph. */ @@ -3430,6 +3723,10 @@ var tf; } edge.appendEdge = appendEdge; ; + edge.interpolate = d3.svg.line() + .interpolate("basis") + .x(function (d) { return d.x; }) + .y(function (d) { return d.y; }); /** * Returns a tween interpolator for the endpoint of an edge path. */ @@ -3462,10 +3759,6 @@ var tf; return dPath; }; } - edge.interpolate = d3.svg.line() - .interpolate("basis") - .x(function (d) { return d.x; }) - .y(function (d) { return d.y; }); function position(d) { d3.select(this).select("path." + scene.Class.Edge.LINE) .each(function (d) { @@ -3478,10 +3771,9 @@ var tf; * For a given d3 selection and data object, mark the edge as a control * dependency if it contains only control edges. * - * d's label property will be a RenderMetaedgeInformation object. + * d's label property will be a RenderMetaedgeInfo object. */ function stylize(edgeGroup, d, stylize) { - var a; var metaedge = d.label.metaedge; edgeGroup .select("path." + scene.Class.Edge.LINE) @@ -3493,9 +3785,24 @@ var tf; })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module </script> -<script>/// <reference path="../graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../graph.ts" /> /// <reference path="scene.ts" /> /// <reference path="annotation.ts" /> +/// <reference path="contextmenu.ts" /> var tf; (function (tf) { var graph; @@ -3693,6 +4000,7 @@ var tf; selection.attr("pointer-events", "none"); return; } + var contextMenuFunction = tf.graph.scene.contextmenu.getMenu(getContextMenu(d.node, sceneBehavior)); selection.on("dblclick", function (d) { sceneBehavior.fire("node-toggle-expand", { name: d.node.name }); }) @@ -3717,10 +4025,28 @@ var tf; // a graph-select. d3.event.stopPropagation(); sceneBehavior.fire("node-select", { name: d.node.name }); + }) + .on("contextmenu", function (d, i) { + sceneBehavior.fire("node-select", { name: d.node.name }); + contextMenuFunction.call(d, i); }); } ; /** + * Returns the d3 context menu specification for the provided node. + */ + function getContextMenu(node, sceneBehavior) { + return [{ + title: function (d) { + return tf.graph.getIncludeNodeButtonString(node.include); + }, + action: function (elm, d, i) { + sceneBehavior.fire("node-toggle-extract", { name: node.name }); + } + }]; + } + node_1.getContextMenu = getContextMenu; + /** * Append svg text for label and assign data. * @param nodeGroup * @param renderNodeInfo The render node information for the label. @@ -3764,10 +4090,10 @@ var tf; /** * Set label position of a given node group */ - function labelPosition(nodeGroup, d, yOffset) { + function labelPosition(nodeGroup, cx, cy, yOffset) { scene.selectChild(nodeGroup, "text", scene.Class.Node.LABEL).transition() - .attr("x", d.x) - .attr("y", d.y + yOffset); + .attr("x", cx) + .attr("y", cy + yOffset); } ; /** @@ -3775,7 +4101,7 @@ var tf; * as the shape's data. * * @param nodeGroup - * @param d RenderNodeInformation + * @param d Render node information. * @param nodeClass class for the element. * @param before Reference DOM node for insertion. * @return Selection of the shape. @@ -3837,38 +4163,41 @@ var tf; /** Modify node and its subscene and its label's positional attributes */ function position(nodeGroup, d, sceneBehavior) { var shapeGroup = scene.selectChild(nodeGroup, "g", scene.Class.Node.SHAPE); + var cx = graph.layout.computeCXPositionOfNodeShape(d); switch (d.node.type) { case graph.NodeType.OP: { // position shape var shape = scene.selectChild(shapeGroup, "ellipse"); - scene.positionEllipse(shape, d.x, d.y, d.width, d.height); - labelPosition(nodeGroup, d, d.labelOffset); + scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); break; } case graph.NodeType.META: { // position shape var shape = scene.selectChild(shapeGroup, "rect"); - scene.positionRect(shape, d.x, d.y, d.width, d.height); if (d.expanded) { + scene.positionRect(shape, d.x, d.y, d.width, d.height); subscenePosition(nodeGroup, d); // put label on top - labelPosition(nodeGroup, d, -d.height / 2 + d.labelHeight / 2); + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); } else { - labelPosition(nodeGroup, d, 0); + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, 0); } break; } case graph.NodeType.SERIES: { var shape = scene.selectChild(shapeGroup, "use"); - scene.positionRect(shape, d.x, d.y, d.width, d.height); if (d.expanded) { + scene.positionRect(shape, d.x, d.y, d.width, d.height); subscenePosition(nodeGroup, d); // put label on top - labelPosition(nodeGroup, d, -d.height / 2 + d.labelHeight / 2); + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); } else { - labelPosition(nodeGroup, d, d.labelOffset); + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); } } case graph.NodeType.BRIDGE: { @@ -3886,29 +4215,33 @@ var tf; } ; /** Enum specifying the options to color nodes by */ - var ColorBy = { - STRUCTURE: 0, - DEVICE: 1, - COMPUTE_TIME: 2, - MEMORY: 3 - }; + (function (ColorBy) { + ColorBy[ColorBy["STRUCTURE"] = 0] = "STRUCTURE"; + ColorBy[ColorBy["DEVICE"] = 1] = "DEVICE"; + ColorBy[ColorBy["COMPUTE_TIME"] = 2] = "COMPUTE_TIME"; + ColorBy[ColorBy["MEMORY"] = 3] = "MEMORY"; + })(node_1.ColorBy || (node_1.ColorBy = {})); + var ColorBy = node_1.ColorBy; + ; /** * Returns the fill color for the node given its state and the "color by" * option. */ - function getFillForNode(sceneBehavior, colorBy, renderInfo, isExpanded) { + function getFillForNode(templateIndex, colorBy, renderInfo, isExpanded) { var colorParams = tf.graph.render.MetanodeColors; switch (colorBy) { case ColorBy.STRUCTURE: if (renderInfo.node.type === tf.graph.NodeType.META) { var tid = renderInfo.node.templateId; - return tid === null ? colorParams.UNKNOWN : colorParams.STRUCTURE_PALETTE(sceneBehavior.templateIndex(tid), renderInfo.expanded); + return tid === null ? + colorParams.UNKNOWN : + colorParams.STRUCTURE_PALETTE(templateIndex(tid), isExpanded); } else if (renderInfo.node.type === tf.graph.NodeType.SERIES) { // If expanded, we're showing the background rect, which we want to // appear gray. Otherwise we're showing a stack of ellipses which we // want to show white. - return renderInfo.expanded ? colorParams.EXPANDED_COLOR : "white"; + return isExpanded ? colorParams.EXPANDED_COLOR : "white"; } else if (renderInfo.node.type === graph.NodeType.BRIDGE) { return renderInfo.structural ? "#f0e" : @@ -3958,6 +4291,7 @@ var tf; throw new Error("Unknown case to color nodes by"); } } + node_1.getFillForNode = getFillForNode; /** * Modify node style by toggling class and assign attributes (only for things * that can't be done in css). @@ -3975,29 +4309,111 @@ var tf; // Main node always exists here and it will be reached before subscene, // so d3 selection is fine here. var node = nodeGroup.select("." + nodeClass + " ." + scene.Class.Node.COLOR_TARGET); - var fillColor = getFillForNode(sceneBehavior, ColorBy[sceneBehavior.colorBy.toUpperCase()], renderInfo, isExpanded); + var fillColor = getFillForNode(sceneBehavior.templateIndex, ColorBy[sceneBehavior.colorBy.toUpperCase()], renderInfo, isExpanded); node.style("fill", fillColor); // Choose outline to be darker version of node color if the node is a single // color and is not selected. - if (isSelected) { - node.style("stroke", null); - } - else { - // If node is colored by a gradient, then use a dark gray outline. - var outlineColor = fillColor.substring(0, 3) === "url" ? - tf.graph.render.MetanodeColors.GRADIENT_OUTLINE : - d3.rgb(fillColor).darker().toString(); - node.style("stroke", outlineColor); - } + node.style("stroke", isSelected ? null : getStrokeForFill(fillColor)); } node_1.stylize = stylize; ; + /** + * Given a node's fill color/gradient, determine the stroke for the node. + */ + function getStrokeForFill(fill) { + // If node is colored by a gradient, then use a dark gray outline. + return fill.substring(0, 3) === "url" ? + tf.graph.render.MetanodeColors.GRADIENT_OUTLINE : + d3.rgb(fill).darker().toString(); + } + node_1.getStrokeForFill = getStrokeForFill; })(node = scene.node || (scene.node = {})); })(scene = graph.scene || (graph.scene = {})); })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module </script> -<script>/// <reference path="graph.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +var tf; +(function (tf) { + var graph; + (function (graph) { + var scene; + (function (scene) { + var contextmenu; + (function (contextmenu) { + /** + * Returns the event listener, which can be used as an argument for the d3 + * selection.on function. Renders the context menu that is to be displayed + * in response to the event. + */ + function getMenu(menu) { + var menuSelection = d3.select(".context-menu"); + // Close the menu when anything else is clicked. + d3.select("body").on("click.context", function () { + menuSelection.style("display", "none"); + }); + // Function called to populate the context menu. + return function (data, index) { + var _this = this; + // Position and display the menu. + var event = d3.event; + menuSelection.style({ + "display": "block", + "left": (event.layerX + 1) + "px", + "top": (event.layerY + 1) + "px" + }); + // Stop the event from propagating further. + event.preventDefault(); + event.stopPropagation(); + // Add provided items to the context menu. + menuSelection.html(""); + var list = menuSelection.append("ul"); + list.selectAll("li").data(menu).enter() + .append("li") + .html(function (d) { + return d.title(data); + }) + .on("click", function (d, i) { + d.action(_this, data, index); + menuSelection.style("display", "none"); + }); + }; + } + contextmenu.getMenu = getMenu; + ; + })(contextmenu = scene.contextmenu || (scene.contextmenu = {})); + })(scene = graph.scene || (graph.scene = {})); + })(graph = tf.graph || (tf.graph = {})); +})(tf || (tf = {})); // close module +</script> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="graph.ts" /> /// <reference path="render.ts" /> var tf; (function (tf) { @@ -4020,14 +4436,19 @@ var tf; * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - nodeSep: 110, + nodeSep: 5, /** * Dagre's ranksep param - number of pixels * between each rank in the layout. * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - rankSep: 25 + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5, }, /** Graph parameter for metanode. */ series: { @@ -4037,7 +4458,7 @@ var tf; * * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ - nodeSep: 90, + nodeSep: 5, /** * Dagre's ranksep param - number of pixels * between each rank in the layout. @@ -4045,6 +4466,11 @@ var tf; * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout */ rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5 }, /** * Padding is used to correctly position the graph SVG inside of its parent @@ -4153,6 +4579,10 @@ var tf; } }, annotations: { + /** Maximum possible width of the bounding box for in annotations */ + inboxWidth: 50, + /** Maximum possible width of the bounding box for out annotations */ + outboxWidth: 50, /** X-space between the shape and each annotation-node. */ xOffset: 10, /** Y-space between each annotation-node. */ @@ -4188,7 +4618,7 @@ var tf; } }; /** Calculate layout for a scene of a group node. */ - function scene(renderNodeInfo) { + function layoutScene(renderNodeInfo) { // Update layout, size, and annotations of its children nodes and edges. if (renderNodeInfo.node.isGroupNode) { layoutChildren(renderNodeInfo); @@ -4201,9 +4631,29 @@ var tf; layoutSeriesNode(renderNodeInfo); } } - layout.scene = scene; + layout.layoutScene = layoutScene; ; /** + * Updates the total width of an unexpanded node which includes the size of its + * in and out annotations. + */ + function updateTotalWidthOfNode(renderInfo) { + renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ? + layout.PARAMS.annotations.inboxWidth : 0; + renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ? + layout.PARAMS.annotations.outboxWidth : 0; + // Assign the width of the core box (the main shape of the node). + renderInfo.coreBox.width = renderInfo.width; + renderInfo.coreBox.height = renderInfo.height; + // TODO(jimbo): Account for font width rather than using a magic number. + var labelLength = renderInfo.node.name.length - + renderInfo.node.name.lastIndexOf(graph_1.NAMESPACE_DELIM) - 1; + var charWidth = 3; // 3 pixels per character. + // Compute the total width of the node. + renderInfo.width = Math.max(renderInfo.coreBox.width + + renderInfo.inboxWidth + renderInfo.outboxWidth, labelLength * charWidth); + } + /** * Update layout, size, and annotations of its children nodes and edges. */ function layoutChildren(renderNodeInfo) { @@ -4221,21 +4671,21 @@ var tf; break; case graph_1.NodeType.META: if (!childNodeInfo.expanded) { - // set fixed width and scalable height based on cardinality + // Set fixed width and scalable height based on cardinality _.extend(childNodeInfo, layout.PARAMS.nodeSize.meta); childNodeInfo.height = layout.PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); } else { var childGroupNodeInfo = childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. } break; case graph_1.NodeType.SERIES: if (childNodeInfo.expanded) { _.extend(childNodeInfo, layout.PARAMS.nodeSize.series.expanded); var childGroupNodeInfo = childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. } else { var childGroupNodeInfo = childNodeInfo; @@ -4248,6 +4698,11 @@ var tf; default: throw Error("Unrecognized node type: " + childNodeInfo.node.type); } + // Compute total width of un-expanded nodes. Width of expanded nodes + // has already been computed. + if (!childNodeInfo.expanded) { + updateTotalWidthOfNode(childNodeInfo); + } // Layout each child's annotations layoutAnnotation(childNodeInfo); }); @@ -4260,8 +4715,9 @@ var tf; */ function dagreLayout(graph, params) { _.extend(graph.graph(), { - nodeSep: params.nodeSep, - rankSep: params.rankSep + nodesep: params.nodeSep, + ranksep: params.rankSep, + edgesep: params.edgeSep }); var bridgeNodeNames = []; var nonBridgeNodeNames = []; @@ -4284,7 +4740,6 @@ var tf; }; } dagre.layout(graph); - var graphLabel = graph.graph(); // Calculate the true bounding box of the graph by iterating over nodes and // edges rather than accepting dagre's word for it. In particular, we should // ignore the extra-wide bridge nodes and bridge edges, and allow for @@ -4296,31 +4751,62 @@ var tf; _.each(nonBridgeNodeNames, function (nodeName) { var nodeInfo = graph.node(nodeName); var w = 0.5 * nodeInfo.width; - var x1 = nodeInfo.x - w - nodeInfo.inboxWidth; - var x2 = nodeInfo.x + w + nodeInfo.outboxWidth; + var x1 = nodeInfo.x - w; + var x2 = nodeInfo.x + w; minX = x1 < minX ? x1 : minX; maxX = x2 > maxX ? x2 : maxX; - var labelLength = nodeName.length - nodeName.lastIndexOf(graph_1.NAMESPACE_DELIM); - // TODO(jimbo): Account for font width rather than using a magic number. - var charWidth = 3; // 3 pixels per character. - var lw = 0.5 * labelLength * charWidth; - var lx1 = nodeInfo.x - lw; - var lx2 = nodeInfo.x + lw; - minX = lx1 < minX ? lx1 : minX; - maxX = lx2 > maxX ? lx2 : maxX; // TODO(jimbo): Account for the height of labels above op nodes here. - var h = 0.5 * nodeInfo.outerHeight; + var h = 0.5 * nodeInfo.height; var y1 = nodeInfo.y - h; var y2 = nodeInfo.y + h; minY = y1 < minY ? y1 : minY; maxY = y2 > maxY ? y2 : maxY; }); _.each(graph.edges(), function (edgeObj) { - var renderMetaedgeInfo = graph.edge(edgeObj); - if (renderMetaedgeInfo.structural) { + var edgeInfo = graph.edge(edgeObj); + if (edgeInfo.structural) { return; // Skip structural edges from min/max calculations. } - _.each(renderMetaedgeInfo.points, function (point) { + // Since the node size passed to dagre includes the in and out + // annotations, the endpoints of the edge produced by dagre may not + // point to the actual node shape (rectangle, ellipse). We correct the + // end-points by finding the intersection of a line between the + // next-to-last (next-to-first) point and the destination (source) + // rectangle. + var sourceNode = graph.node(edgeInfo.metaedge.v); + var destNode = graph.node(edgeInfo.metaedge.w); + // Straight 3-points edges are special case, since they are curved after + // our default correction. To keep them straight, we remove the mid point + // and correct the first and the last point to be the center of the + // source and destination node respectively. + if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) { + if (sourceNode != null) { + var cxSource = sourceNode.expanded ? + sourceNode.x : computeCXPositionOfNodeShape(sourceNode); + edgeInfo.points[0].x = cxSource; + } + if (destNode != null) { + var cxDest = destNode.expanded ? + destNode.x : computeCXPositionOfNodeShape(destNode); + edgeInfo.points[2].x = cxDest; + } + // Remove the middle point so the edge doesn't curve. + edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]]; + } + // Correct the destination endpoint of the edge. + var nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2]; + // The destination node might be null if this is a bridge edge. + if (destNode != null) { + edgeInfo.points[edgeInfo.points.length - 1] = + intersectPointAndNode(nextToLastPoint, destNode); + } + // Correct the source endpoint of the edge. + var secondPoint = edgeInfo.points[1]; + // The source might be null if this is a bridge edge. + if (sourceNode != null) { + edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode); + } + _.each(edgeInfo.points, function (point) { minX = point.x < minX ? point.x : minX; maxX = point.x > maxX ? point.x : maxX; minY = point.y < minY ? point.y : minY; @@ -4342,59 +4828,59 @@ var tf; }); return { width: maxX - minX, - height: maxY - minY, + height: maxY - minY }; } - /** Layout a metanode. */ + /** Layout a metanode. Only called for an expanded node. */ function layoutMetanode(renderNodeInfo) { // First, copy params specific to meta nodes onto this render info object. var params = layout.PARAMS.subscene.meta; - renderNodeInfo = _.extend(renderNodeInfo, params); + _.extend(renderNodeInfo, params); // Invoke dagre.layout() on the core graph and record the bounding box // dimensions. _.extend(renderNodeInfo.coreBox, dagreLayout(renderNodeInfo.coreGraph, layout.PARAMS.graph.meta)); // Calculate the position of nodes in isolatedInExtract relative to the // top-left corner of inExtractBox (the bounding box for all inExtract nodes) // and calculate the size of the inExtractBox. - var hasInExtract = renderNodeInfo.isolatedInExtract.length > 0; - renderNodeInfo.inExtractBox.width = hasInExtract ? - _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0; + var maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract, function (renderNode) { return renderNode.width; }).width; + renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ? + maxInExtractWidth : 0; renderNodeInfo.inExtractBox.height = _.reduce(renderNodeInfo.isolatedInExtract, function (height, child, i) { var yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.inExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; }, 0); // Calculate the position of nodes in isolatedOutExtract relative to the // top-left corner of outExtractBox (the bounding box for all outExtract // nodes) and calculate the size of the outExtractBox. - var hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0; - renderNodeInfo.outExtractBox.width = hasOutExtract ? - _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0; + var maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract, function (renderNode) { return renderNode.width; }).width; + renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ? + maxOutExtractWidth : 0; renderNodeInfo.outExtractBox.height = _.reduce(renderNodeInfo.isolatedOutExtract, function (height, child, i) { var yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.outExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; }, 0); + // Add the in-extract and out-extract width to the core box width. + renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width + + renderNodeInfo.outExtractBox.width; + renderNodeInfo.coreBox.height = + params.labelHeight + + Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height); // Determine the whole metanode's width (from left to right). - renderNodeInfo.width = - params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight + - (hasInExtract ? - renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) + - (hasOutExtract ? - params.extractXOffset + renderNodeInfo.outExtractBox.width : 0); - // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes. + renderNodeInfo.width = renderNodeInfo.coreBox.width + + params.paddingLeft + params.paddingRight; // Determine the whole metanode's height (from top to bottom). renderNodeInfo.height = - renderNodeInfo.labelHeight + - params.paddingTop + - Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height) + - params.paddingBottom; + renderNodeInfo.paddingTop + + renderNodeInfo.coreBox.height + + renderNodeInfo.paddingBottom; } /** * Calculate layout for series node's core graph. Only called for an expanded @@ -4415,7 +4901,7 @@ var tf; } /** * Calculate layout for annotations of a given node. - * This will modify positions of the the given node and its annotations. + * This will modify positions of the given node and its annotations. * * @see tf.graph.render.Node and tf.graph.render.Annotation * for description of each property of each render node. @@ -4425,14 +4911,6 @@ var tf; // If the render node is an expanded metanode, then its annotations will not // be visible and we should skip the annotation calculations. if (renderNodeInfo.expanded) { - _.extend(renderNodeInfo, { - inboxWidth: 0, - inboxHeight: 0, - outboxWidth: 0, - outboxHeight: 0, - outerWidth: renderNodeInfo.width, - outerHeight: renderNodeInfo.height - }); return; } var inAnnotations = renderNodeInfo.inAnnotations.list; @@ -4442,23 +4920,13 @@ var tf; // Calculate size for out-annotations _.each(outAnnotations, function (a) { return sizeAnnotation(a); }); var params = layout.PARAMS.annotations; - renderNodeInfo.inboxWidth = - inAnnotations.length > 0 ? - _(inAnnotations).pluck("width").max() + - params.xOffset + params.labelWidth + params.labelOffset : - 0; - renderNodeInfo.outboxWidth = - outAnnotations.length > 0 ? - _(outAnnotations).pluck("width").max() + - params.xOffset + params.labelWidth + params.labelOffset : - 0; // Calculate annotation node position (a.dx, a.dy) // and total height for in-annotations // After this chunk of code: // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) var inboxHeight = _.reduce(inAnnotations, function (height, a, i) { var yOffset = i > 0 ? params.yOffset : 0; - a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset; + a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset; a.dy = height + yOffset + a.height / 2; return height + yOffset + a.height; }, 0); @@ -4473,7 +4941,7 @@ var tf; // (annotation.length - 1 * yOffset) var outboxHeight = _.reduce(outAnnotations, function (height, a, i) { var yOffset = i > 0 ? params.yOffset : 0; - a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset; + a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset; a.dy = height + yOffset + a.height / 2; return height + yOffset + a.height; }, 0); @@ -4500,7 +4968,7 @@ var tf; }, // The host node end { - dx: -renderNodeInfo.width / 2, + dx: -renderNodeInfo.coreBox.width / 2, // only use scale if there are more than one, // otherwise center it vertically dy: inAnnotations.length > 1 ? inY(i) : 0 @@ -4519,7 +4987,7 @@ var tf; a.points = [ // The host node end { - dx: renderNodeInfo.width / 2, + dx: renderNodeInfo.coreBox.width / 2, // only use scale if there are more than one, // otherwise center it vertically dy: outAnnotations.length > 1 ? outY(i) : 0 @@ -4531,9 +4999,7 @@ var tf; } ]; }); - renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth + - renderNodeInfo.outboxWidth; - renderNodeInfo.outerHeight = + renderNodeInfo.height = Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); } /** @@ -4563,11 +5029,92 @@ var tf; break; } } + /** + * Determines the center position of the node's shape. The position depends + * on if the node has in and out-annotations. + */ + function computeCXPositionOfNodeShape(renderInfo) { + if (renderInfo.expanded) { + return renderInfo.x; + } + var dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0; + return renderInfo.x - renderInfo.width / 2 + dx + + renderInfo.coreBox.width / 2; + } + layout.computeCXPositionOfNodeShape = computeCXPositionOfNodeShape; + /** Returns the angle (in degrees) between two points. */ + function angleBetweenTwoPoints(a, b) { + var dx = b.x - a.x; + var dy = b.y - a.y; + return 180 * Math.atan(dy / dx) / Math.PI; + } + /** + * Returns if a line going through the specified points is a straight line. + */ + function isStraightLine(points) { + var angle = angleBetweenTwoPoints(points[0], points[1]); + for (var i = 1; i < points.length - 1; i++) { + var newAngle = angleBetweenTwoPoints(points[i], points[i + 1]); + // Have a tolerance of 1 degree. + if (Math.abs(newAngle - angle) > 1) { + return false; + } + angle = newAngle; + } + return true; + } + /** + * Returns the intersection of a line between the provided point + * and the provided rectangle. + */ + function intersectPointAndNode(point, node) { + // cx and cy are the center of the rectangle. + var cx = node.expanded ? + node.x : computeCXPositionOfNodeShape(node); + var cy = node.y; + // Calculate the slope + var dx = point.x - cx; + var dy = point.y - cy; + var w = node.expanded ? node.width : node.coreBox.width; + var h = node.expanded ? node.height : node.coreBox.height; + var deltaX, deltaY; + if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) { + // The intersection is above or below the rectangle. + if (dy < 0) { + h = -h; + } + deltaX = dy === 0 ? 0 : h / 2 * dx / dy; + deltaY = h / 2; + } + else { + // The intersection is left or right of the rectangle. + if (dx < 0) { + w = -w; + } + deltaX = w / 2; + deltaY = dx === 0 ? 0 : w / 2 * dy / dx; + } + return { x: cx + deltaX, y: cy + deltaY }; + } })(layout = graph_1.layout || (graph_1.layout = {})); })(graph = tf.graph || (tf.graph = {})); })(tf || (tf = {})); // close module </script> -<script>var tf; +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +var tf; (function (tf) { /** * Mapping from color palette name to color pallette, which contains @@ -4699,7 +5246,21 @@ var tf; }, {}); })(tf || (tf = {})); </script> -<script>/// <reference path="../../../../typings/tsd.d.ts" /> +<script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../../../typings/tsd.d.ts" /> /// <reference path="../common.ts" /> var tf; (function (tf) { @@ -4759,6 +5320,9 @@ var tf; this.canvas = $minimap.select("canvas.first").node(); this.canvasBuffer = $minimap.select("canvas.second").node(); + this.downloadCanvas = + $minimap.select("canvas.download").node(); + d3.select(this.downloadCanvas).style("display", "none"); } /** * Updates the position and the size of the viewpoint rectangle. @@ -4782,6 +5346,11 @@ var tf; */ Minimap.prototype.update = function () { var _this = this; + var $download = d3.select("#graphdownload"); + this.download = $download.node(); + $download.on("click", function (d) { + _this.download.href = _this.downloadCanvas.toDataURL("image/png"); + }); var $svg = d3.select(this.svg); // Read all the style rules in the document and embed them into the svg. // The svg needs to be self contained, i.e. all the style rules need to be @@ -4815,7 +5384,8 @@ var tf; // Get the size of the entire scene. var sceneSize = this.zoomG.getBBox(); // Since we add padding, account for that here. - sceneSize.height += this.labelPadding; + sceneSize.height += this.labelPadding * 2; + sceneSize.width += this.labelPadding * 2; // Temporarily assign an explicit width/height to the main svg, since // it doesn't have one (uses flex-box), but we need it for the canvas // to work. @@ -4837,6 +5407,10 @@ var tf; // viewpoint rect. d3.select(this.minimapSvg).attr(this.minimapSize); d3.select(this.canvasBuffer).attr(this.minimapSize); + // Download canvas width and height are multiples of the style width and + // height in order to increase pixel density of the PNG for clarity. + d3.select(this.downloadCanvas).style({ width: sceneSize.width, height: sceneSize.height }); + d3.select(this.downloadCanvas).attr({ width: sceneSize.width * 3, height: sceneSize.height * 3 }); if (this.translate != null && this.zoom != null) { // Update the viewpoint rectangle shape since the aspect ratio of the // map has changed. @@ -4868,6 +5442,9 @@ var tf; _a = [_this.canvasBuffer, _this.canvas], _this.canvas = _a[0], _this.canvasBuffer = _a[1]; var _a; }); + var downloadContext = _this.downloadCanvas.getContext("2d"); + downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height); + downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height); }; image.src = "data:image/svg+xml;base64," + btoa(svgXml); }; @@ -4922,576 +5499,6 @@ var tf; })(scene = tf.scene || (tf.scene = {})); })(tf || (tf = {})); // close module tf.scene </script> -<script>/// <reference path="graph.ts" /> -/// <reference path="render.ts" /> -var tf; -(function (tf) { - var graph; - (function (graph_1) { - var layout; - (function (layout) { - /** Set of parameters that define the look and feel of the graph. */ - layout.PARAMS = { - animation: { - /** Default duration for graph animations in ms. */ - duration: 250 - }, - graph: { - /** Graph parameter for metanode. */ - meta: { - /** - * Dagre's nodesep param - number of pixels that - * separate nodes horizontally in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - nodeSep: 110, - /** - * Dagre's ranksep param - number of pixels - * between each rank in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - rankSep: 25 - }, - /** Graph parameter for metanode. */ - series: { - /** - * Dagre's nodesep param - number of pixels that - * separate nodes horizontally in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - nodeSep: 90, - /** - * Dagre's ranksep param - number of pixels - * between each rank in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - rankSep: 25, - }, - /** - * Padding is used to correctly position the graph SVG inside of its parent - * element. The padding amounts are applied using an SVG transform of X and - * Y coordinates. - */ - padding: { - paddingTop: 40, - paddingLeft: 20 - } - }, - subscene: { - meta: { - paddingTop: 10, - paddingBottom: 10, - paddingLeft: 10, - paddingRight: 10, - /** - * Used to leave room for the label on top of the highest node in - * the core graph. - */ - labelHeight: 20, - /** X-space between each extracted node and the core graph. */ - extractXOffset: 50, - /** Y-space between each extracted node. */ - extractYOffset: 20 - }, - series: { - paddingTop: 10, - paddingBottom: 10, - paddingLeft: 10, - paddingRight: 10, - labelHeight: 10 - } - }, - nodeSize: { - /** Size of meta nodes. */ - meta: { - radius: 5, - width: 60, - /** A scale for the node's height based on number of nodes inside */ - height: d3.scale.linear().domain([1, 200]).range([15, 60]).clamp(true), - /** The radius of the circle denoting the expand button. */ - expandButtonRadius: 3 - }, - /** Size of op nodes. */ - op: { - width: 15, - height: 6, - radius: 3, - labelOffset: -8 - }, - /** Size of series nodes. */ - series: { - expanded: { - // For expanded series nodes, width and height will be - // computed to account for the subscene. - radius: 10, - labelOffset: 0, - }, - vertical: { - // When unexpanded, series whose underlying metagraphs contain - // one or more non-control edges will show as a vertical stack - // of ellipses. - width: 16, - height: 13, - labelOffset: -13, - }, - horizontal: { - // When unexpanded, series whose underlying metagraphs contain - // no non-control edges will show as a horizontal stack of - // ellipses. - width: 24, - height: 8, - radius: 10, - labelOffset: -10, - }, - }, - /** Size of bridge nodes. */ - bridge: { - // NOTE: bridge nodes will normally be invisible, but they must - // take up some space so that the layout step leaves room for - // their edges. - width: 20, - height: 20, - radius: 2, - labelOffset: 0 - } - }, - shortcutSize: { - /** Size of shortcuts for op nodes */ - op: { - width: 10, - height: 4 - }, - /** Size of shortcuts for meta nodes */ - meta: { - width: 12, - height: 4, - radius: 1 - }, - /** Size of shortcuts for series nodes */ - series: { - width: 14, - height: 4, - } - }, - annotations: { - /** X-space between the shape and each annotation-node. */ - xOffset: 10, - /** Y-space between each annotation-node. */ - yOffset: 3, - /** X-space between each annotation-node and its label. */ - labelOffset: 2, - /** Estimate max width for annotation label */ - labelWidth: 35 - }, - constant: { - size: { - width: 4, - height: 4 - } - }, - series: { - /** Maximum number of repeated item for unexpanded series node. */ - maxStackCount: 3, - /** - * Positioning offset ratio for collapsed stack - * of parallel series (series without edges between its members). - */ - parallelStackOffsetRatio: 0.2, - /** - * Positioning offset ratio for collapsed stack - * of tower series (series with edges between its members). - */ - towerStackOffsetRatio: 0.5 - }, - minimap: { - /** The maximum width/height the minimap can have. */ - size: 150 - } - }; - /** Calculate layout for a scene of a group node. */ - function scene(renderNodeInfo) { - // Update layout, size, and annotations of its children nodes and edges. - if (renderNodeInfo.node.isGroupNode) { - layoutChildren(renderNodeInfo); - } - // Update position of its children nodes and edges - if (renderNodeInfo.node.type === graph_1.NodeType.META) { - layoutMetanode(renderNodeInfo); - } - else if (renderNodeInfo.node.type === graph_1.NodeType.SERIES) { - layoutSeriesNode(renderNodeInfo); - } - } - layout.scene = scene; - ; - /** - * Update layout, size, and annotations of its children nodes and edges. - */ - function layoutChildren(renderNodeInfo) { - var children = renderNodeInfo.coreGraph.nodes().map(function (n) { - return renderNodeInfo.coreGraph.node(n); - }).concat(renderNodeInfo.isolatedInExtract, renderNodeInfo.isolatedOutExtract); - _.each(children, function (childNodeInfo) { - // Set size of each child - switch (childNodeInfo.node.type) { - case graph_1.NodeType.OP: - _.extend(childNodeInfo, layout.PARAMS.nodeSize.op); - break; - case graph_1.NodeType.BRIDGE: - _.extend(childNodeInfo, layout.PARAMS.nodeSize.bridge); - break; - case graph_1.NodeType.META: - if (!childNodeInfo.expanded) { - // set fixed width and scalable height based on cardinality - _.extend(childNodeInfo, layout.PARAMS.nodeSize.meta); - childNodeInfo.height = - layout.PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); - } - else { - var childGroupNodeInfo = childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. - } - break; - case graph_1.NodeType.SERIES: - if (childNodeInfo.expanded) { - _.extend(childNodeInfo, layout.PARAMS.nodeSize.series.expanded); - var childGroupNodeInfo = childNodeInfo; - scene(childGroupNodeInfo); // Recursively layout its subscene. - } - else { - var childGroupNodeInfo = childNodeInfo; - var seriesParams = childGroupNodeInfo.node.hasNonControlEdges ? - layout.PARAMS.nodeSize.series.vertical : - layout.PARAMS.nodeSize.series.horizontal; - _.extend(childNodeInfo, seriesParams); - } - break; - default: - throw Error("Unrecognized node type: " + childNodeInfo.node.type); - } - // Layout each child's annotations - layoutAnnotation(childNodeInfo); - }); - } - /** - * Calculate layout for a graph using dagre - * @param graph the graph to be laid out - * @param params layout parameters - * @return width and height of the core graph - */ - function dagreLayout(graph, params) { - _.extend(graph.graph(), { - nodeSep: params.nodeSep, - rankSep: params.rankSep - }); - var bridgeNodeNames = []; - var nonBridgeNodeNames = []; - // Split out nodes into bridge and non-bridge nodes, and calculate the total - // width we should use for bridge nodes. - _.each(graph.nodes(), function (nodeName) { - var nodeInfo = graph.node(nodeName); - if (nodeInfo.node.type === graph_1.NodeType.BRIDGE) { - bridgeNodeNames.push(nodeName); - } - else { - nonBridgeNodeNames.push(nodeName); - } - }); - // If there are no non-bridge nodes, then the graph has zero size. - if (!nonBridgeNodeNames.length) { - return { - width: 0, - height: 0, - }; - } - dagre.layout(graph); - var graphLabel = graph.graph(); - // Calculate the true bounding box of the graph by iterating over nodes and - // edges rather than accepting dagre's word for it. In particular, we should - // ignore the extra-wide bridge nodes and bridge edges, and allow for - // annotation boxes and labels. - var minX = Infinity; - var minY = Infinity; - var maxX = -Infinity; - var maxY = -Infinity; - _.each(nonBridgeNodeNames, function (nodeName) { - var nodeInfo = graph.node(nodeName); - var w = 0.5 * nodeInfo.width; - var x1 = nodeInfo.x - w - nodeInfo.inboxWidth; - var x2 = nodeInfo.x + w + nodeInfo.outboxWidth; - minX = x1 < minX ? x1 : minX; - maxX = x2 > maxX ? x2 : maxX; - var labelLength = nodeName.length - nodeName.lastIndexOf(graph_1.NAMESPACE_DELIM); - // TODO(jimbo): Account for font width rather than using a magic number. - var charWidth = 3; // 3 pixels per character. - var lw = 0.5 * labelLength * charWidth; - var lx1 = nodeInfo.x - lw; - var lx2 = nodeInfo.x + lw; - minX = lx1 < minX ? lx1 : minX; - maxX = lx2 > maxX ? lx2 : maxX; - // TODO(jimbo): Account for the height of labels above op nodes here. - var h = 0.5 * nodeInfo.outerHeight; - var y1 = nodeInfo.y - h; - var y2 = nodeInfo.y + h; - minY = y1 < minY ? y1 : minY; - maxY = y2 > maxY ? y2 : maxY; - }); - _.each(graph.edges(), function (edgeObj) { - var renderMetaedgeInfo = graph.edge(edgeObj); - if (renderMetaedgeInfo.structural) { - return; // Skip structural edges from min/max calculations. - } - _.each(renderMetaedgeInfo.points, function (point) { - minX = point.x < minX ? point.x : minX; - maxX = point.x > maxX ? point.x : maxX; - minY = point.y < minY ? point.y : minY; - maxY = point.y > maxY ? point.y : maxY; - }); - }); - // Shift all nodes and edge points to account for the left-padding amount, - // and the invisble bridge nodes. - _.each(graph.nodes(), function (nodeName) { - var nodeInfo = graph.node(nodeName); - nodeInfo.x -= minX; - nodeInfo.y -= minY; - }); - _.each(graph.edges(), function (edgeObj) { - _.each(graph.edge(edgeObj).points, function (point) { - point.x -= minX; - point.y -= minY; - }); - }); - return { - width: maxX - minX, - height: maxY - minY, - }; - } - /** Layout a metanode. */ - function layoutMetanode(renderNodeInfo) { - // First, copy params specific to meta nodes onto this render info object. - var params = layout.PARAMS.subscene.meta; - renderNodeInfo = _.extend(renderNodeInfo, params); - // Invoke dagre.layout() on the core graph and record the bounding box - // dimensions. - _.extend(renderNodeInfo.coreBox, dagreLayout(renderNodeInfo.coreGraph, layout.PARAMS.graph.meta)); - // Calculate the position of nodes in isolatedInExtract relative to the - // top-left corner of inExtractBox (the bounding box for all inExtract nodes) - // and calculate the size of the inExtractBox. - var hasInExtract = renderNodeInfo.isolatedInExtract.length > 0; - renderNodeInfo.inExtractBox.width = hasInExtract ? - _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0; - renderNodeInfo.inExtractBox.height = - _.reduce(renderNodeInfo.isolatedInExtract, function (height, child, i) { - var yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.inExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; - }, 0); - // Calculate the position of nodes in isolatedOutExtract relative to the - // top-left corner of outExtractBox (the bounding box for all outExtract - // nodes) and calculate the size of the outExtractBox. - var hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0; - renderNodeInfo.outExtractBox.width = hasOutExtract ? - _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0; - renderNodeInfo.outExtractBox.height = - _.reduce(renderNodeInfo.isolatedOutExtract, function (height, child, i) { - var yOffset = i > 0 ? params.extractYOffset : 0; - // use outerWidth/Height here to avoid overlaps between extracts - child.x = renderNodeInfo.outExtractBox.width / 2; - child.y = height + yOffset + child.outerHeight / 2; - return height + yOffset + child.outerHeight; - }, 0); - // Determine the whole metanode's width (from left to right). - renderNodeInfo.width = - params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight + - (hasInExtract ? - renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) + - (hasOutExtract ? - params.extractXOffset + renderNodeInfo.outExtractBox.width : 0); - // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes. - // Determine the whole metanode's height (from top to bottom). - renderNodeInfo.height = - renderNodeInfo.labelHeight + - params.paddingTop + - Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height) + - params.paddingBottom; - } - /** - * Calculate layout for series node's core graph. Only called for an expanded - * series. - */ - function layoutSeriesNode(node) { - var graph = node.coreGraph; - var params = layout.PARAMS.subscene.series; - _.extend(node, params); - // Layout the core. - _.extend(node.coreBox, dagreLayout(node.coreGraph, layout.PARAMS.graph.series)); - _.each(graph.nodes(), function (nodeName) { - graph.node(nodeName).excluded = false; - }); - // Series do not have in/outExtractBox so no need to include them here. - node.width = node.coreBox.width + params.paddingLeft + params.paddingRight; - node.height = node.coreBox.height + params.paddingTop + params.paddingBottom; - } - /** - * Calculate layout for annotations of a given node. - * This will modify positions of the the given node and its annotations. - * - * @see tf.graph.render.Node and tf.graph.render.Annotation - * for description of each property of each render node. - * - */ - function layoutAnnotation(renderNodeInfo) { - // If the render node is an expanded metanode, then its annotations will not - // be visible and we should skip the annotation calculations. - if (renderNodeInfo.expanded) { - _.extend(renderNodeInfo, { - inboxWidth: 0, - inboxHeight: 0, - outboxWidth: 0, - outboxHeight: 0, - outerWidth: renderNodeInfo.width, - outerHeight: renderNodeInfo.height - }); - return; - } - var inAnnotations = renderNodeInfo.inAnnotations.list; - var outAnnotations = renderNodeInfo.outAnnotations.list; - // Calculate size for in-annotations - _.each(inAnnotations, function (a) { return sizeAnnotation(a); }); - // Calculate size for out-annotations - _.each(outAnnotations, function (a) { return sizeAnnotation(a); }); - var params = layout.PARAMS.annotations; - renderNodeInfo.inboxWidth = - inAnnotations.length > 0 ? - _(inAnnotations).pluck("width").max() + - params.xOffset + params.labelWidth + params.labelOffset : - 0; - renderNodeInfo.outboxWidth = - outAnnotations.length > 0 ? - _(outAnnotations).pluck("width").max() + - params.xOffset + params.labelWidth + params.labelOffset : - 0; - // Calculate annotation node position (a.dx, a.dy) - // and total height for in-annotations - // After this chunk of code: - // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) - var inboxHeight = _.reduce(inAnnotations, function (height, a, i) { - var yOffset = i > 0 ? params.yOffset : 0; - a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset; - a.dy = height + yOffset + a.height / 2; - return height + yOffset + a.height; - }, 0); - _.each(inAnnotations, function (a) { - a.dy -= inboxHeight / 2; - a.labelOffset = params.labelOffset; - }); - // Calculate annotation node position position (a.dx, a.dy) - // and total height for out-annotations - // After this chunk of code: - // outboxHeight = sum of annotation heights + - // (annotation.length - 1 * yOffset) - var outboxHeight = _.reduce(outAnnotations, function (height, a, i) { - var yOffset = i > 0 ? params.yOffset : 0; - a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset; - a.dy = height + yOffset + a.height / 2; - return height + yOffset + a.height; - }, 0); - _.each(outAnnotations, function (a) { - // adjust by (half of ) the total height - // so dy is relative to the host node's center. - a.dy -= outboxHeight / 2; - a.labelOffset = params.labelOffset; - }); - // Creating scales for touch point between the in-annotation edges - // and their hosts. - var inTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, inboxHeight / 2); - inTouchHeight = inTouchHeight < 0 ? 0 : inTouchHeight; - var inY = d3.scale.linear() - .domain([0, inAnnotations.length - 1]) - .range([-inTouchHeight, inTouchHeight]); - // Calculate annotation edge position - _.each(inAnnotations, function (a, i) { - a.points = [ - // The annotation node end - { - dx: a.dx + a.width / 2, - dy: a.dy - }, - // The host node end - { - dx: -renderNodeInfo.width / 2, - // only use scale if there are more than one, - // otherwise center it vertically - dy: inAnnotations.length > 1 ? inY(i) : 0 - } - ]; - }); - // Creating scales for touch point between the out-annotation edges - // and their hosts. - var outTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, outboxHeight / 2); - outTouchHeight = outTouchHeight < 0 ? 0 : outTouchHeight; - var outY = d3.scale.linear() - .domain([0, outAnnotations.length - 1]) - .range([-outTouchHeight, outTouchHeight]); - _.each(outAnnotations, function (a, i) { - // Add point from the border of the annotation node - a.points = [ - // The host node end - { - dx: renderNodeInfo.width / 2, - // only use scale if there are more than one, - // otherwise center it vertically - dy: outAnnotations.length > 1 ? outY(i) : 0 - }, - // The annotation node end - { - dx: a.dx - a.width / 2, - dy: a.dy - } - ]; - }); - renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth + - renderNodeInfo.outboxWidth; - renderNodeInfo.outerHeight = - Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); - } - /** - * Set size of an annotation node. - */ - function sizeAnnotation(a) { - switch (a.annotationType) { - case graph_1.render.AnnotationType.CONSTANT: - _.extend(a, layout.PARAMS.constant.size); - break; - case graph_1.render.AnnotationType.SHORTCUT: - if (a.node.type === graph_1.NodeType.OP) { - _.extend(a, layout.PARAMS.shortcutSize.op); - } - else if (a.node.type === graph_1.NodeType.META) { - _.extend(a, layout.PARAMS.shortcutSize.meta); - } - else if (a.node.type === graph_1.NodeType.SERIES) { - _.extend(a, layout.PARAMS.shortcutSize.series); - } - else { - throw Error("Invalid node type: " + a.node.type); - } - break; - case graph_1.render.AnnotationType.SUMMARY: - _.extend(a, layout.PARAMS.constant.size); - break; - } - } - })(layout = graph_1.layout || (graph_1.layout = {})); - })(graph = tf.graph || (tf.graph = {})); -})(tf || (tf = {})); // close module -</script> @@ -5499,9 +5506,23 @@ var tf; -</head><body><div hidden="" by-vulcanize=""><dom-module id="tf-data-coordinator" assetpath="../components/tf-event-dashboard/"> - <script>/// <reference path="../../typings/tsd.d.ts" /> -/// <reference path="../../bower_components/plottable/plottable.d.ts" /> +</head><body><div hidden="" by-vulcanize=""><dom-module id="tf-data-coordinator" assetpath="../tf-event-dashboard/"> + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../typings/tsd.d.ts" /> +/// <reference path="../plottable/plottable.d.ts" /> var TF; (function (TF) { /* The DataCoordinator generates TF.Datasets for each run/tag combination, @@ -5553,14 +5574,28 @@ var TF; TF.DataCoordinator = DataCoordinator; })(TF || (TF = {})); </script> - <script>/// <reference path="../../typings/tsd.d.ts" /> -/// <reference path="../../bower_components/plottable/plottable.d.ts" /> + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ var __extends = (this && this.__extends) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; function __() { this.constructor = d; } __.prototype = b.prototype; d.prototype = new __(); }; +/// <reference path="../../typings/tsd.d.ts" /> +/// <reference path="../plottable/plottable.d.ts" /> var TF; (function (TF) { /* An extension of Plottable.Dataset that knows how to load data from a backend. @@ -5615,7 +5650,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-tooltip-coordinator" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-tooltip-coordinator" assetpath="../tf-event-dashboard/"> <script> Polymer({ is: "tf-tooltip-coordinator", @@ -5654,7 +5689,7 @@ var TF; }); </script> </dom-module> -<dom-module id="scrollbar-style" assetpath="../components/tf-dashboard-common/"> +<dom-module id="scrollbar-style" assetpath="../tf-dashboard-common/"> <template> <style> .scrollbar::-webkit-scrollbar-track @@ -5680,7 +5715,7 @@ var TF; </style> </template> </dom-module> -<dom-module id="run-color-style" assetpath="../components/tf-dashboard-common/"> +<dom-module id="run-color-style" assetpath="../tf-dashboard-common/"> <template> <style> [color-class="light-blue"] paper-checkbox { @@ -5740,7 +5775,7 @@ var TF; </style> </template> </dom-module> -<dom-module id="tf-multi-checkbox" assetpath="../components/tf-multi-checkbox/"> +<dom-module id="tf-multi-checkbox" assetpath="../tf-multi-checkbox/"> <style include="scrollbar-style"></style> <style include="run-color-style"></style> @@ -5917,7 +5952,7 @@ var TF; </script> </dom-module> -<dom-module id="tf-run-selector" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-run-selector" assetpath="../tf-event-dashboard/"> <template> <div id="top-text"> <template is="dom-if" if="[[xValue]]"> @@ -5927,12 +5962,15 @@ var TF; </div> </template> <template is="dom-if" if="[[!xValue]]"> - <div id="tooltip-help" class="tooltip-container"> - Selected Runs: - </div> + <h3 id="tooltip-help" class="tooltip-container"> + Runs + </h3> </template> </div> <tf-multi-checkbox names="[[runs]]" tooltips="[[tooltips]]" highlights="[[_arrayify(closestRun)]]" out-selected="{{outSelected}}" class-scale="[[classScale]]" hide-missing-tooltips=""></tf-multi-checkbox> + <paper-button class="x-button" id="toggle-all" on-tap="_toggleAll"> + Toggle All Runs + </paper-button> <style> :host { display: flex; @@ -5944,7 +5982,6 @@ var TF; width: 100%; flex-grow: 0; flex-shrink: 0; - padding-left: 35px; padding-right: 16px; padding-bottom: 6px; box-sizing: border-box; @@ -5956,6 +5993,12 @@ var TF; flex-shrink: 1; height: 0px; /* hackhack So the flex-grow takes over and gives it space */ } + .x-button { + font-size: 13px; + background-color: var(--tb-ui-light-accent); + margin-top: 5px; + color: var(--tb-ui-dark-accent); + } .x-tooltip { display: flex; flex-direction: row; @@ -5967,6 +6010,16 @@ var TF; .x-tooltip-value { align-self: flex-end; } + #tooltip-help { + color: var(--paper-grey-800); + margin: 0; + font-weight: normal; + font-size: 14px; + margin-bottom: 5px; + } + paper-button { + margin-left: 0; + } </style> </template> <script> @@ -5982,17 +6035,24 @@ var TF; classScale: Object, // map from run name to color class (css) closestRun: {type: String, value: null}, // which run has a value closest to mouse coordinate }, + _toggleAll: function() { + if (this.outSelected.length > 0) { + this.outSelected = []; + } else { + this.outSelected = this.runs.slice(); + } + }, _arrayify: function(item) { return [item]; }, }); </script> </dom-module> -<dom-module id="tf-x-type-selector" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-x-type-selector" assetpath="../tf-event-dashboard/"> <template> <div id="buttons"> - <p>X Type: </p> - <paper-button class="x-button selected" id="step" on-tap="_select" raised=""> + <h3>Horizontal Axis</h3> + <paper-button class="x-button selected" id="step" on-tap="_select"> step </paper-button> <paper-button class="x-button" id="relative" on-tap="_select"> @@ -6004,22 +6064,28 @@ var TF; </div> <style> .x-button { - width: 29%; - font-size: 14px; - background-color: var(--paper-grey-500); - margin-top: 5px; - color: white; + width: 30%; + font-size: 13px; + background: none; + margin-top: 10px; + color: var(--tb-ui-dark-accent); + } + + .x-button:first-of-type { + margin-left: 0; } .x-button.selected { - font-weight: bold; - background-color: var(--tb-orange-strong) !important; + background-color: var(--tb-ui-dark-accent); + color: white!important; } - #buttons p { - text-align: center; - font-size: 12px; + #buttons h3 { + color: var(--paper-grey-800); margin: 0; + font-weight: normal; + font-size: 14px; + margin-bottom: 5px; } </style> </template> @@ -6032,17 +6098,15 @@ var TF; _select: function(e) { var _this = this; ["step", "wall_time", "relative"].forEach(function(id) { - _this.$[id].raised = false; _this.$[id].classList.remove("selected"); }); - e.currentTarget.raised = true; this._setOutXType(e.currentTarget.id); e.currentTarget.classList.add("selected"); }, }); </script> </dom-module> -<dom-module id="tf-run-generator" assetpath="../components/tf-dashboard-common/"> +<dom-module id="tf-run-generator" assetpath="../tf-dashboard-common/"> <template> <iron-ajax id="ajax" auto="" url="[[url]]" handle-as="json" debounce="300" on-response="_setResponse" verbose="true"> </iron-ajax> @@ -6119,7 +6183,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-color-scale" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-color-scale" assetpath="../tf-event-dashboard/"> <script> (function() { // TODO(danmane) - get Plottable team to make an API point for this @@ -6176,9 +6240,23 @@ var TF; })(); </script> </dom-module> -<dom-module id="tf-url-generator" assetpath="../components/tf-dashboard-common/"> - <script>/// <reference path="../../typings/tsd.d.ts" /> -/// <reference path="../../bower_components/plottable/plottable.d.ts" /> +<dom-module id="tf-url-generator" assetpath="../tf-dashboard-common/"> + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../typings/tsd.d.ts" /> +/// <reference path="../plottable/plottable.d.ts" /> var TF; (function (TF) { var Urls; @@ -6252,7 +6330,7 @@ var TF; Polymer(polymerObject); </script> </dom-module> -<dom-module id="tf-dashboard-layout" assetpath="../components/tf-dashboard-common/"> +<dom-module id="tf-dashboard-layout" assetpath="../tf-dashboard-common/"> <template> <div id="sidebar"> <content select=".sidebar"></content> @@ -6266,23 +6344,22 @@ var TF; #sidebar { width: inherit; height: 100%; - background-color: var(--tb-grey-darker); - background-image: linear-gradient(to right, var(--tb-grey-lighter), var(--tb-grey-lighter)); overflow: ellipsis; - padding-left: 10px; - padding-right: 10px; flex-grow: 0; flex-shrink: 0; } #center { - margin: 0 10px; height: 100%; overflow-y: scroll; - padding-right: 12px; flex-grow: 1; flex-shrink: 1; } + + .tf-graph-dashboard #center { + background: white; + } + :host { display: flex; flex-direction: row; @@ -6296,7 +6373,7 @@ var TF; }); </script> </dom-module> -<dom-module id="dashboard-style" assetpath="../components/tf-dashboard-common/"> +<dom-module id="dashboard-style" assetpath="../tf-dashboard-common/"> <template> <style> .card { @@ -6304,10 +6381,8 @@ var TF; width: 300px; display: flex; flex-direction: column; - margin: 5px 5px; - padding: 5px; - border: 1px solid var(--paper-grey-500); - border-radius: 3px; + margin: 5px; + padding: 0 30px 30px 0; -webkit-user-select: none; -moz-user-select: none; position: relative; @@ -6316,9 +6391,8 @@ var TF; .card .card-title { flex-grow: 0; flex-shrink: 0; - margin-bottom: 2px; + margin-bottom: 10px; font-size: 14px; - font-weight: bold; text-overflow: ellipsis; overflow: hidden; } @@ -6347,7 +6421,7 @@ var TF; .expand-button { position: absolute; left: 0px; - bottom: 0px; + bottom: 20px; color: #2196F3; display: block; } @@ -6360,6 +6434,7 @@ var TF; display: flex; flex-direction: column; height: 100%; + margin-right: 20px; } #categorizer { @@ -6376,21 +6451,30 @@ var TF; flex-grow: 1; } - #download-option { - padding-left: 55px; - color: var(--paper-grey-700); - font-size: 14px; + .sidebar-section { + border-top: solid 1px rgba(0, 0, 0, 0.12); + padding: 20px 0px 20px 30px; } - #download-option paper-toggle-button { - --paper-toggle-button-checked-button-color: var(--tb-orange-strong); - --paper-toggle-button-checked-bar-color: var(--tb-orange-weak); + .sidebar-section:first-child { + border: none; + } + .sidebar-section:last-child { + flex-grow: 1; + display: flex; } + + paper-checkbox { + --paper-checkbox-checked-color: var(--tb-ui-dark-accent); + --paper-checkbox-unchecked-color: var(--tb-ui-dark-accent); + font-size: 14px; + } + </style> </template> </dom-module> -<dom-module id="tf-downloader" assetpath="../components/tf-dashboard-common/"> +<dom-module id="tf-downloader" assetpath="../tf-dashboard-common/"> <template> <paper-dropdown-menu no-label-float="true" label="run to download" selected-item-label="{{_run}}"> <paper-menu class="dropdown-content"> @@ -6460,40 +6544,46 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-regex-group" assetpath="../components/tf-regex-group/"> +<dom-module id="tf-regex-group" assetpath="../tf-regex-group/"> <template> <div class="regex-list"> <template is="dom-repeat" items="{{rawRegexes}}"> <div class="regex-line"> - <paper-input id="text-input" class="regex-input" label="input new regex" no-label-float="" bind-value="{{item.regex}}" invalid="[[!item.valid]]" on-keyup="moveFocus"></paper-input> - <paper-toggle-button class="active-button" checked="{{item.active}}" disabled="[[!item.valid]]"></paper-toggle-button> - - <paper-icon-button icon="delete" class="delete-button" aria-label="Delete Regex" tabindex="0" on-tap="deleteRegex"></paper-icon-button> + <paper-checkbox class="active-button" checked="{{item.active}}" disabled="[[!item.valid]]"></paper-checkbox> + <paper-input id="text-input" class="regex-input" label="Regex filter" no-label-float="" bind-value="{{item.regex}}" invalid="[[!item.valid]]" on-keyup="moveFocus"></paper-input> + <paper-icon-button icon="close" class="delete-button" aria-label="Delete Regex" tabindex="0" on-tap="deleteRegex"></paper-icon-button> </div> <style> .regex-input { - width: 210px; + width: 230px; display: inline-block; - padding-left: 8px; - padding-right: 5px; + margin-left: -3px; } - .active-button { - --paper-toggle-button-checked-button-color: var(--tb-orange-strong); - --paper-toggle-button-checked-bar-color: var(--tb-orange-weak); - border: none; + paper-checkbox { + --paper-checkbox-checked-color: var(--tb-ui-dark-accent); + --paper-checkbox-unchecked-color: var(--tb-ui-dark-accent); } .delete-button { - color: var(--paper-pink-900); - width: 24px; - height: 24px; + color: var(--paper-grey-700); + width: 40px; + height: 40px; + margin-right: -10px; } + .regex-list { margin-bottom: 10px; } + paper-input { --paper-input-container-focus-color: var(--tb-orange-strong); + --paper-input-container-input: { + font-size: 14px; + }; + --paper-input-container-label: { + font-size: 14px; + }; } </style> </template> @@ -6566,38 +6656,44 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-categorizer" assetpath="../components/tf-categorizer/"> +<dom-module id="tf-categorizer" assetpath="../tf-categorizer/"> <template> <div class="inputs"> <tf-regex-group id="regex-group" regexes="{{regexes}}"></tf-regex-group> </div> <div id="underscore-categorization"> - <span>Split On Underscores:</span> - <paper-toggle-button checked="{{splitOnUnderscore}}"></paper-toggle-button> + <paper-checkbox checked$="{{splitOnUnderscore}}">Split on underscores</paper-checkbox> </div> <style> :host { display: block; - padding-bottom: 5px; - padding-top: 5px; + padding-bottom: 15px; } - - .inputs { - padding-left: 5px; - } - - paper-toggle-button { - --paper-toggle-button-checked-button-color: var(--tb-orange-strong); - --paper-toggle-button-checked-bar-color: var(--tb-orange-weak); + paper-checkbox { + --paper-checkbox-checked-color: var(--paper-grey-600); + --paper-checkbox-unchecked-color: var(--paper-grey-600); + font-size: 14px; } #underscore-categorization { - padding-left: 94px; color: var(--paper-grey-700); - font-size: 14px; } </style> </template> - <script>/// <reference path="../../typings/tsd.d.ts" /> + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// <reference path="../../typings/tsd.d.ts" /> var Categorizer; (function (Categorizer) { /* Canonical TensorFlow ops are namespaced using forward slashes. @@ -6738,7 +6834,7 @@ var Categorizer; }); </script> </dom-module> -<dom-module id="tf-chart" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-chart" assetpath="../tf-event-dashboard/"> <template> <svg id="chartsvg"></svg> <style> @@ -6761,7 +6857,21 @@ var Categorizer; } </style> </template> - <script>var __extends = (this && this.__extends) || function (d, b) { + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +var __extends = (this && this.__extends) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; function __() { this.constructor = d; } __.prototype = b.prototype; @@ -6901,14 +7011,28 @@ var Plottable; Plottable.DragZoomLayer = DragZoomLayer; })(Plottable || (Plottable = {})); </script> - <script>/// <reference path="../../typings/tsd.d.ts" /> -/// <reference path="../../bower_components/plottable/plottable.d.ts" /> + <script>/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ var __extends = (this && this.__extends) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; function __() { this.constructor = d; } __.prototype = b.prototype; d.prototype = new __(); }; +/// <reference path="../../typings/tsd.d.ts" /> +/// <reference path="../plottable/plottable.d.ts" /> var TF; (function (TF) { var Y_TOOLTIP_FORMATTER_PRECISION = 4; @@ -7238,13 +7362,13 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-collapsable-pane" assetpath="../components/tf-collapsable-pane/"> +<dom-module id="tf-collapsable-pane" assetpath="../tf-collapsable-pane/"> <template> <button class="heading" on-tap="togglePane" open-button$="[[opened]]"> <span class="name">[[name]]</span> <span class="hackpadding"></span> <span class="count"> - (<span>[[count]]</span>) + <span>[[count]]</span> </span> </button> <iron-collapse opened="[[opened]]"> @@ -7255,47 +7379,63 @@ var TF; </div> </iron-collapse> <style> + :host { + display: block; + margin: 0 5px 1px 10px; + } + + :host:first-of-type { + margin-top: 20px; + } + + :host:last-of-type { + margin-bottom: 20px; + } + .heading { - margin-top: 10px; - padding-left: 15px; - background-color: #f3f3f3; - border: 1px solid #dedede; - border-radius: 5px; - font-size: 18px; + background-color: white; + border-radius: 2px; + border: none; cursor: pointer; -webkit-tap-highlight-color: rgba(0,0,0,0); width: 100%; - height: 30px; box-sizing: border-box; - font-size: 16px; + font-size: 15px; display: inline-flex; flex-direction: row; align-items: center; justify-content: space-between; line-height: 1; - padding-top: 2px; - padding-bottom: 2px; + box-shadow: 0 1px 5px rgba(0,0,0,0.2); + padding: 10px 15px; } .content { padding: 15px; border: 1px solid #dedede; border-top: none; - border-bottom-left-radius: 5px; - border-bottom-right-radius: 5px; + border-bottom-left-radius: 2px; + border-bottom-right-radius: 2px; + background: white; } + [open-button] { border-bottom-left-radius: 0px !important; border-bottom-right-radius: 0px !important; } + .name { flex-grow: 0; } + .count { flex-grow: 0; float: right; + margin-right: 5px; font-size: 12px; + color: var(--paper-grey-500); } + .hackpadding { /* An obnoxious hack, but I can't get justify-content: space-between to work */ flex-grow: 1; @@ -7321,7 +7461,7 @@ var TF; </script> </dom-module> -<dom-module id="warning-style" assetpath="../components/tf-dashboard-common/"> +<dom-module id="warning-style" assetpath="../tf-dashboard-common/"> <template> <style> .warning { @@ -7331,7 +7471,7 @@ var TF; </style> </template> </dom-module> -<dom-module id="tf-event-dashboard" assetpath="../components/tf-event-dashboard/"> +<dom-module id="tf-event-dashboard" assetpath="../tf-event-dashboard/"> <template> <div id="plumbing"> <tf-url-generator out-runs-url="{{runsUrl}}" out-scalars-url-generator="{{scalarsUrlGen}}" id="urlGenerator"></tf-url-generator> @@ -7348,16 +7488,16 @@ var TF; <tf-dashboard-layout> <div class="sidebar"> - <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer> - <span id="download-option"> - Show Data Download Links: - <paper-toggle-button checked="{{_show_download_links}}"></paper-toggle-button> - </span> - - <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector> - - <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector> - + <div class="sidebar-section"> + <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer> + <paper-checkbox id="download-option" checked$="{{_show_download_links}}">Data download links</paper-checkbox> + </div> + <div class="sidebar-section"> + <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector> + </div> + <div class="sidebar-section"> + <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector> + </div> </div> <div class="center"> <template is="dom-if" if="[[!categories.length]]"> @@ -7442,7 +7582,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-histogram-dashboard" assetpath="../components/tf-histogram-dashboard/"> +<dom-module id="tf-histogram-dashboard" assetpath="../tf-histogram-dashboard/"> <template> <div id="plumbing"> <tf-url-generator out-runs-url="{{runsUrl}}" out-compressed-histograms-url-generator="{{compressedHistogramsUrlGen}}" id="urlGenerator"></tf-url-generator> @@ -7458,13 +7598,15 @@ var TF; <tf-dashboard-layout> <div class="sidebar"> - - <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer> - - <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector> - - <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector> - + <div class="sidebar-section"> + <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer> + </div> + <div class="sidebar-section"> + <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector> + </div> + <div class="sidebar-section"> + <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector> + </div> </div> <div class="center"> @@ -7563,7 +7705,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-image-loader" assetpath="../components/tf-image-dashboard/"> +<dom-module id="tf-image-loader" assetpath="../tf-image-dashboard/"> <style> :host { display: block; @@ -7610,7 +7752,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-image-grid" assetpath="../components/tf-image-dashboard/"> +<dom-module id="tf-image-grid" assetpath="../tf-image-dashboard/"> <template> <style include="scrollbar-style"></style> <div id="fullContainer" class="container scrollbar"> @@ -7725,7 +7867,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-image-dashboard" assetpath="../components/tf-image-dashboard/"> +<dom-module id="tf-image-dashboard" assetpath="../tf-image-dashboard/"> <template> <div id="plumbing"> <tf-url-generator out-runs-url="{{runsUrl}}" out-images-url-generator="{{imagesUrlGen}}" out-individual-image-url-generator="{{individualImageUrlGen}}" id="urlGenerator"></tf-url-generator> @@ -7783,7 +7925,7 @@ var TF; }); </script> </dom-module> -<dom-module id="tf-graph-loader" assetpath="../components/tf-graph-loader/"> +<dom-module id="tf-graph-loader" assetpath="../tf-graph-loader/"> </dom-module> <script> @@ -7906,7 +8048,9 @@ Polymer({ } var hierarchyParams = { verifyTemplate: true, - groupSeries: true, + // If a set of numbered op nodes has at least this number of nodes + // then group them into a series node. + seriesNodeMinSize: 5, }; var hierarchyTracker = tf.getSubtaskTracker(tracker, 50, 'Namespace hierarchy'); @@ -7950,7 +8094,7 @@ Polymer({ } }); </script> -<dom-module id="tf-graph-style" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-style" assetpath="../tf-graph/"> <template> <style> :host { @@ -8289,7 +8433,7 @@ Polymer({ </style> </template> </dom-module> -<dom-module id="tf-graph-minimap" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-minimap" assetpath="../tf-graph/"> <template> <style> :host { @@ -8334,6 +8478,7 @@ svg { <canvas class="first"></canvas> <canvas class="second"></canvas> +<canvas class="download"></canvas> </template> <script> Polymer({ @@ -8356,7 +8501,7 @@ Polymer({ }); </script> </dom-module> -<dom-module id="tf-graph-scene" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-scene" assetpath="../tf-graph/"> <template> <style include="tf-graph-style"> :host { @@ -8419,6 +8564,9 @@ Polymer({ <use xlink:href="#op-node-annotation-stamp" x="7" y="2"></use> <use xlink:href="#op-node-annotation-stamp" x="5" y="2"></use> </g> + <svg id="summary-icon" fill="#848484" height="12" viewBox="0 0 24 24" width="12"> + <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z"></path> + </svg> <g id="linearGradients"></g> </defs> @@ -8433,7 +8581,7 @@ Polymer({ Polymer({ is: 'tf-graph-scene', properties: { - graphHierarchy: Object, + renderHierarchy: Object, name: String, colorBy: { type: String, @@ -8473,9 +8621,10 @@ Polymer({ /** * @type {d3.scale.ordinal} * Scale mapping from template name to a number between 0 and N-1 - * where N is the number of different template names. + * where N is the number of different template names. Used by + * tf.graph.scene.node when computing node color by structure. */ - templateIndex: Object, + templateIndex: Function, /** * @type {tf.scene.Minimap} * A minimap object to notify for zoom events. @@ -8537,16 +8686,17 @@ Polymer({ progress: Object }, observers: [ - '_buildAndFit(graphHierarchy)' + '_buildAndFit(renderHierarchy)' ], getNode: function(nodeName) { - return this.graphHierarchy.getRenderNodeByName(nodeName); + return this.renderHierarchy.getRenderNodeByName(nodeName); }, isNodeExpanded: function(node) { return node.expanded; }, setNodeExpanded: function(renderNode) { - this._build(this.graphHierarchy); + this._build(this.renderHierarchy); + this._updateLabels(!this._zoomed); }, /** * Resets the state of the component. Called whenever the whole graph @@ -8565,20 +8715,15 @@ Polymer({ .selectAll('*').remove(); }, /** Main method for building the scene */ - _build: function(graphHierarchy) { - if (!graphHierarchy) { return; } //handle untruthy input - var templateNames = d3.keys(graphHierarchy.hierarchy.templates); - - this.templateIndex = d3.scale.ordinal() - .domain(templateNames) - .range(d3.range(0, templateNames.length)); + _build: function(renderHierarchy) { + this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex(); tf.time('tf-graph-scene (layout):', function() { // layout the scene for this meta / series node - tf.graph.layout.scene(graphHierarchy.root, this); + tf.graph.layout.layoutScene(renderHierarchy.root, this); }.bind(this)); tf.time('tf-graph-scene (build scene):', function() { - tf.graph.scene.buildGroup(d3.select(this.$.root), graphHierarchy.root, this); + tf.graph.scene.buildGroup(d3.select(this.$.root), renderHierarchy.root, this); tf.graph.scene.addGraphClickListener(this.$.svg, this); }.bind(this)); // Update the minimap again when the graph is done animating. @@ -8641,21 +8786,24 @@ Polymer({ tf.graph.layout.PARAMS.minimap.size, tf.graph.layout.PARAMS.subscene.meta.labelHeight); }, - _buildAndFit: function(graphHierarchy) { + _buildAndFit: function(renderHierarchy) { this._resetState(); - this._build(graphHierarchy); + this._build(renderHierarchy); // Fit to screen after the graph is done animating. setTimeout(this.fit.bind(this), tf.graph.layout.PARAMS.animation.duration); }, _updateLabels: function(showLabels) { var titleStyle = this.getElementsByClassName('title')[0].style; var auxTitleStyle = this.getElementsByClassName('auxTitle')[0].style; - var core = this.getElementsByClassName(tf.graph.scene.Class.Scene.CORE)[0]; + var core = d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." + + tf.graph.scene.Class.Scene.CORE)[0][0]; // Only show labels if the graph is fully loaded. if (showLabels && core && this.progress && this.progress.value === 100) { var aux = - this.getElementsByClassName(tf.graph.scene.Class.Scene.INEXTRACT)[0] || - this.getElementsByClassName(tf.graph.scene.Class.Scene.OUTEXTRACT)[0]; + d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." + + tf.graph.scene.Class.Scene.INEXTRACT)[0][0] || + d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." + + tf.graph.scene.Class.Scene.OUTEXTRACT)[0][0]; var coreX = core.getCTM().e; var auxX = aux ? aux.getCTM().e : null; titleStyle.display = 'inline'; @@ -8760,7 +8908,7 @@ Polymer({ } // Update the minimap to reflect the highlighted (selected) node. this.minimap.update(); - var node = this.graphHierarchy.hierarchy.node(selectedNode); + var node = this.renderHierarchy.hierarchy.node(selectedNode); var nodeParents = []; // Create list of all metanode parents of the selected node. while (node.parentNode != null @@ -8771,8 +8919,8 @@ Polymer({ // Ensure each parent metanode is built and expanded. var topParentNodeToBeExpanded; _.forEachRight(nodeParents, function(parentName) { - this.graphHierarchy.buildSubhierarchy(parentName); - var renderNode = this.graphHierarchy.getRenderNodeByName(parentName); + this.renderHierarchy.buildSubhierarchy(parentName); + var renderNode = this.renderHierarchy.getRenderNodeByName(parentName); if (renderNode.node.isGroupNode && !renderNode.expanded) { renderNode.expanded = true; if (!topParentNodeToBeExpanded) { @@ -8812,7 +8960,7 @@ Polymer({ }, }); </script> -<dom-module id="tf-graph-params" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-params" assetpath="../tf-graph/"> </dom-module> <script> Polymer({ @@ -8878,7 +9026,7 @@ Polymer({ */ detachAllEdgesForHighDegree: { type: Boolean, - value: false + value: true }, /** @@ -8921,12 +9069,14 @@ Polymer({ } }); </script> -<dom-module id="tf-graph" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph" assetpath="../tf-graph/"> <template> <style> .container { width: 100%; height: 100%; + background: white; + box-shadow: 0 1px 5px rgba(0,0,0,0.2); } .vertical { @@ -8952,7 +9102,7 @@ paper-button { <tf-graph-params id="graphParams"></tf-graph-params> <div class="vertical"> <h2>[[title]]</h2> - <tf-graph-scene id="scene" class="auto" graph-hierarchy="[[_renderHierarchy]]" highlighted-node="[[_getVisible(highlightedNode)]]" selected-node="[[selectedNode]]" color-by="[[colorBy]]" name="[[graphName]]" progress="[[progress]]"></tf-graph-scene> + <tf-graph-scene id="scene" class="auto" render-hierarchy="[[renderHierarchy]]" highlighted-node="[[_getVisible(highlightedNode)]]" selected-node="[[selectedNode]]" color-by="[[colorBy]]" name="[[graphName]]" progress="[[progress]]"></tf-graph-scene> </div> </div> </template> @@ -8985,6 +9135,11 @@ Polymer({ notify: true, readOnly: true, // Produces and doesn't consume. }, + renderHierarchy: { + type: Object, + readOnly: true, + notify: true, + }, // internal properties _graphParams: { type: Object, @@ -8996,27 +9151,24 @@ Polymer({ type: Number, value: 1 }, - _renderHierarchy: { - type: Object, - readOnly: true, - notify: true, - computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)' - }, _allowGraphSelect: { type: Boolean, value: true } }, + observers: [ + '_buildRenderHierarchy(graphHierarchy, _graphParams)' + ], _buildRenderHierarchy: function(graphHierarchy, params) { - return tf.time('new tf.graph.render.Hierarchy', function() { + tf.time('new tf.graph.render.Hierarchy', function() { if (graphHierarchy.root.type !== tf.graph.NodeType.META) { // root must be metanode but sometimes Polymer's dom-if has not // remove tf-graph element yet in <tf-node-info> // and thus mistakenly pass non-metanode to this module. return; } - var renderGraph = new tf.graph.render.RenderGraphInformation( - graphHierarchy, params); + var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy, + params); // Producing the 'color by' parameters to be consumed // by the tf-graph-controls panel. It contains information about the // min and max values and their respective colors, as well as list @@ -9042,14 +9194,14 @@ Polymer({ }; }) }); - return renderGraph; + this._setRenderHierarchy(renderGraph); }.bind(this)); }, _getVisible: function(name) { if (!name) { return name; } - return this._renderHierarchy.getNearestVisibleAncestor(name); + return this.renderHierarchy.getNearestVisibleAncestor(name); }, listeners: { 'graph-select': '_graphSelected', @@ -9060,6 +9212,7 @@ Polymer({ 'node-select': '_nodeSelected', 'node-highlight': '_nodeHighlighted', 'node-unhighlight': '_nodeUnhighlighted', + 'node-toggle-extract': '_nodeToggleExtract', // Annotations @@ -9110,53 +9263,72 @@ Polymer({ }, _nodeToggleExpand: function(event) { var nodeName = event.detail.name; - var renderNode = this._renderHierarchy.getRenderNodeByName(nodeName); + var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); // Op nodes are not expandable. if (renderNode.node.type === tf.graph.NodeType.OP) { return; } - this._renderHierarchy.buildSubhierarchy(nodeName); + this.renderHierarchy.buildSubhierarchy(nodeName); renderNode.expanded = !renderNode.expanded; this.querySelector('#scene').setNodeExpanded(renderNode); // Also select the expanded node. this._nodeSelected(event); }, + _nodeToggleExtract: function(event) { + // Toggle the include setting of the specified node appropriately. + var nodeName = event.detail.name; + var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + if (renderNode.node.include == tf.graph.InclusionType.INCLUDE) { + renderNode.node.include = tf.graph.InclusionType.EXCLUDE; + } else if (renderNode.node.include == tf.graph.InclusionType.EXCLUDE) { + renderNode.node.include = tf.graph.InclusionType.INCLUDE; + } else { + renderNode.node.include = + this.renderHierarchy.isNodeAuxilliary(renderNode) + ? tf.graph.InclusionType.INCLUDE : tf.graph.InclusionType.EXCLUDE; + } + + // Rebuild the render hierarchy. + this._buildRenderHierarchy(this.graphHierarchy, this._graphParams); + }, not: function(x) { return !x; } }); </script> -<dom-module id="tf-graph-icon" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-icon" assetpath="../tf-graph/"> <template> <template is="dom-if" if="[[_isType(node, type, 'OP')]]"> <template is="dom-if" if="[[_isConst(node, const)]]"> <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 10 10"> - <circle fill="white" stroke="#848484" cx="5" cy="5" r="3"></circle> + <circle cx="5" cy="5" r="3" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]"></circle> </svg> </template> <template is="dom-if" if="[[_isSummary(node, summary)]]"> - <img height$="[[height]]" src="[[resolveUrl('../../lib/svg/summary-icon.svg')]]"> + <svg width$="[[height]]" height$="[[height]]" viewBox="0 0 12 12"> + <use x="0" y="0" xlink:href="#summary-icon"></use> + </svg> </template> <template is="dom-if" if="[[_isRegularOp(node, const, summary)]]"> <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 8"> - <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill="white" stroke="#ccc" x="8" y="4"></use> + <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]" x="8" y="4"></use> </svg> </template> </template> <template is="dom-if" if="[[_isType(node, type, 'META')]]"> <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 37 16"> - <rect x="1" y="1" fill="#d9d9d9" stroke="#ccc" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect> + <rect x="1" y="1" fill$="[[_getFill(_computedFill, 'META')]]" stroke$="[[_getStroke(_computedFill, 'META')]]" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect> </svg> </template> <template is="dom-if" if="[[_isType(node, type, 'SERIES')]]"> <template is="dom-if" if="[[_isVertical(node, vertical)]]"> <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 15"> - <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill="white" stroke="#ccc" x="0" y="2"></use> + <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="2"></use> </svg> </template> <template is="dom-if" if="[[!_isVertical(node, vertical)]]"> <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 24 10"> - <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill="white" stroke="#ccc" x="0" y="1"></use> + <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="1"></use> </svg> </template> </template> @@ -9179,7 +9351,36 @@ Polymer({ value: null }, - /** Type of node to draw. */ + /** + * Render node information associated with this node. Optional. If + * specified, this is only used when computing the fill of the icon + * element. + * @type {tf.graph.render.RenderNodeInfo} + */ + renderInfo: { + type: Object, + value: null + }, + + /** + * String indicating the type of coloring to use for this node, used + * only for deterimining the fill. + */ + colorBy: { + type: Object, + value: "structural" + }, + + /** + * Function used by structural coloring algorithim to determine which + * color to use based on the template ID of the node. Optional. + */ + templateIndex: { + type: Function, + value: null + }, + + /** Type of node to draw (ignored if node is set). */ type: { type: String, value: null @@ -9203,11 +9404,70 @@ Polymer({ value: false }, + /** + * Fill for the icon, optional. If fill is specified and node is not + * specified, then this value will override the default for the + * element. However, if node is specified, this value will be ignored. + */ + fill: { + type: String, + value: null + }, + /** Height of the SVG element in pixels, used for scaling. */ height: { type: Number, value: 20 + }, + + /** The computed fill for the node. **/ + _computedFill: { + type: String, + computed: + "_getComputedFill(node, renderInfo, colorBy, templateIndex, fill)" + } + + }, + + /** + * Get the computed fill value for the element. + */ + _getComputedFill: function(inputNode, inputRenderInfo, inputColorBy, + inputTemplateIndex, inputFill) { + if (inputNode && inputRenderInfo && + inputColorBy && inputTemplateIndex) { + var ns = tf.graph.scene.node; + var colorBy = ns.ColorBy[inputColorBy.toUpperCase()]; + return ns.getFillForNode(inputTemplateIndex, colorBy, + inputRenderInfo, false); } + return inputFill; + }, + + /** + * Get the fill value for the element, or if that's not possible, return + * the default fill value for the node type. + */ + _getFill: function(inputComputedFill, inputNodeType) { + return inputComputedFill || ({ + OP: tf.graph.render.OpNodeColors.DEFAULT_FILL, + META: tf.graph.render.MetanodeColors.DEFAULT_FILL, + SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_FILL + })[inputNodeType]; + }, + + /** + * Get the stroke value for the element, or if that's not possible, + * return the default stroke value for the node type. + */ + _getStroke: function(inputComputedFill, inputNodeType) { + return inputComputedFill ? + tf.graph.scene.node.getStrokeForFill(inputComputedFill) : + ({ + OP: tf.graph.render.OpNodeColors.DEFAULT_STROKE, + META: tf.graph.render.MetanodeColors.DEFAULT_STROKE, + SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_STROKE + })[inputNodeType]; }, /** @@ -9267,7 +9527,7 @@ Polymer({ })(); </script> </dom-module> -<dom-module id="tf-node-list-item" assetpath="../components/tf-graph-info/"> +<dom-module id="tf-node-list-item" assetpath="../tf-graph-info/"> <style> #list-item { width: 100%; @@ -9302,7 +9562,7 @@ Polymer({ </style> <template> <div id="list-item" on-mouseover="_nodeListener" on-mouseout="_nodeListener" on-click="_nodeListener"> - <tf-graph-icon class="node-icon" node="[[itemNode]]" height="12"></tf-graph-icon> + <tf-graph-icon class="node-icon" height="12" color-by="[[colorBy]]" color-by-params="[[colorByParams]]" node="[[itemNode]]" render-info="[[itemRenderInfo]]" template-index="[[templateIndex]]"></tf-graph-icon> <span title$="[[name]]">[[name]]</span> </div> </template> @@ -9323,11 +9583,19 @@ Polymer({ * @type {tf.graph.Node} */ itemNode: Object, + /** + * The render node information for the item node. Used by the graph + * icon in determining fill color. + */ + itemRenderInfo: Object, name: String, itemType: { type: String, observer: '_itemTypeChanged' - } + }, + colorBy: String, + colorByParams: Object, + templateIndex: Function, }, _itemTypeChanged: function() { @@ -9351,7 +9619,7 @@ Polymer({ })(); </script> </dom-module> -<dom-module id="tf-node-info" assetpath="../components/tf-graph-info/"> +<dom-module id="tf-node-info" assetpath="../tf-graph-info/"> <style> .sub-list-group { padding: 8px 12px 0px; @@ -9432,6 +9700,27 @@ Polymer({ max-width: 20px; padding: 0; } + + .toggle-include-group { + padding-top: 4px; + } + + .toggle-include { + margin: 5px 6px; + text-transform: none; + padding: 4px 6px; + font-size: 10pt; + background-color: #fafafa; + color: #666; + } + + .toggle-include:hover { + background-color: var(--google-yellow-100); + } + + .non-control-list-item { + padding-left: 10px; + } </style> <template> <paper-item> @@ -9442,7 +9731,7 @@ Polymer({ <div class="node-name">[[_getNodeName(nodeName)]]</div> </div> <div secondary=""> - <tf-graph-icon class="node-icon" node="[[_node]]"></tf-graph-icon> + <tf-graph-icon class="node-icon" node="[[_node]]" render-info="[[_getRenderInfo(nodeName, renderHierarchy)]]" color-by="[[colorBy]]" template-index="[[_templateIndex]]"></tf-graph-icon> <template is="dom-if" if="{{_node.op}}"> <div class="subtitle"> Operation: @@ -9486,7 +9775,7 @@ Polymer({ (<span>[[_totalPredecessors]]</span>) <iron-list class="sub-list" id="inputsList" items="[[_predecessors.regular]]"> <template> - <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="predecessors"> + <tf-node-list-item class="non-control-list-item" card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="predecessors" color-by="[[colorBy]]" template-index="[[_templateIndex]]"> </tf-node-list-item> </template> </iron-list> @@ -9501,7 +9790,7 @@ Polymer({ <template is="dom-if" if="{{_openedControlPred}}" restamp="true"> <iron-list class="sub-list" items="[[_predecessors.control]]"> <template> - <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="predecessors"> + <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="predecessors" color-by="[[colorBy]]" template-index="[[_templateIndex]]"> </tf-node-list-item> </template> </iron-list> @@ -9516,7 +9805,7 @@ Polymer({ (<span>[[_totalSuccessors]]</span>) <iron-list class="sub-list" id="outputsList" items="[[_successors.regular]]"> <template> - <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="successor"> + <tf-node-list-item class="non-control-list-item" card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="successor" color-by="[[colorBy]]" template-index="[[_templateIndex]]"> </tf-node-list-item> </template> </iron-list> @@ -9531,7 +9820,7 @@ Polymer({ <template is="dom-if" if="{{_openedControlSucc}}" restamp="true"> <iron-list class="sub-list" items="[[_successors.control]]"> <template> - <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="successors"> + <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="successors" color-by="[[colorBy]]" template-index="[[_templateIndex]]"> </tf-node-list-item> </template> </iron-list> @@ -9540,6 +9829,11 @@ Polymer({ </div> </template> </div> + <div class="toggle-include-group"> + <paper-button raised="" class="toggle-include" on-click="_toggleInclude"> + <span>[[_auxButtonText]]</span> + </paper-button> + </div> </div> </template> </iron-collapse> @@ -9553,11 +9847,23 @@ Polymer({ properties: { nodeName: String, graphHierarchy: Object, + renderHierarchy: Object, + /** What to color the nodes by (compute time, memory, device etc.) */ + colorBy: String, + _templateIndex: { + type: Function, + computed: '_getTemplateIndex(graphHierarchy)' + }, _node: { type: Object, computed: '_getNode(nodeName, graphHierarchy)', observer: '_resetState' }, + // The enum value of the include property of the selected node. + nodeInclude: { + type: Number, + observer: '_nodeIncludeStateChanged' + }, _attributes: { type: Array, computed: '_getAttributes(_node)' @@ -9598,18 +9904,25 @@ Polymer({ type: Boolean, value: false }, + _auxButtonText: String }, expandNode: function() { this.fire('_node.expand', this.node); }, - _getNode: function(n, graphHierarchy) { - return graphHierarchy.node(n); + _getTemplateIndex: function(graphHierarchy) { + return graphHierarchy.getTemplateIndex(); + }, + _getNode: function(nodeName, graphHierarchy) { + return graphHierarchy.node(nodeName); }, _getNodeName: function(nodeName) { // Insert a zero-width whitespace character before each slash so that // long node names wrap cleanly at path boundaries. return (nodeName || '').replace(/\//g, '\u200B/'); }, + _getRenderInfo: function(nodeName, renderHierarchy) { + return this.renderHierarchy.getOrCreateRenderNodeByName(nodeName); + }, _getAttributes: function(node) { this.async(this._resizeList.bind(this, "#attributesList")); return node && node.attr ? node.attr.map(function(entry) { @@ -9658,12 +9971,22 @@ Polymer({ if (list) { list.fire('iron-resize'); } + }, + _toggleInclude: function() { + var graphElem = document.querySelector("#graph"); + graphElem.fire("node-toggle-extract", { name: this.nodeName }); + var graphBoardElem = document.querySelector("#graphboard"); + graphBoardElem.fire("node-toggle-extract"); + }, + _nodeIncludeStateChanged: function(include, oldInclude) { + this.set("_auxButtonText", + tf.graph.getIncludeNodeButtonString(include)); } }); })(); </script> </dom-module> -<dom-module id="tf-graph-info" assetpath="../components/tf-graph-info/"> +<dom-module id="tf-graph-info" assetpath="../tf-graph-info/"> <template> <style> :host { @@ -9681,7 +10004,7 @@ h2 { </style> <template is="dom-if" if="{{selectedNode}}"> <paper-material elevation="1" class="card"> - <tf-node-info graph-hierarchy="[[graphHierarchy]]" flat-graph="[[graph]]" node-name="[[selectedNode]]" highlighted-node="{{highlightedNode}}"> + <tf-node-info graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[renderHierarchy]]" flat-graph="[[graph]]" node-name="[[selectedNode]]" node-include="[[selectedNodeInclude]]" highlighted-node="{{highlightedNode}}" color-by="[[colorBy]]"> </tf-node-info> </paper-material> </template> @@ -9695,6 +10018,8 @@ h2 { title: String, graphHierarchy: Object, graph: Object, + renderHierarchy: Object, + colorBy: String, // Two-ways selectedNode: { type: String, @@ -9703,6 +10028,11 @@ h2 { highlightedNode: { type: String, notify: true + }, + // The enum value of the include property of the selected node. + selectedNodeInclude: { + type: Number, + notify: true } }, listeners: { @@ -9723,7 +10053,7 @@ h2 { })(); </script> </dom-module> -<dom-module id="tf-graph-board" assetpath="../components/tf-graph-board/"> +<dom-module id="tf-graph-board" assetpath="../tf-graph-board/"> <template> <style> ::host { @@ -9790,6 +10120,32 @@ paper-progress { --paper-progress-height: 6px; --paper-progress-active-color: #f3913e; } + +.context-menu { + position: absolute; + display: none; + background-color: #e2e2e2; + border-radius: 2px; + font-size: 14px; + min-width: 150px; + border: 1px solid #d4d4d4; +} + +/deep/ .context-menu ul { + list-style-type: none; + margin: 0; + padding: 0; + cursor: default; +} + +/deep/ .context-menu ul li { + padding: 4px 16px; +} + +/deep/ .context-menu ul li:hover { + background-color: #f3913e; + color: white; +} </style> <template is="dom-if" if="[[_isNotComplete(progress)]]"> <div id="progress-bar"> @@ -9799,11 +10155,12 @@ paper-progress { </template> <div class$="[[_getContainerClass(progress)]]"> <div id="main"> - <tf-graph id="graph" graph-hierarchy="[[graphHierarchy]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" graph-name="[[graphName]]" progress="[[progress]]"></tf-graph> + <tf-graph id="graph" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="{{_renderHierarchy}}" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" graph-name="[[graphName]]" progress="[[progress]]"></tf-graph> </div> <div id="info"> - <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}"></tf-graph-info> + <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[_renderHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" selected-node-include="{{_selectedNodeInclude}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="[[colorByParams]]"></tf-graph-info> </div> + <div class="context-menu"></div> </div> </template> </dom-module> @@ -9825,14 +10182,24 @@ Polymer({ * for the progress bar and the displayed message. */ progress: Object, + colorBy: String, colorByParams: { type: Object, notify: true, }, // Private API: Data routing between child components. _selectedNode: String, + // The enum value of the include property of the selected node. + _selectedNodeInclude: Number, _highlightedNode: String, + _renderHierarchy: Object, }, + listeners: { + 'node-toggle-extract': '_nodeToggleExtract' + }, + observers: [ + '_updateNodeInclude(_selectedNode)' + ], /** True if the progress is not complete yet (< 100 %). */ _isNotComplete: function(progress) { return progress.value < 100; @@ -9846,10 +10213,18 @@ Polymer({ result += ' loading'; } return result; + }, + _updateNodeInclude: function(nodeName) { + var node = this.graphHierarchy.node(nodeName); + this.set("_selectedNodeInclude", + node ? node.include : tf.graph.InclusionType.UNSPECIFIED); + }, + _nodeToggleExtract: function() { + this._updateNodeInclude(this._selectedNode); } }); </script> -<dom-module id="tf-graph-controls" assetpath="../components/tf-graph/"> +<dom-module id="tf-graph-controls" assetpath="../tf-graph/"> <template> <style> :host { @@ -9899,7 +10274,7 @@ table td { } .allcontrols { - padding: 10px; + padding: 30px; } .legend-holder { @@ -9908,10 +10283,6 @@ table td { padding-bottom: 10px; } -#fit { - color: var(--paper-orange-500); -} - paper-radio-button { padding: 5px; } @@ -9979,7 +10350,7 @@ svg.icon { padding: 0 0 0 55px; } -.fit-button-text { +.button-text { text-transform: none; padding: 8px 18px 0 18px; font-size: 14px @@ -9992,10 +10363,11 @@ svg.icon { margin-top: 4px; } -.fit-button { +.iconbutton { padding: 2px; width: 30px; height: 30px; + color: var(--paper-orange-500); } .hidden-input { @@ -10011,12 +10383,20 @@ svg.icon { </style> <div class="allcontrols"> <div class="control-holder"> - <paper-icon-button id="fit" icon="aspect-ratio" class="fit-button" on-click="fit" alt="Fit to screen"> + <paper-icon-button icon="aspect-ratio" class="iconbutton" on-click="fit" alt="Fit to screen"> </paper-icon-button> - <paper-button class="fit-button-text" on-click="fit">Fit to screen + <paper-button class="button-text" on-click="fit">Fit to screen </paper-button> </div> <div class="control-holder"> + <paper-icon-button icon="file-download" class="iconbutton" on-click="download" alt="Download PNG"> + </paper-icon-button> + <paper-button class="button-text" on-click="download">Download PNG + </paper-button> + <a href="#" id="graphdownload" class="title" download="graph.png"> + </a> + </div> + <div class="control-holder"> <div class="title">Run</div> <paper-dropdown-menu no-label-float="" no-animations="" noink="" class="run-dropdown"> <paper-menu id="select" class="dropdown-content" selected="{{selectedDataset}}"> @@ -10137,8 +10517,8 @@ svg.icon { </tr> <tr> <td> - <svg class="image-icon"> - <image id="summary-icon" width="24" height="24" x="0" y="0" class="image-icon"></image> + <svg class="image-icon" viewBox="0 0 12 12" width="24" height="24"> + <use x="0" y="0" class="image-icon" xlink:href="#summary-icon"></use> </svg> </td> <td>Summary</td> @@ -10183,11 +10563,6 @@ svg.icon { (function() { // Private scope. Polymer({ is: 'tf-graph-controls', - ready: function() { - // Set the url to download the summary icon. - d3.select(this.$['summary-icon']) - .attr('xlink:href', this.resolveUrl('../../lib/svg/summary-icon.svg')); - }, properties: { // Public API. hasStats: { @@ -10207,6 +10582,7 @@ Polymer({ type: Number, notify: true, value: 0, + observer: '_selectedDatasetChanged' }, selectedFile: { type: Object, @@ -10258,17 +10634,44 @@ Polymer({ endColor: params.endColor }; }, + download: function() { + this.$.graphdownload.click(); + }, _updateFileInput: function(e) { + var file = e.target.files[0]; + if (!file) { + return; + } + this._setDownloadFilename(file.name); this.set('selectedFile', e); }, _datasetsChanged: function(newDatasets, oldDatasets) { if (oldDatasets != null || this.selected == null) { // Select the first dataset by default. this.set('selectedDataset', 0); + this._setDownloadFilename(this.datasets[this.selectedDataset].path); + } + }, + _selectedDatasetChanged: function(newDataset, oldDataset) { + if (this.datasets) { + this._setDownloadFilename(this.datasets[newDataset].path); } }, _getFile: function() { this.$.file.click(); + }, + _setDownloadFilename: function(graphPath) { + // Strip off everything before the last "/" and strip off the file + // extension in order to get the name of the PNG for the graph. + var dotIndex = graphPath.lastIndexOf('.'); + if (dotIndex) { + graphPath = graphPath.substring(0, dotIndex); + } + var slashIndex = graphPath.lastIndexOf('/'); + if (slashIndex) { + graphPath = graphPath.substring(slashIndex + 1); + } + this.$.graphdownload.setAttribute('download', graphPath + '.png'); } }); @@ -10311,7 +10714,7 @@ function convertToHumanReadable(value, units, unitIndex) { })(); // Closing private scope. </script> </dom-module> -<dom-module id="tf-graph-dashboard" assetpath="../components/tf-graph-dashboard/"> +<dom-module id="tf-graph-dashboard" assetpath="../tf-graph-dashboard/"> <template> <div id="plumbing"> <tf-url-generator out-runs-url="{{_runsUrl}}" out-graph-url-generator="{{_graphUrlGen}}" id="urlGenerator"></tf-url-generator> @@ -10348,6 +10751,7 @@ function convertToHumanReadable(value, units, unitIndex) { } .center { + position: relative; height: 100%; } @@ -10386,15 +10790,13 @@ Polymer({ <paper-header-panel> <paper-toolbar id="toolbar"> <div id="toolbar-content"> - <div class="toolbar-title"> - TensorBoard - </div> - <div class="right-buttons"> - <paper-button class="link-button" on-click="chooseEvents" active$="[[eventDashboard(mode)]]" noink="">Events</paper-button> - <paper-button class="link-button" on-click="chooseImages" active$="[[imageDashboard(mode)]]" noink="">Images</paper-button> - <paper-button class="link-button" on-click="chooseGraphs" active$="[[graphDashboard(mode)]]" noink="">Graph</paper-button> - <paper-button class="link-button" on-click="chooseHistograms" active$="[[histogramDashboard(mode)]]" noink="">Histograms</paper-button> - </div> + <div class="toolbar-title">TensorBoard</div> + <paper-tabs selected="0" noink="" class="tabs"> + <paper-tab on-click="chooseEvents">Events</paper-tab> + <paper-tab on-click="chooseImages">Images</paper-tab> + <paper-tab on-click="chooseGraphs">Graph</paper-tab> + <paper-tab on-click="chooseHistograms">Histograms</paper-tab> + </paper-tabs> </div> </paper-toolbar> <div id="content" class="fit"> @@ -10416,33 +10818,48 @@ Polymer({ </div> </paper-header-panel> <style> + :host { + height: 100%; + display: block; + background-color: var(--paper-grey-100); + } + #toolbar { background-color: var(--tb-orange-strong); - background-image: radial-gradient(ellipse, var(--tb-orange-weak), var(--tb-orange-strong)); + -webkit-font-smoothing: antialiased; } + #toolbar-content { width: 100%; + height: 100%; display: flex; flex-direction: row; justify-content: space-between; align-items: center; } + .toolbar-title { - font-size: 30px; + font-size: 20px; + margin-left: 10px; + text-rendering: optimizeLegibility; + letter-spacing: -0.025em; + font-weight: 500; } + #content { height: 100%; } - .link-button { - height: 30px; - } - [active] { - font-weight: bold; - } - :host { + + .tabs { + width: 400px; + text-transform: uppercase; height: 100%; - display: block; } + + paper-tabs { + --paper-tabs-selection-bar-color: white; + } + </style> </template> <script> diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js index 61387e730b..867fc2f5ef 100644 --- a/tensorflow/tensorboard/gulpfile.js +++ b/tensorflow/tensorboard/gulpfile.js @@ -27,6 +27,7 @@ var gulpFilter = require('gulp-filter'); var vulcanize = require('gulp-vulcanize'); var minimist = require('minimist'); var replace = require('gulp-replace'); +var header = require('gulp-header'); var fs = require('fs'); var path = require('path'); var options = minimist(process.argv.slice(2), { @@ -162,16 +163,9 @@ gulp.task('vulcanize', ['compile.all', 'tslint-strict'], function() { // fixes https://github.com/Polymer/vulcanize/issues/273 .pipe(replace(linkRegex, '')) .pipe(replace(scriptRegex, '')) - .pipe(gulp.dest('dist')); + .pipe(header('// AUTOGENERATED FILE - DO NOT MODIFY \n')) + .pipe(gulp.dest('../opensource_only/tensorboard')); - // Vulcanize TensorBoard with all external libraries inlined. - gulp.src('components/index.html') - .pipe(vulcanize({ - inlineScripts: true, - inlineCss: true, - stripComments: true, - })) - .pipe(gulp.dest('dist')); gulp.src('app/tf-tensorboard-demo.html') .pipe(vulcanize({ diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json index 3b87007522..492fd36053 100644 --- a/tensorflow/tensorboard/package.json +++ b/tensorflow/tensorboard/package.json @@ -12,19 +12,21 @@ "license": "Apache-2.0", "devDependencies": { "gulp": "~3.9.0", - "gulp-typescript": "~2.8.0", - "tsd": "~0.6.3", - "typescript": "~1.5.3", - "gulp-cli": "~0.3.0", - "gulp-util": "~3.0.6", - "gulp-tslint": "~3.1.1-beta", - "gulp-server-livereload": "~1.4.0", + "gulp-cli": "^1.1.0", + "gulp-filter": "~3.0.1", + "gulp-replace": "~0.5.4", + "gulp-server-livereload": "~1.5.4", + "gulp-tslint": "~4.2.2", + "gulp-typescript": "~2.10.0", + "gulp-util": "~3.0.7", + "gulp-vulcanize": "~6.1.0", "merge2": "~0.3.6", - "gulp-filter": "~3.0.0", - "vulcanize": "~1.14.0", - "gulp-vulcanize": "~6.0.1", "minimist": "~1.2.0", - "gulp-replace": "~0.5.4", - "web-component-tester": "~3.3.30" + "tsd": "^0.6.5", + "tslint": "^3.2.1", + "typescript": "^1.6.2", + "vulcanize": "^1.14.0", + "web-component-tester": "~3.4.2", + "gulp-header": "~1.7.1" } } diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py index 8c4a9e7689..ae190fef7b 100644 --- a/tensorflow/tensorboard/tensorboard_handler.py +++ b/tensorflow/tensorboard/tensorboard_handler.py @@ -44,6 +44,8 @@ from tensorflow.python.platform import resource_loader from tensorflow.python.summary import event_accumulator from tensorflow.tensorboard import float_wrapper + +DATA_PREFIX = '/data' RUNS_ROUTE = '/runs' SCALARS_ROUTE = '/' + event_accumulator.SCALARS IMAGES_ROUTE = '/' + event_accumulator.IMAGES @@ -51,6 +53,7 @@ HISTOGRAMS_ROUTE = '/' + event_accumulator.HISTOGRAMS COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS INDIVIDUAL_IMAGE_ROUTE = '/individualImage' GRAPH_ROUTE = '/' + event_accumulator.GRAPH +TAB_ROUTES = ['', '/events', '/images', '/graphs', '/histograms'] _IMGHDR_TO_MIMETYPE = { 'bmp': 'image/bmp', @@ -373,32 +376,34 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler): if clean_path.endswith('/'): clean_path = clean_path[:-1] - handlers = { - SCALARS_ROUTE: self._serve_scalars, - GRAPH_ROUTE: self._serve_graph, - HISTOGRAMS_ROUTE: self._serve_histograms, - COMPRESSED_HISTOGRAMS_ROUTE: self._serve_compressed_histograms, - IMAGES_ROUTE: self._serve_images, - INDIVIDUAL_IMAGE_ROUTE: self._serve_image, - RUNS_ROUTE: self._serve_runs, - '': self._serve_index, + data_handlers = { + DATA_PREFIX + SCALARS_ROUTE: self._serve_scalars, + DATA_PREFIX + GRAPH_ROUTE: self._serve_graph, + DATA_PREFIX + HISTOGRAMS_ROUTE: self._serve_histograms, + DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: + self._serve_compressed_histograms, + DATA_PREFIX + IMAGES_ROUTE: self._serve_images, + DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: self._serve_image, + DATA_PREFIX + RUNS_ROUTE: self._serve_runs, '/app.js': self._serve_js } - if clean_path in handlers: - query_params = urlparse.parse_qs(parsed_url.query) - # parse_qs returns a list of values for each key; we're only interested in - # the first. - for key in query_params: - value_count = len(query_params[key]) - if value_count != 1: - self.send_error( - 400, - 'query parameter %s should have exactly one value, had %d' % - (key, value_count)) - return - - query_params[key] = query_params[key][0] - handlers[clean_path](query_params) + query_params = urlparse.parse_qs(parsed_url.query) + # parse_qs returns a list of values for each key; we're only interested in + # the first. + for key in query_params: + value_count = len(query_params[key]) + if value_count != 1: + self.send_error( + 400, + 'query parameter %s should have exactly one value, had %d' % + (key, value_count)) + return + query_params[key] = query_params[key][0] + + if clean_path in data_handlers: + data_handlers[clean_path](query_params) + elif clean_path in TAB_ROUTES: + self._serve_index(query_params) else: self._serve_static_file(clean_path) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index 942240bd82..908764d2f7 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Cholesky" +#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Cholesky" diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index e9896a5fba..c78b7c95ee 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Core" +#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index 5db8b147c6..235e34cd5f 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Eigenvalues" +#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index 25e4ebf4f5..cdf52403a3 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/LU" +#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/LU" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index 8f4bbd7ee9..72e6fa6663 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/unsupported/Eigen/CXX11/Tensor" +#include "external/eigen_archive/eigen-eigen-a0661a2bb165/unsupported/Eigen/CXX11/Tensor" |