diff options
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 4 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 10 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 31 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 8 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/kernel_and_device.cc | 31 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/kernel_and_device.h | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 6 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 2 |
11 files changed, 93 insertions, 15 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a0a44440c8..d7073d8e05 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -719,6 +719,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 25cf7adbc7..092af45731 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -380,6 +380,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 6ab2d1ebf1..e5fe87fc37 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/util/env_var.h" @@ -46,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts, local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), + num_active_steps_(0), async_default_(async), env_(opts.env), use_send_tensor_rpc_(false) { @@ -194,6 +196,35 @@ Status EagerContext::FindDeviceByName(const string& name, Device** result) { return Status::OK(); } +void EagerContext::StartStep() { + mutex_lock ml(metadata_mu_); + num_active_steps_++; + if (step_container_ == nullptr) { + step_container_.reset( + new ScopedStepContainer(0, [this](const string& name) { + for (Device* device : devices_) { + device->resource_manager()->Cleanup(name).IgnoreError(); + } + })); + } +} + +void EagerContext::EndStep() { + mutex_lock ml(metadata_mu_); + num_active_steps_--; + if (num_active_steps_ == 0) { + step_container_.reset(); + } +} + +ScopedStepContainer* EagerContext::StepContainer() { + if (num_active_steps_.load() == 0) { + return nullptr; + } + mutex_lock ml(metadata_mu_); + return step_container_.get(); +} + Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { if (remote_device_manager_ == nullptr) return Status::OK(); #ifndef __ANDROID__ diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index a0b612e6e5..3eea56b5e3 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -153,6 +153,10 @@ class EagerContext { void SetShouldStoreMetadata(bool value); RunMetadata* RunMetadataProto() { return &run_metadata_; } + void StartStep(); + void EndStep(); + ScopedStepContainer* StepContainer(); + FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } #ifndef __ANDROID__ @@ -236,6 +240,10 @@ class EagerContext { // EagerExecutor for async execution. EagerExecutor executor_; + // Information related to step containers. + std::atomic<int> num_active_steps_; + std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_); + // True if the default value for execution mode is async. Note that this value // can be overridden per thread based on `thread_local_async` overrides. const bool async_default_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3837405e7f..51b770d035 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -653,7 +653,12 @@ Status EagerExecute(EagerContext* ctx, Device* device, // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats // for ops which are a part of functions. // TODO(agarwal): change Run to take vector of handles ? - TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); + ScopedStepContainer* container = ctx->StepContainer(); + if (container == nullptr) { + TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); + } else { + TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats)); + } if (maybe_stats != nullptr) { int64 nanos = Env::Default()->NowNanos(); maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos - diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index dae5d1983f..3d61ff4dc2 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -60,12 +60,22 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, return s; } -Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, - std::vector<Tensor>* output_tensors, +Status KernelAndDevice::Run(std::vector<Tensor>* inputs, + std::vector<Tensor>* outputs, NodeExecStats* stats) { - gtl::InlinedVector<TensorValue, 4> inputs; - for (Tensor& t : *input_tensors) { - inputs.push_back(TensorValue(&t)); + ScopedStepContainer step_container(0, [this](const string& name) { + device_->resource_manager()->Cleanup(name).IgnoreError(); + }); + return this->Run(&step_container, inputs, outputs, stats); +} + +Status KernelAndDevice::Run(ScopedStepContainer* step_container, + std::vector<Tensor>* inputs, + std::vector<Tensor>* outputs, + NodeExecStats* stats) { + gtl::InlinedVector<TensorValue, 4> input_vector; + for (Tensor& t : *inputs) { + input_vector.push_back(TensorValue(&t)); } std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs()); @@ -77,7 +87,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, OpKernelContext::Params params; params.device = device_; params.frame_iter = FrameAndIter(0, 0); - params.inputs = &inputs; + params.inputs = &input_vector; params.op_kernel = kernel_.get(); params.resource_manager = device_->resource_manager(); params.output_attr_array = gtl::vector_as_array(&out_attrs); @@ -94,10 +104,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, params.runner = runner_; } - ScopedStepContainer step_container(0, [this](const string& name) { - device_->resource_manager()->Cleanup(name).IgnoreError(); - }); - params.step_container = &step_container; + params.step_container = step_container; OpKernelContext context(¶ms); @@ -114,9 +121,9 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, } if (!context.status().ok()) return context.status(); - output_tensors->clear(); + outputs->clear(); for (int i = 0; i < context.num_outputs(); ++i) { - output_tensors->push_back(Tensor(*context.mutable_output(i))); + outputs->push_back(Tensor(*context.mutable_output(i))); } if (stats != nullptr) { for (const auto& allocator_pair : context.wrapped_allocators()) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index c0b676b285..751cf687b2 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -70,6 +70,9 @@ class KernelAndDevice { Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, NodeExecStats* stats); + Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs, + std::vector<Tensor>* outputs, NodeExecStats* stats); + const OpKernel* kernel() const { return kernel_.get(); } Device* device() const { return device_; } diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 5f60f62874..728b283695 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -705,6 +705,7 @@ class GradientTape(object): self._tape = None self._persistent = persistent self._recording = False + context.context().start_step() def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" @@ -733,6 +734,9 @@ class GradientTape(object): tape.pop_tape(self._tape) self._recording = False + def __del__(self): + context.context().end_step() + def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index c79294895b..09223c86d4 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -603,6 +603,12 @@ class Context(object): """Returns a stack of context switches.""" return self._context_switches + def start_step(self): + pywrap_tensorflow.TFE_ContextStartStep(self._handle) + + def end_step(self): + pywrap_tensorflow.TFE_ContextEndStep(self._handle) + _context = None _context_lock = threading.Lock() diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 8084df4e8e..06b4e732a1 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -232,8 +232,6 @@ class FunctionTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testGraphLoopGradient(self): - if context.executing_eagerly(): - self.skipTest('TODO(apassos): support loops in defuns in eager') @function.defun def f(x): diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 1b69e0d06c..157f2341e0 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -63,6 +63,8 @@ limitations under the License. %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; %rename("%s") TFE_Py_TensorShapeOnDevice; +%rename("%s") TFE_ContextStartStep; +%rename("%s") TFE_ContextEndStep; %{ #include "tensorflow/python/eager/pywrap_tfe.h" |