aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/c_api.cc4
-rw-r--r--tensorflow/c/eager/c_api.h10
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc31
-rw-r--r--tensorflow/core/common_runtime/eager/context.h8
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc7
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc31
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h3
-rw-r--r--tensorflow/python/eager/backprop.py4
-rw-r--r--tensorflow/python/eager/context.py6
-rw-r--r--tensorflow/python/eager/function_test.py2
-rw-r--r--tensorflow/python/pywrap_tfe.i2
11 files changed, 93 insertions, 15 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index a0a44440c8..d7073d8e05 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -719,6 +719,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
+void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
+
+void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
+
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 25cf7adbc7..092af45731 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -380,6 +380,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
+// Some TF ops need a step container to be set to limit the lifetime of some
+// resources (mostly TensorArray and Stack, used in while loop gradients in
+// graph mode). Calling this on a context tells it to start a step.
+TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
+
+// Ends a step. When there is no active step (that is, every started step has
+// been ended) step containers will be cleared. Note: it is not safe to call
+// TFE_ContextEndStep while ops which rely on the step container may be running.
+TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 6ab2d1ebf1..e5fe87fc37 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/util/env_var.h"
@@ -46,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
+ num_active_steps_(0),
async_default_(async),
env_(opts.env),
use_send_tensor_rpc_(false) {
@@ -194,6 +196,35 @@ Status EagerContext::FindDeviceByName(const string& name, Device** result) {
return Status::OK();
}
+void EagerContext::StartStep() {
+ mutex_lock ml(metadata_mu_);
+ num_active_steps_++;
+ if (step_container_ == nullptr) {
+ step_container_.reset(
+ new ScopedStepContainer(0, [this](const string& name) {
+ for (Device* device : devices_) {
+ device->resource_manager()->Cleanup(name).IgnoreError();
+ }
+ }));
+ }
+}
+
+void EagerContext::EndStep() {
+ mutex_lock ml(metadata_mu_);
+ num_active_steps_--;
+ if (num_active_steps_ == 0) {
+ step_container_.reset();
+ }
+}
+
+ScopedStepContainer* EagerContext::StepContainer() {
+ if (num_active_steps_.load() == 0) {
+ return nullptr;
+ }
+ mutex_lock ml(metadata_mu_);
+ return step_container_.get();
+}
+
Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
if (remote_device_manager_ == nullptr) return Status::OK();
#ifndef __ANDROID__
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index a0b612e6e5..3eea56b5e3 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -153,6 +153,10 @@ class EagerContext {
void SetShouldStoreMetadata(bool value);
RunMetadata* RunMetadataProto() { return &run_metadata_; }
+ void StartStep();
+ void EndStep();
+ ScopedStepContainer* StepContainer();
+
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
#ifndef __ANDROID__
@@ -236,6 +240,10 @@ class EagerContext {
// EagerExecutor for async execution.
EagerExecutor executor_;
+ // Information related to step containers.
+ std::atomic<int> num_active_steps_;
+ std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_);
+
// True if the default value for execution mode is async. Note that this value
// can be overridden per thread based on `thread_local_async` overrides.
const bool async_default_;
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 3837405e7f..51b770d035 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -653,7 +653,12 @@ Status EagerExecute(EagerContext* ctx, Device* device,
// FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
// for ops which are a part of functions.
// TODO(agarwal): change Run to take vector of handles ?
- TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
+ ScopedStepContainer* container = ctx->StepContainer();
+ if (container == nullptr) {
+ TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
+ } else {
+ TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats));
+ }
if (maybe_stats != nullptr) {
int64 nanos = Env::Default()->NowNanos();
maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index dae5d1983f..3d61ff4dc2 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -60,12 +60,22 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
return s;
}
-Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
- std::vector<Tensor>* output_tensors,
+Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
+ std::vector<Tensor>* outputs,
NodeExecStats* stats) {
- gtl::InlinedVector<TensorValue, 4> inputs;
- for (Tensor& t : *input_tensors) {
- inputs.push_back(TensorValue(&t));
+ ScopedStepContainer step_container(0, [this](const string& name) {
+ device_->resource_manager()->Cleanup(name).IgnoreError();
+ });
+ return this->Run(&step_container, inputs, outputs, stats);
+}
+
+Status KernelAndDevice::Run(ScopedStepContainer* step_container,
+ std::vector<Tensor>* inputs,
+ std::vector<Tensor>* outputs,
+ NodeExecStats* stats) {
+ gtl::InlinedVector<TensorValue, 4> input_vector;
+ for (Tensor& t : *inputs) {
+ input_vector.push_back(TensorValue(&t));
}
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
@@ -77,7 +87,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
OpKernelContext::Params params;
params.device = device_;
params.frame_iter = FrameAndIter(0, 0);
- params.inputs = &inputs;
+ params.inputs = &input_vector;
params.op_kernel = kernel_.get();
params.resource_manager = device_->resource_manager();
params.output_attr_array = gtl::vector_as_array(&out_attrs);
@@ -94,10 +104,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.runner = runner_;
}
- ScopedStepContainer step_container(0, [this](const string& name) {
- device_->resource_manager()->Cleanup(name).IgnoreError();
- });
- params.step_container = &step_container;
+ params.step_container = step_container;
OpKernelContext context(&params);
@@ -114,9 +121,9 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
}
if (!context.status().ok()) return context.status();
- output_tensors->clear();
+ outputs->clear();
for (int i = 0; i < context.num_outputs(); ++i) {
- output_tensors->push_back(Tensor(*context.mutable_output(i)));
+ outputs->push_back(Tensor(*context.mutable_output(i)));
}
if (stats != nullptr) {
for (const auto& allocator_pair : context.wrapped_allocators()) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index c0b676b285..751cf687b2 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -70,6 +70,9 @@ class KernelAndDevice {
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
NodeExecStats* stats);
+ Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
+ std::vector<Tensor>* outputs, NodeExecStats* stats);
+
const OpKernel* kernel() const { return kernel_.get(); }
Device* device() const { return device_; }
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 5f60f62874..728b283695 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -705,6 +705,7 @@ class GradientTape(object):
self._tape = None
self._persistent = persistent
self._recording = False
+ context.context().start_step()
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
@@ -733,6 +734,9 @@ class GradientTape(object):
tape.pop_tape(self._tape)
self._recording = False
+ def __del__(self):
+ context.context().end_step()
+
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index c79294895b..09223c86d4 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -603,6 +603,12 @@ class Context(object):
"""Returns a stack of context switches."""
return self._context_switches
+ def start_step(self):
+ pywrap_tensorflow.TFE_ContextStartStep(self._handle)
+
+ def end_step(self):
+ pywrap_tensorflow.TFE_ContextEndStep(self._handle)
+
_context = None
_context_lock = threading.Lock()
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 8084df4e8e..06b4e732a1 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -232,8 +232,6 @@ class FunctionTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testGraphLoopGradient(self):
- if context.executing_eagerly():
- self.skipTest('TODO(apassos): support loops in defuns in eager')
@function.defun
def f(x):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 1b69e0d06c..157f2341e0 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -63,6 +63,8 @@ limitations under the License.
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
+%rename("%s") TFE_ContextStartStep;
+%rename("%s") TFE_ContextEndStep;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"